@@ -7,7 +7,6 @@ package unsafesql_test
7
7
8
8
import (
9
9
"context"
10
- gosql "database/sql"
11
10
"fmt"
12
11
"os"
13
12
"strings"
@@ -105,140 +104,135 @@ func TestAccessCheckServer(t *testing.T) {
105
104
_ , err := pool .Exec ("CREATE TABLE foo (id INT PRIMARY KEY)" )
106
105
require .NoError (t , err )
107
106
108
- // helper func for setting a safe connection.
109
- safeConn := func (t * testing.T ) * gosql.Conn {
110
- conn , err := pool .Conn (ctx )
111
- require .NoError (t , err )
112
- _ , err = conn .ExecContext (ctx , "SET allow_unsafe_internals = false" )
113
- require .NoError (t , err )
114
- return conn
115
- }
116
-
117
- t .Run ("regular user activity is unaffected" , func (t * testing.T ) {
118
- conn := safeConn (t )
119
- _ , err := conn .QueryContext (ctx , "SELECT * FROM foo" )
120
- require .NoError (t , err )
121
- })
107
+ sendQuery := func (allowUnsafe bool , internal bool , query string ) error {
108
+ if internal {
109
+ idb := s .InternalDB ().(isql.DB )
110
+ err := idb .Txn (ctx , func (ctx context.Context , txn isql.Txn ) error {
111
+ txn .SessionData ().LocalOnlySessionData .AllowUnsafeInternals = allowUnsafe
122
112
123
- t .Run ("accessing the system database" , func (t * testing.T ) {
124
- q := "SELECT * FROM system.namespace"
125
-
126
- t .Run ("as an external querier" , func (t * testing.T ) {
127
- for _ , test := range []struct {
128
- AllowUnsafeInternals bool
129
- Passes bool
130
- }{
131
- {AllowUnsafeInternals : false , Passes : false },
132
- {AllowUnsafeInternals : true , Passes : true },
133
- } {
134
- t .Run (fmt .Sprintf ("%t" , test ), func (t * testing.T ) {
135
- conn := s .SQLConn (t )
136
- defer conn .Close ()
137
- _ , err := conn .ExecContext (ctx , "SET allow_unsafe_internals = $1" , test .AllowUnsafeInternals )
138
- require .NoError (t , err )
139
-
140
- _ , err = conn .QueryContext (ctx , q )
141
- if test .Passes {
142
- require .NoError (t , err )
143
- } else {
144
- checkUnsafeErr (t , err )
145
- }
146
- })
113
+ _ , err := txn .QueryBuffered (ctx , "internal-query" , txn .KV (), query )
114
+ return err
115
+ })
116
+ if err != nil {
117
+ return err
147
118
}
148
- })
149
-
150
- t .Run ("as an internal querier" , func (t * testing.T ) {
151
- for _ , test := range []struct {
152
- AllowUnsafeInternals bool
153
- Passes bool
154
- }{
155
- {AllowUnsafeInternals : false , Passes : true },
156
- {AllowUnsafeInternals : true , Passes : true },
157
- } {
158
- t .Run (fmt .Sprintf ("%t" , test ), func (t * testing.T ) {
159
- idb := s .InternalDB ().(isql.DB )
160
- err := idb .Txn (ctx , func (ctx context.Context , txn isql.Txn ) error {
161
- txn .SessionData ().LocalOnlySessionData .AllowUnsafeInternals = test .AllowUnsafeInternals
162
-
163
- _ , err := txn .QueryBuffered (ctx , "internal-query" , txn .KV (), q )
164
- return err
165
- })
166
-
167
- require .NoError (t , err )
168
-
169
- if test .Passes {
170
- require .NoError (t , err )
171
- } else {
172
- checkUnsafeErr (t , err )
173
- }
174
- })
119
+ } else {
120
+ conn , err := pool .Conn (ctx )
121
+ if err != nil {
122
+ return err
175
123
}
176
- })
177
- })
178
-
179
- t .Run ("accessing the crdb_internal schema" , func (t * testing.T ) {
180
- t .Run ("supported table allowed" , func (t * testing.T ) {
181
- conn := safeConn (t )
182
-
183
- // Supported crdb_internal tables should be allowed even when allow_unsafe_internals = false
184
- _ , err := conn .QueryContext (ctx , "SELECT * FROM crdb_internal.zones" )
185
- require .NoError (t , err , "supported crdb_internal table (zones) should be accessible when allow_unsafe_internals = false" )
186
- })
187
-
188
- t .Run ("unsupported table denied" , func (t * testing.T ) {
189
- conn := safeConn (t )
190
-
191
- // Unsupported crdb_internal tables should be denied when allow_unsafe_internals = false
192
- _ , err := conn .QueryContext (ctx , "SELECT * FROM crdb_internal.gossip_alerts" )
193
- checkUnsafeErr (t , err )
194
- })
195
- })
196
-
197
- // The functionality for this lies in the optbuilder package file,
198
- // but it is tested here as that package does not setup a test server.
199
- t .Run ("accessing crdb_internal builtins" , func (t * testing.T ) {
200
- t .Run ("non crdb_internal builtin allowed" , func (t * testing.T ) {
201
- conn := safeConn (t )
202
-
203
- // Non crdb_internal tables should be allowed.
204
- _ , err := conn .QueryContext (ctx , "SELECT * FROM generate_series(1,5)" )
205
- require .NoError (t , err )
206
- })
207
-
208
- t .Run ("crdb_internal builtin not allowed" , func (t * testing.T ) {
209
- conn := safeConn (t )
210
-
211
- // Unsupported crdb_internal builtins should be denied.
212
- _ , err := conn .QueryContext (ctx , "SELECT * FROM crdb_internal.tenant_span_stats()" )
213
- checkUnsafeErr (t , err )
214
- })
215
- })
216
-
217
- // The functionality for this check also lives in the optbuilder package
218
- // but is tested here.
219
- t .Run ("skips delegation" , func (t * testing.T ) {
220
- t .Run ("delegation is allowed" , func (t * testing.T ) {
221
- conn := safeConn (t )
222
-
223
- // tests delegation to builtins
224
- _ , err := conn .ExecContext (ctx , "show grants" )
225
- require .NoError (t , err )
226
-
227
- // tests delegation to crdb_internal tables
228
- _ , err = conn .ExecContext (ctx , "show databases" )
229
- require .NoError (t , err )
230
- })
231
-
232
- t .Run ("underlying tables which delegates rely on are not" , func (t * testing.T ) {
233
- conn := safeConn (t )
234
-
235
- // tests delegation to builtins
236
- _ , err := conn .ExecContext (ctx , "SELECT * FROM crdb_internal.privilege_name('DELETE')" )
237
- checkUnsafeErr (t , err )
124
+ _ , err = conn .ExecContext (ctx , "SET allow_unsafe_internals = $1" , allowUnsafe )
125
+ if err != nil {
126
+ return err
127
+ }
128
+ _ , err = conn .QueryContext (ctx , query )
129
+ if err != nil {
130
+ return err
131
+ }
132
+ }
133
+ return nil
134
+ }
238
135
239
- // tests delegation to crdb_internal tables
240
- _ , err = conn .ExecContext (ctx , "SELECT * FROM crdb_internal.databases" )
241
- checkUnsafeErr (t , err )
136
+ for _ , test := range []struct {
137
+ Query string
138
+ Internal bool
139
+ AllowUnsafeInternals bool
140
+ Passes bool
141
+ }{
142
+ // Regular tables aren't considered unsafe.
143
+ {
144
+ Query : "SELECT * FROM foo" ,
145
+ Passes : true ,
146
+ },
147
+ // Tests on the system objects.
148
+ {
149
+ Query : "SELECT * FROM system.namespace" ,
150
+ Internal : true ,
151
+ AllowUnsafeInternals : false ,
152
+ Passes : true ,
153
+ },
154
+ {
155
+ Query : "SELECT * FROM system.namespace" ,
156
+ Internal : true ,
157
+ AllowUnsafeInternals : true ,
158
+ Passes : true ,
159
+ },
160
+ {
161
+ Query : "SELECT * FROM system.namespace" ,
162
+ Internal : false ,
163
+ AllowUnsafeInternals : false ,
164
+ Passes : false ,
165
+ },
166
+ {
167
+ Query : "SELECT * FROM system.namespace" ,
168
+ Internal : false ,
169
+ AllowUnsafeInternals : true ,
170
+ Passes : true ,
171
+ },
172
+ // Tests on unsupported crdb_internal objects.
173
+ {
174
+ Query : "SELECT * FROM crdb_internal.gossip_alerts" ,
175
+ AllowUnsafeInternals : false ,
176
+ Passes : false ,
177
+ },
178
+ {
179
+ Query : "SELECT * FROM crdb_internal.gossip_alerts" ,
180
+ AllowUnsafeInternals : true ,
181
+ Passes : true ,
182
+ },
183
+ // Tests on supported crdb_internal objects.
184
+ {
185
+ Query : "SELECT * FROM crdb_internal.zones" ,
186
+ AllowUnsafeInternals : false ,
187
+ Passes : true ,
188
+ },
189
+ {
190
+ Query : "SELECT * FROM crdb_internal.zones" ,
191
+ AllowUnsafeInternals : true ,
192
+ Passes : true ,
193
+ },
194
+ // Non-crdb_internal functions pass
195
+ {
196
+ Query : "SELECT * FROM generate_series(1, 5)" ,
197
+ Passes : true ,
198
+ },
199
+ // Crdb_internal functions require the override.
200
+ {
201
+ Query : "SELECT * FROM crdb_internal.tenant_span_stats()" ,
202
+ Passes : false ,
203
+ },
204
+ {
205
+ Query : "SELECT * FROM crdb_internal.tenant_span_stats()" ,
206
+ AllowUnsafeInternals : true ,
207
+ Passes : true ,
208
+ },
209
+ // Tests on delegate behavior.
210
+ {
211
+ Query : "SHOW GRANTS" ,
212
+ Passes : true ,
213
+ },
214
+ {
215
+ // this query is what show grants is using under the hood.
216
+ Query : "SELECT * FROM crdb_internal.privilege_name('SELECT')" ,
217
+ Passes : false ,
218
+ },
219
+ {
220
+ Query : "SHOW DATABASES" ,
221
+ Passes : true ,
222
+ },
223
+ {
224
+ // this query is what show databases is using under the hood.
225
+ Query : "SELECT * FROM crdb_internal.databases" ,
226
+ Passes : false ,
227
+ },
228
+ } {
229
+ t .Run (fmt .Sprintf ("query=%s,internal=%t,allowUnsafe=%t" , test .Query , test .Internal , test .AllowUnsafeInternals ), func (t * testing.T ) {
230
+ err := sendQuery (test .AllowUnsafeInternals , test .Internal , test .Query )
231
+ if test .Passes {
232
+ require .NoError (t , err )
233
+ } else {
234
+ checkUnsafeErr (t , err )
235
+ }
242
236
})
243
- })
237
+ }
244
238
}
0 commit comments