Skip to content

Commit 22f2187

Browse files
committed
Planbuilder Authorization
1 parent c5725b1 commit 22f2187

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+706
-209
lines changed

engine.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ type Config struct {
6363
// disabled, and including any users here will enable authentication. All users in this list will have full access.
6464
// This field is only temporary, and will be removed as development on users and authentication continues.
6565
TemporaryUsers []TemporaryUser
66+
// AuthorizationHandler sets the handler that will be used for authorization on all calls. Will use the default
67+
// handler if one is not specified.
68+
AuthorizationHandler planbuilder.AuthorizationHandler
6669
}
6770

6871
// TemporaryUser is a user that will be added to the engine. This is for temporary use while the remaining features
@@ -148,6 +151,7 @@ type Engine struct {
148151
Version sql.AnalyzerVersion
149152
EventScheduler *eventscheduler.EventScheduler
150153
Parser sql.Parser
154+
AuthHandler planbuilder.AuthorizationHandler
151155
}
152156

153157
type ColumnWithRawDefault struct {
@@ -167,6 +171,13 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine {
167171
a.Catalog.MySQLDb.AddRootAccount()
168172
}
169173

174+
var authHandler planbuilder.AuthorizationHandler
175+
if cfg.AuthorizationHandler != nil {
176+
authHandler = cfg.AuthorizationHandler
177+
} else {
178+
authHandler = planbuilder.DefaultAuthorizationHandler()
179+
}
180+
170181
ls := sql.NewLockSubsystem()
171182

172183
variables.InitStatusVariables()
@@ -193,6 +204,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine {
193204
mu: &sync.Mutex{},
194205
EventScheduler: nil,
195206
Parser: sql.GlobalParser,
207+
AuthHandler: authHandler,
196208
}
197209
ret.ReadOnly.Store(cfg.IsReadOnly)
198210
return ret
@@ -209,7 +221,7 @@ func (e *Engine) AnalyzeQuery(
209221
ctx *sql.Context,
210222
query string,
211223
) (sql.Node, error) {
212-
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
224+
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser, e.AuthHandler)
213225
parsed, _, _, qFlags, err := binder.Parse(query, nil, false)
214226
if err != nil {
215227
return nil, err
@@ -237,7 +249,7 @@ func (e *Engine) PrepareParsedQuery(
237249
statementKey, query string,
238250
stmt sqlparser.Statement,
239251
) (sql.Node, error) {
240-
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
252+
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser, e.AuthHandler)
241253
node, _, err := binder.BindOnly(stmt, query, nil)
242254

243255
if err != nil {
@@ -495,7 +507,7 @@ func (e *Engine) BoundQueryPlan(ctx *sql.Context, query string, parsed sqlparser
495507

496508
query = sql.RemoveSpaceAndDelimiter(query, ';')
497509

498-
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
510+
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser, e.AuthHandler)
499511
binder.SetBindings(bindings)
500512

501513
// Begin a transaction if necessary (no-op if one is in flight)
@@ -549,7 +561,7 @@ func (e *Engine) preparedStatement(ctx *sql.Context, query string, parsed sqlpar
549561
preparedAst, preparedDataFound = e.PreparedDataCache.GetCachedStmt(ctx.Session.ID(), query)
550562
}
551563

552-
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
564+
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser, e.AuthHandler)
553565
if preparedDataFound {
554566
parsed = preparedAst
555567
binder.SetBindings(bindings)
@@ -782,6 +794,10 @@ func (e *Engine) EngineAnalyzer() *analyzer.Analyzer {
782794
return e.Analyzer
783795
}
784796

797+
func (e *Engine) AuthorizationHandler() planbuilder.AuthorizationHandler {
798+
return e.AuthHandler
799+
}
800+
785801
// InitializeEventScheduler initializes the EventScheduler for the engine with the given sql.Context
786802
// getter function, |ctxGetterFunc, the EventScheduler |status|, and the |period| for the event scheduler
787803
// to check for events to execute. If |period| is less than 1, then it is ignored and the default period

enginetest/engine_only_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ func TestAnalyzer_Exp(t *testing.T) {
510510
require.NoError(t, err)
511511

512512
ctx := enginetest.NewContext(harness)
513-
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser())
513+
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser(), e.AuthorizationHandler())
514514
parsed, _, _, _, err := b.Parse(tt.query, nil, false)
515515
require.NoError(t, err)
516516

enginetest/enginetests.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ func TestQueryPlanWithName(t *testing.T, name string, harness Harness, e QueryEn
584584
t.Run(name, func(t *testing.T) {
585585
ctx := NewContext(harness)
586586

587-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, query)
587+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), query)
588588
require.NoError(t, err)
589589

590590
node, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
@@ -611,7 +611,7 @@ func TestQueryPlanWithName(t *testing.T, name string, harness Harness, e QueryEn
611611
func TestQueryPlanWithEngine(t *testing.T, harness Harness, e QueryEngine, tt queries.QueryPlanTest, verbose bool) {
612612
t.Run(tt.Query, func(t *testing.T) {
613613
ctx := NewContext(harness)
614-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, tt.Query)
614+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), tt.Query)
615615
require.NoError(t, err)
616616

617617
node, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
@@ -1360,7 +1360,7 @@ func TestTruncate(t *testing.T, harness Harness) {
13601360
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t5 ORDER BY 1", []sql.Row{{int64(1), int64(1)}, {int64(2), int64(2)}}, nil, nil, nil)
13611361

13621362
deleteStr := "DELETE FROM t5"
1363-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1363+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
13641364
require.NoError(t, err)
13651365
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
13661366
require.NoError(t, err)
@@ -1389,7 +1389,7 @@ func TestTruncate(t *testing.T, harness Harness) {
13891389
RunQueryWithContext(t, e, harness, ctx, "INSERT INTO t6parent VALUES (1,1), (2,2)")
13901390
RunQueryWithContext(t, e, harness, ctx, "INSERT INTO t6child VALUES (1,1), (2,2)")
13911391

1392-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, "DELETE FROM t6parent")
1392+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), "DELETE FROM t6parent")
13931393
require.NoError(t, err)
13941394
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
13951395
require.NoError(t, err)
@@ -1417,7 +1417,7 @@ func TestTruncate(t *testing.T, harness Harness) {
14171417
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t7i ORDER BY 1", []sql.Row{{int64(3), int64(3)}}, nil, nil, nil)
14181418

14191419
deleteStr := "DELETE FROM t7"
1420-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1420+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
14211421
require.NoError(t, err)
14221422
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
14231423
require.NoError(t, err)
@@ -1445,7 +1445,7 @@ func TestTruncate(t *testing.T, harness Harness) {
14451445
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t8 ORDER BY 1", []sql.Row{{int64(1), int64(4)}, {int64(2), int64(5)}}, nil, nil, nil)
14461446

14471447
deleteStr := "DELETE FROM t8"
1448-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1448+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
14491449
require.NoError(t, err)
14501450
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
14511451
require.NoError(t, err)
@@ -1474,7 +1474,7 @@ func TestTruncate(t *testing.T, harness Harness) {
14741474
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t9 ORDER BY 1", []sql.Row{{int64(7), int64(7)}, {int64(8), int64(8)}}, nil, nil, nil)
14751475

14761476
deleteStr := "DELETE FROM t9 WHERE pk > 0"
1477-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1477+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
14781478
require.NoError(t, err)
14791479
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
14801480
require.NoError(t, err)
@@ -1501,7 +1501,7 @@ func TestTruncate(t *testing.T, harness Harness) {
15011501
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t10 ORDER BY 1", []sql.Row{{int64(8), int64(8)}, {int64(9), int64(9)}}, nil, nil, nil)
15021502

15031503
deleteStr := "DELETE FROM t10 LIMIT 1000"
1504-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1504+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
15051505
require.NoError(t, err)
15061506
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
15071507
require.NoError(t, err)
@@ -1528,7 +1528,7 @@ func TestTruncate(t *testing.T, harness Harness) {
15281528
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t11 ORDER BY 1", []sql.Row{{int64(1), int64(1)}, {int64(9), int64(9)}}, nil, nil, nil)
15291529

15301530
deleteStr := "DELETE FROM t11 ORDER BY 1"
1531-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1531+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
15321532
require.NoError(t, err)
15331533
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
15341534
require.NoError(t, err)
@@ -1559,7 +1559,7 @@ func TestTruncate(t *testing.T, harness Harness) {
15591559
TestQueryWithContext(t, ctx, e, harness, "SELECT * FROM t12b ORDER BY 1", []sql.Row{{int64(1), int64(1)}, {int64(2), int64(2)}}, nil, nil, nil)
15601560

15611561
deleteStr := "DELETE t12a, t12b FROM t12a INNER JOIN t12b WHERE t12a.pk=t12b.pk"
1562-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, deleteStr)
1562+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), deleteStr)
15631563
require.NoError(t, err)
15641564
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, qFlags)
15651565
require.NoError(t, err)

enginetest/evaluation.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ func injectBindVarsAndPrepare(
526526
}
527527
}
528528

529-
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser())
529+
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser(), e.AuthorizationHandler())
530530
b.SetParserOptions(sql.LoadSqlMode(ctx).ParserOptions())
531531
resPlan, _, err := b.BindOnly(parsed, q, nil)
532532
if err != nil {

enginetest/harness.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup"
2424
"github.com/dolthub/go-mysql-server/server"
2525
"github.com/dolthub/go-mysql-server/sql"
26+
"github.com/dolthub/go-mysql-server/sql/planbuilder"
2627
)
2728

2829
// Harness provides a way for database integrators to validate their implementation against the standard set of queries
@@ -173,3 +174,11 @@ type ResultEvaluationHarness interface {
173174
// EvaluateExpectedErrorKind compares expected error kinds to actual errors and emits failed test assertions in the
174175
EvaluateExpectedErrorKind(t *testing.T, expected *errors.Kind, err error)
175176
}
177+
178+
// AuthorizingHarness specifies the AuthorizationHandler that should be used.
179+
type AuthorizingHarness interface {
180+
Harness
181+
182+
// AuthorizationHandler returns the handler to use.
183+
AuthorizationHandler() planbuilder.AuthorizationHandler
184+
}

enginetest/initialization.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ func NewEngineWithProvider(_ *testing.T, harness Harness, provider sql.DatabaseP
103103
if idh, ok := harness.(IndexDriverHarness); ok {
104104
idh.InitializeIndexDriver(engine.Analyzer.Catalog.AllDatabases(NewContext(harness)))
105105
}
106+
if ah, ok := harness.(AuthorizingHarness); ok {
107+
engine.AuthHandler = ah.AuthorizationHandler()
108+
}
106109

107110
return engine
108111
}

enginetest/join_planning_tests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1777,7 +1777,7 @@ func evalJoinTypeTest(t *testing.T, harness Harness, e QueryEngine, query string
17771777
}
17781778

17791779
func analyzeQuery(ctx *sql.Context, e QueryEngine, query string) (sql.Node, error) {
1780-
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, query)
1780+
parsed, qFlags, err := planbuilder.Parse(ctx, e.EngineAnalyzer().Catalog, e.AuthorizationHandler(), query)
17811781
if err != nil {
17821782
return nil, err
17831783
}

enginetest/mysqlshim/table.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ func (t Table) getCreateTable() (*plan.CreateTable, error) {
381381
return nil, sql.ErrTableNotFound.New(t.name)
382382
}
383383
// TODO add catalog
384-
createTableNode, _, err := planbuilder.Parse(sql.NewEmptyContext(), sql.MapCatalog{Tables: map[string]sql.Table{t.name: t}}, rows[0][1].(string))
384+
createTableNode, _, err := planbuilder.Parse(sql.NewEmptyContext(), sql.MapCatalog{Tables: map[string]sql.Table{t.name: t}}, nil, rows[0][1].(string))
385385
if err != nil {
386386
return nil, err
387387
}

enginetest/plangen/cmd/plangen/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func generatePlansForSuite(spec PlanSpec, w *bytes.Buffer) error {
165165

166166
if !tt.Skip {
167167
ctx := enginetest.NewContextWithEngine(harness, engine)
168-
binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, sql.NewMysqlParser())
168+
binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, sql.NewMysqlParser(), engine.AuthorizationHandler())
169169
parsed, _, _, qFlags, err := binder.Parse(tt.Query, nil, false)
170170
if err != nil {
171171
exit(fmt.Errorf("%w\nfailed to parse query: %s", err, tt.Query))

enginetest/queries/priv_auth_queries.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1780,10 +1780,10 @@ var UserPrivTests = []UserPrivilegeTest{
17801780
},
17811781
},
17821782
{
1783-
User: "rand_user1",
1784-
Host: "54.244.85.252",
1785-
Query: "SELECT * FROM mydb.test;",
1786-
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
1783+
User: "rand_user1",
1784+
Host: "54.244.85.252",
1785+
Query: "SELECT * FROM mydb.test;",
1786+
ExpectedErrStr: "Access denied for user 'rand_user1' (errno 1045) (sqlstate 28000)",
17871787
},
17881788
{
17891789
User: "rand_user2",
@@ -1804,10 +1804,10 @@ var UserPrivTests = []UserPrivilegeTest{
18041804
},
18051805
},
18061806
{
1807-
User: "rand_user2",
1808-
Host: "54.244.85.252",
1809-
Query: "SELECT * FROM mydb.test2;",
1810-
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
1807+
User: "rand_user2",
1808+
Host: "54.244.85.252",
1809+
Query: "SELECT * FROM mydb.test2;",
1810+
ExpectedErrStr: "Access denied for user 'rand_user2' (errno 1045) (sqlstate 28000)",
18111811
},
18121812
},
18131813
},

0 commit comments

Comments
 (0)