Skip to content

Commit 3e2aa8c

Browse files
unsafesql: convert tests to table tests
This is just a bit of cleanup in the unsafesql test, many of the existing tests share a similar structure, so I created a table test layout so that they work better. Fixes: none Epic: none Release note: none
1 parent 6228e69 commit 3e2aa8c

File tree

1 file changed

+126
-132
lines changed

1 file changed

+126
-132
lines changed

pkg/sql/unsafesql/unsafesql_test.go

Lines changed: 126 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ package unsafesql_test
77

88
import (
99
"context"
10-
gosql "database/sql"
1110
"fmt"
1211
"os"
1312
"strings"
@@ -105,140 +104,135 @@ func TestAccessCheckServer(t *testing.T) {
105104
_, err := pool.Exec("CREATE TABLE foo (id INT PRIMARY KEY)")
106105
require.NoError(t, err)
107106

108-
// helper func for setting a safe connection.
109-
safeConn := func(t *testing.T) *gosql.Conn {
110-
conn, err := pool.Conn(ctx)
111-
require.NoError(t, err)
112-
_, err = conn.ExecContext(ctx, "SET allow_unsafe_internals = false")
113-
require.NoError(t, err)
114-
return conn
115-
}
116-
117-
t.Run("regular user activity is unaffected", func(t *testing.T) {
118-
conn := safeConn(t)
119-
_, err := conn.QueryContext(ctx, "SELECT * FROM foo")
120-
require.NoError(t, err)
121-
})
107+
sendQuery := func(allowUnsafe bool, internal bool, query string) error {
108+
if internal {
109+
idb := s.InternalDB().(isql.DB)
110+
err := idb.Txn(ctx, func(ctx context.Context, txn isql.Txn) error {
111+
txn.SessionData().LocalOnlySessionData.AllowUnsafeInternals = allowUnsafe
122112

123-
t.Run("accessing the system database", func(t *testing.T) {
124-
q := "SELECT * FROM system.namespace"
125-
126-
t.Run("as an external querier", func(t *testing.T) {
127-
for _, test := range []struct {
128-
AllowUnsafeInternals bool
129-
Passes bool
130-
}{
131-
{AllowUnsafeInternals: false, Passes: false},
132-
{AllowUnsafeInternals: true, Passes: true},
133-
} {
134-
t.Run(fmt.Sprintf("%t", test), func(t *testing.T) {
135-
conn := s.SQLConn(t)
136-
defer conn.Close()
137-
_, err := conn.ExecContext(ctx, "SET allow_unsafe_internals = $1", test.AllowUnsafeInternals)
138-
require.NoError(t, err)
139-
140-
_, err = conn.QueryContext(ctx, q)
141-
if test.Passes {
142-
require.NoError(t, err)
143-
} else {
144-
checkUnsafeErr(t, err)
145-
}
146-
})
113+
_, err := txn.QueryBuffered(ctx, "internal-query", txn.KV(), query)
114+
return err
115+
})
116+
if err != nil {
117+
return err
147118
}
148-
})
149-
150-
t.Run("as an internal querier", func(t *testing.T) {
151-
for _, test := range []struct {
152-
AllowUnsafeInternals bool
153-
Passes bool
154-
}{
155-
{AllowUnsafeInternals: false, Passes: true},
156-
{AllowUnsafeInternals: true, Passes: true},
157-
} {
158-
t.Run(fmt.Sprintf("%t", test), func(t *testing.T) {
159-
idb := s.InternalDB().(isql.DB)
160-
err := idb.Txn(ctx, func(ctx context.Context, txn isql.Txn) error {
161-
txn.SessionData().LocalOnlySessionData.AllowUnsafeInternals = test.AllowUnsafeInternals
162-
163-
_, err := txn.QueryBuffered(ctx, "internal-query", txn.KV(), q)
164-
return err
165-
})
166-
167-
require.NoError(t, err)
168-
169-
if test.Passes {
170-
require.NoError(t, err)
171-
} else {
172-
checkUnsafeErr(t, err)
173-
}
174-
})
119+
} else {
120+
conn, err := pool.Conn(ctx)
121+
if err != nil {
122+
return err
175123
}
176-
})
177-
})
178-
179-
t.Run("accessing the crdb_internal schema", func(t *testing.T) {
180-
t.Run("supported table allowed", func(t *testing.T) {
181-
conn := safeConn(t)
182-
183-
// Supported crdb_internal tables should be allowed even when allow_unsafe_internals = false
184-
_, err := conn.QueryContext(ctx, "SELECT * FROM crdb_internal.zones")
185-
require.NoError(t, err, "supported crdb_internal table (zones) should be accessible when allow_unsafe_internals = false")
186-
})
187-
188-
t.Run("unsupported table denied", func(t *testing.T) {
189-
conn := safeConn(t)
190-
191-
// Unsupported crdb_internal tables should be denied when allow_unsafe_internals = false
192-
_, err := conn.QueryContext(ctx, "SELECT * FROM crdb_internal.gossip_alerts")
193-
checkUnsafeErr(t, err)
194-
})
195-
})
196-
197-
// The functionality for this lies in the optbuilder package file,
198-
// but it is tested here as that package does not setup a test server.
199-
t.Run("accessing crdb_internal builtins", func(t *testing.T) {
200-
t.Run("non crdb_internal builtin allowed", func(t *testing.T) {
201-
conn := safeConn(t)
202-
203-
// Non crdb_internal tables should be allowed.
204-
_, err := conn.QueryContext(ctx, "SELECT * FROM generate_series(1,5)")
205-
require.NoError(t, err)
206-
})
207-
208-
t.Run("crdb_internal builtin not allowed", func(t *testing.T) {
209-
conn := safeConn(t)
210-
211-
// Unsupported crdb_internal builtins should be denied.
212-
_, err := conn.QueryContext(ctx, "SELECT * FROM crdb_internal.tenant_span_stats()")
213-
checkUnsafeErr(t, err)
214-
})
215-
})
216-
217-
// The functionality for this check also lives in the optbuilder package
218-
// but is tested here.
219-
t.Run("skips delegation", func(t *testing.T) {
220-
t.Run("delegation is allowed", func(t *testing.T) {
221-
conn := safeConn(t)
222-
223-
// tests delegation to builtins
224-
_, err := conn.ExecContext(ctx, "show grants")
225-
require.NoError(t, err)
226-
227-
// tests delegation to crdb_internal tables
228-
_, err = conn.ExecContext(ctx, "show databases")
229-
require.NoError(t, err)
230-
})
231-
232-
t.Run("underlying tables which delegates rely on are not", func(t *testing.T) {
233-
conn := safeConn(t)
234-
235-
// tests delegation to builtins
236-
_, err := conn.ExecContext(ctx, "SELECT * FROM crdb_internal.privilege_name('DELETE')")
237-
checkUnsafeErr(t, err)
124+
_, err = conn.ExecContext(ctx, "SET allow_unsafe_internals = $1", allowUnsafe)
125+
if err != nil {
126+
return err
127+
}
128+
_, err = conn.QueryContext(ctx, query)
129+
if err != nil {
130+
return err
131+
}
132+
}
133+
return nil
134+
}
238135

239-
// tests delegation to crdb_internal tables
240-
_, err = conn.ExecContext(ctx, "SELECT * FROM crdb_internal.databases")
241-
checkUnsafeErr(t, err)
136+
for _, test := range []struct {
137+
Query string
138+
Internal bool
139+
AllowUnsafeInternals bool
140+
Passes bool
141+
}{
142+
// Regular tables aren't considered unsafe.
143+
{
144+
Query: "SELECT * FROM foo",
145+
Passes: true,
146+
},
147+
// Tests on the system objects.
148+
{
149+
Query: "SELECT * FROM system.namespace",
150+
Internal: true,
151+
AllowUnsafeInternals: false,
152+
Passes: true,
153+
},
154+
{
155+
Query: "SELECT * FROM system.namespace",
156+
Internal: true,
157+
AllowUnsafeInternals: true,
158+
Passes: true,
159+
},
160+
{
161+
Query: "SELECT * FROM system.namespace",
162+
Internal: false,
163+
AllowUnsafeInternals: false,
164+
Passes: false,
165+
},
166+
{
167+
Query: "SELECT * FROM system.namespace",
168+
Internal: false,
169+
AllowUnsafeInternals: true,
170+
Passes: true,
171+
},
172+
// Tests on unsupported crdb_internal objects.
173+
{
174+
Query: "SELECT * FROM crdb_internal.gossip_alerts",
175+
AllowUnsafeInternals: false,
176+
Passes: false,
177+
},
178+
{
179+
Query: "SELECT * FROM crdb_internal.gossip_alerts",
180+
AllowUnsafeInternals: true,
181+
Passes: true,
182+
},
183+
// Tests on supported crdb_internal objects.
184+
{
185+
Query: "SELECT * FROM crdb_internal.zones",
186+
AllowUnsafeInternals: false,
187+
Passes: true,
188+
},
189+
{
190+
Query: "SELECT * FROM crdb_internal.zones",
191+
AllowUnsafeInternals: true,
192+
Passes: true,
193+
},
194+
// Non-crdb_internal functions pass
195+
{
196+
Query: "SELECT * FROM generate_series(1, 5)",
197+
Passes: true,
198+
},
199+
// Crdb_internal functions require the override.
200+
{
201+
Query: "SELECT * FROM crdb_internal.tenant_span_stats()",
202+
Passes: false,
203+
},
204+
{
205+
Query: "SELECT * FROM crdb_internal.tenant_span_stats()",
206+
AllowUnsafeInternals: true,
207+
Passes: true,
208+
},
209+
// Tests on delegate behavior.
210+
{
211+
Query: "SHOW GRANTS",
212+
Passes: true,
213+
},
214+
{
215+
// this query is what show grants is using under the hood.
216+
Query: "SELECT * FROM crdb_internal.privilege_name('SELECT')",
217+
Passes: false,
218+
},
219+
{
220+
Query: "SHOW DATABASES",
221+
Passes: true,
222+
},
223+
{
224+
// this query is what show databases is using under the hood.
225+
Query: "SELECT * FROM crdb_internal.databases",
226+
Passes: false,
227+
},
228+
} {
229+
t.Run(fmt.Sprintf("query=%s,internal=%t,allowUnsafe=%t", test.Query, test.Internal, test.AllowUnsafeInternals), func(t *testing.T) {
230+
err := sendQuery(test.AllowUnsafeInternals, test.Internal, test.Query)
231+
if test.Passes {
232+
require.NoError(t, err)
233+
} else {
234+
checkUnsafeErr(t, err)
235+
}
242236
})
243-
})
237+
}
244238
}

0 commit comments

Comments
 (0)