diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 5727e5677c..e8515616de 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -866,10 +866,6 @@ func (s SimpleTableFunction) WithChildren(_ ...sql.Node) (sql.Node, error) { return s, nil } -func (s SimpleTableFunction) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (SimpleTableFunction) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 06a9bb1915..b2453710d9 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -1028,6 +1028,9 @@ func newMergableIndex(dbs []sql.Database, tableName string, exprs ...sql.Express if db == nil { return nil } + if tableRevision, ok := table.(*memory.TableRevision); ok { + table = tableRevision.Table + } return &memory.Index{ DB: db.Name(), DriverName: memory.IndexDriverId, diff --git a/enginetest/queries/priv_auth_queries.go b/enginetest/queries/priv_auth_queries.go index 8b5ef3e694..3925889f65 100644 --- a/enginetest/queries/priv_auth_queries.go +++ b/enginetest/queries/priv_auth_queries.go @@ -1780,10 +1780,10 @@ var UserPrivTests = []UserPrivilegeTest{ }, }, { - User: "rand_user1", - Host: "54.244.85.252", - Query: "SELECT * FROM mydb.test;", - ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + User: "rand_user1", + Host: "54.244.85.252", + Query: "SELECT * FROM mydb.test;", + ExpectedErrStr: "Access denied for user 'rand_user1' (errno 1045) (sqlstate 28000)", }, { User: "rand_user2", @@ -1804,10 +1804,10 @@ var UserPrivTests = []UserPrivilegeTest{ }, }, { - User: "rand_user2", - Host: "54.244.85.252", - Query: "SELECT * FROM mydb.test2;", - ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + User: "rand_user2", + Host: "54.244.85.252", + Query: "SELECT * FROM mydb.test2;", + ExpectedErrStr: "Access denied for user 'rand_user2' (errno 1045) (sqlstate 28000)", }, }, }, diff --git a/go.mod b/go.mod index 4d1f10112b..1183344539 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20241028204000-267861bc75a0 + github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 56fd29ccd5..c328f81c8c 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20241028204000-267861bc75a0 h1:eeKypNsi1nQmjWxSAAWT6tvRsDWdmll03BozAUUIE4E= -github.com/dolthub/vitess v0.0.0-20241028204000-267861bc75a0/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683 h1:2/RJeUfNAXS7mbBnEr9C36htiCJHk5XldDPzhxtEsME= +github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/memory/exponential_dist_table.go b/memory/exponential_dist_table.go index 0369932967..29af52070f 100644 --- a/memory/exponential_dist_table.go +++ b/memory/exponential_dist_table.go @@ -108,10 +108,6 @@ func (s ExponentialDistTable) WithChildren(_ ...sql.Node) (sql.Node, error) { return s, nil } -func (s ExponentialDistTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (ExponentialDistTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 diff --git a/memory/normal_dist_table.go b/memory/normal_dist_table.go index 21d4a8e79d..79e558a5b5 100644 --- a/memory/normal_dist_table.go +++ b/memory/normal_dist_table.go @@ -119,10 +119,6 @@ func (s NormalDistTable) WithChildren(_ ...sql.Node) (sql.Node, error) { return s, nil } -func (s NormalDistTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (NormalDistTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 diff --git a/memory/sequence_table.go b/memory/sequence_table.go index 2b0ca64125..e92fa77008 100644 --- a/memory/sequence_table.go +++ b/memory/sequence_table.go @@ -101,10 +101,6 @@ func (s IntSequenceTable) WithChildren(_ ...sql.Node) (sql.Node, error) { return s, nil } -func (s IntSequenceTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (IntSequenceTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 diff --git a/memory/table_function.go b/memory/table_function.go index e52866571c..060e22508c 100644 --- a/memory/table_function.go +++ b/memory/table_function.go @@ -91,10 +91,6 @@ func (s TableFunc) WithChildren(_ ...sql.Node) (sql.Node, error) { return s, nil } -func (s TableFunc) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (TableFunc) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index c8e458790a..a7b45bad3e 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -34,6 +34,7 @@ type Catalog struct { StatsProvider sql.StatsProvider DbProvider sql.DatabaseProvider + AuthHandler sql.AuthorizationHandler builtInFunctions function.Registry // BinlogReplicaController holds an optional controller that receives forwarded binlog @@ -64,7 +65,7 @@ type sessionLocks map[uint32]dbLocks // NewCatalog returns a new empty Catalog with the given provider func NewCatalog(provider sql.DatabaseProvider) *Catalog { - return &Catalog{ + c := &Catalog{ MySQLDb: mysql_db.CreateEmptyMySQLDb(), InfoSchema: information_schema.NewInformationSchemaDatabase(), DbProvider: provider, @@ -72,6 +73,8 @@ func NewCatalog(provider sql.DatabaseProvider) *Catalog { StatsProvider: memory.NewStatsProv(), locks: make(sessionLocks), } + c.AuthHandler = sql.GetAuthorizationHandlerFactory().CreateHandler(c) + return c } func (c *Catalog) HasBinlogReplicaController() bool { @@ -109,7 +112,7 @@ func (c *Catalog) AllDatabases(ctx *sql.Context) []sql.Database { dbs = append(dbs, c.InfoSchema) if c.MySQLDb.Enabled() { - dbs = append(dbs, mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider).AllDatabases(ctx)...) + dbs = append(dbs, mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider, c.AuthHandler).AllDatabases(ctx)...) } else { dbs = append(dbs, c.DbProvider.AllDatabases(ctx)...) } @@ -162,7 +165,7 @@ func (c *Catalog) HasDatabase(ctx *sql.Context, db string) bool { if db == "information_schema" { return true } else if c.MySQLDb.Enabled() { - return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider).HasDatabase(ctx, db) + return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider, c.AuthHandler).HasDatabase(ctx, db) } else { return c.DbProvider.HasDatabase(ctx, db) } @@ -173,7 +176,7 @@ func (c *Catalog) Database(ctx *sql.Context, db string) (sql.Database, error) { if strings.ToLower(db) == "information_schema" { return c.InfoSchema, nil } else if c.MySQLDb.Enabled() { - return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider).Database(ctx, db) + return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider, c.AuthHandler).Database(ctx, db) } else { return c.DbProvider.Database(ctx, db) } @@ -440,6 +443,10 @@ func (c *Catalog) DataLength(ctx *sql.Context, db string, table sql.Table) (uint return st.DataLength(ctx) } +func (c *Catalog) AuthorizationHandler() sql.AuthorizationHandler { + return c.AuthHandler +} + func getStatisticsTable(table sql.Table, prevTable sql.Table) (sql.StatisticsTable, bool) { // Some TableNodes return themselves for UnderlyingTable, so we need to check for that if table == prevTable { diff --git a/sql/analyzer/load_triggers.go b/sql/analyzer/load_triggers.go index 9c9dcf5373..bcbc652444 100644 --- a/sql/analyzer/load_triggers.go +++ b/sql/analyzer/load_triggers.go @@ -105,6 +105,7 @@ func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database) ([]*plan for _, trigger := range triggers { var parsedTrigger sql.Node sqlMode := sql.NewSqlModeFromString(trigger.SqlMode) + // TODO: should perhaps add the auth query handler to the analyzer? does this even use auth? parsedTrigger, _, err = planbuilder.ParseWithOptions(ctx, a.Catalog, trigger.CreateStatement, sqlMode.ParserOptions()) if err != nil { return nil, err diff --git a/sql/analyzer/node_batches.go b/sql/analyzer/node_batches.go index 78fa6e234a..2f48169aa6 100644 --- a/sql/analyzer/node_batches.go +++ b/sql/analyzer/node_batches.go @@ -29,10 +29,6 @@ func getBatchesForNode(node sql.Node) ([]*Batch, bool) { Id: applyForeignKeysId, Apply: applyForeignKeys, }, - { - Id: validatePrivilegesId, - Apply: validatePrivileges, - }, { Id: validateReadOnlyDatabaseId, Apply: validateReadOnlyDatabase, @@ -74,10 +70,6 @@ func getBatchesForNode(node sql.Node) ([]*Batch, bool) { Id: applyForeignKeysId, Apply: applyForeignKeys, }, - { - Id: validatePrivilegesId, - Apply: validatePrivileges, - }, { Id: optimizeJoinsId, Apply: optimizeJoins, @@ -123,10 +115,6 @@ func getBatchesForNode(node sql.Node) ([]*Batch, bool) { Id: applyForeignKeysId, Apply: applyForeignKeys, }, - { - Id: validatePrivilegesId, - Apply: validatePrivileges, - }, { Id: optimizeJoinsId, Apply: optimizeJoins, diff --git a/sql/analyzer/privileges.go b/sql/analyzer/privileges.go deleted file mode 100644 index 06fc04b859..0000000000 --- a/sql/analyzer/privileges.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2021 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package analyzer - -import ( - "github.com/dolthub/vitess/go/mysql" - - "github.com/dolthub/go-mysql-server/sql/transform" - - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/mysql_db" - "github.com/dolthub/go-mysql-server/sql/plan" -) - -// validatePrivileges verifies the given statement (node n) by checking that the calling user has the necessary privileges -// to execute it. -// TODO: add the remaining statements that interact with the grant tables -func validatePrivileges(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - mysqlDb := a.Catalog.MySQLDb - - switch n.(type) { - case *plan.CreateUser, *plan.DropUser, *plan.RenameUser, *plan.CreateRole, *plan.DropRole, - *plan.Grant, *plan.GrantRole, *plan.GrantProxy, *plan.Revoke, *plan.RevokeRole, *plan.RevokeAll, *plan.RevokeProxy: - mysqlDb.SetEnabled(true) - } - if !mysqlDb.Enabled() { - return n, transform.SameTree, nil - } - - client := ctx.Session.Client() - user := func() *mysql_db.User { - rd := mysqlDb.Reader() - defer rd.Close() - return mysqlDb.GetUser(rd, client.User, client.Address, false) - }() - if user == nil { - return nil, transform.SameTree, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", ctx.Session.Client().User) - } - - // TODO: this is incorrect, getTable returns only the first table, there could be others in the tree - if plan.IsDualTable(getTable(n)) { - return n, transform.SameTree, nil - } - if !n.CheckPrivileges(ctx, mysqlDb) { - return nil, transform.SameTree, sql.ErrPrivilegeCheckFailed.New(user.UserHostToString("'")) - } - return n, transform.SameTree, nil -} diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 9970e56279..f4b4c7e5bc 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -50,9 +50,8 @@ var OnceBeforeDefault = []Rule{ {validateReadOnlyTransactionId, validateReadOnlyTransaction}, {validateDatabaseSetId, validateDatabaseSet}, {validateDeleteFromId, validateDeleteFrom}, - {validatePrivilegesId, validatePrivileges}, // Ensure that checking privileges happens after db, table & table function resolution - {simplifyFiltersId, simplifyFilters}, //TODO inline? - {pushNotFiltersId, pushNotFilters}, //TODO inline? + {simplifyFiltersId, simplifyFilters}, //TODO inline? + {pushNotFiltersId, pushNotFilters}, //TODO inline? {hoistOutOfScopeFiltersId, hoistOutOfScopeFilters}, } diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 188d2125d8..70d43f4e13 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -55,6 +55,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan var procToRegister *plan.Procedure var parsedProcedure sql.Node b := planbuilder.New(ctx, a.Catalog, nil, nil) + b.DisableAuth() b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) if err != nil { @@ -290,6 +291,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop } var parsedProcedure sql.Node b := planbuilder.New(ctx, a.Catalog, nil, nil) + b.DisableAuth() b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) if call.AsOf() != nil { asOf, err := call.AsOf().Eval(ctx, nil) diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index e94f056b74..52e762adcd 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -194,6 +194,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, } b := planbuilder.New(ctx, a.Catalog, nil, nil) + b.DisableAuth() prevActive := b.TriggerCtx().Active b.TriggerCtx().Active = true defer func() { diff --git a/sql/analyzer/validation_rules_test.go b/sql/analyzer/validation_rules_test.go index 2763d60329..7913f78a91 100644 --- a/sql/analyzer/validation_rules_test.go +++ b/sql/analyzer/validation_rules_test.go @@ -764,9 +764,6 @@ func (dummyNode) Schema() sql.Schema { return ni func (dummyNode) Children() []sql.Node { return nil } func (dummyNode) RowIter(*sql.Context, sql.Row) (sql.RowIter, error) { return nil, nil } func (dummyNode) WithChildren(...sql.Node) (sql.Node, error) { return nil, nil } -func (dummyNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} func (dummyNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 } diff --git a/sql/auth.go b/sql/auth.go new file mode 100644 index 0000000000..dea287a734 --- /dev/null +++ b/sql/auth.go @@ -0,0 +1,88 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +// AuthorizationQueryState contains any state that should be retained for an entire query, so that it is loaded once and +// then reused. If an error was created in NewQueryState, then that error should be retained here and returned in +// HandleAuth. +type AuthorizationQueryState interface { + // Error returns the error contained within the state if one exists, otherwise returns nil. + Error() error + // AuthorizationQueryStateImpl has no actual usage besides enforcing that the returned object adheres to this + // interface. + AuthorizationQueryStateImpl() +} + +// AuthorizationHandlerFactory creates an AuthorizationHandler, which will be used for all authorization needs. +type AuthorizationHandlerFactory interface { + // CreateHandler creates an AuthorizationHandler from the given catalog. + CreateHandler(cat Catalog) AuthorizationHandler +} + +// AuthorizationHandler handles the authorization of queries, generally through the use of a privilege system. This +// handler exists to create handlers that will operate on a single query. +type AuthorizationHandler interface { + // NewQueryState returns some kind of state that should be retained for an entire query. This is for the purposes of + // optimization. If nothing needs to be retained between calls for a single query, then it is valid for this to + // return nil. If an error occurs, then the state should contain the error so that it may be returned in HandleAuth. + NewQueryState(ctx *Context) AuthorizationQueryState + // HandleAuth checks whether the authentication information is valid. The state may be nil, therefore this function + // may need to internally construct the state if a nil one is provided. + HandleAuth(ctx *Context, state AuthorizationQueryState, auth ast.AuthInformation) error + // HandleAuthNode handles the authentication of nodes that implement AuthorizationCheckerNode. These are often used + // by integrators that do not modify the AST, and instead prefer a node-based form of authentication. + HandleAuthNode(ctx *Context, state AuthorizationQueryState, node AuthorizationCheckerNode) error + // CheckDatabase returns nil when access to the given database is permitted. If the database name is empty, it + // should be assumed that the context's current database is being checked. + CheckDatabase(ctx *Context, state AuthorizationQueryState, dbName string) error + // CheckSchema returns nil when access to the given schema is permitted. If the database or schema name is empty, + // it should be assumed that the context's current database and schema are being checked. + CheckSchema(ctx *Context, state AuthorizationQueryState, dbName string, schemaName string) error + // CheckTable returns nil when access to the given table is permitted. If the database or schema name is empty, + // it should be assumed that the context's current database and schema are being checked. Should return an error if + // the table name is empty. + CheckTable(ctx *Context, state AuthorizationQueryState, dbName string, schemaName string, tableName string) error +} + +// AuthorizationCheckerNode is a node that implements its own authorization checking. +type AuthorizationCheckerNode interface { + Node + // CheckAuth performs any authorization needed for the node, returning true if authorization succeeded. + CheckAuth(ctx *Context, checker PrivilegedOperationChecker) bool +} + +// globalAuthorizationHandlerFactory is the factory that is used when creating a sql.Catalog. This should never be +// fetched directly, instead it should be fetched using GetAuthorizationHandlerFactory. +var globalAuthorizationHandlerFactory AuthorizationHandlerFactory + +// SetAuthorizationHandlerFactory sets the desired authorization factory. Defaults to the factory that expects an AST +// that was generated by Vitess. This factory is used when creating a sql.Catalog, so it must be changed before the +// creation of a catalog. If set to nil, then all auth-related queries will succeed. +func SetAuthorizationHandlerFactory(factory AuthorizationHandlerFactory) { + globalAuthorizationHandlerFactory = factory +} + +// GetAuthorizationHandlerFactory returns the global AuthorizationHandlerFactory that was set using +// SetAuthorizationHandlerFactory. +func GetAuthorizationHandlerFactory() AuthorizationHandlerFactory { + if globalAuthorizationHandlerFactory != nil { + return globalAuthorizationHandlerFactory + } + return emptyAuthorizationHandlerFactory{} +} diff --git a/sql/auth_empty.go b/sql/auth_empty.go new file mode 100644 index 0000000000..6fb3f1b804 --- /dev/null +++ b/sql/auth_empty.go @@ -0,0 +1,64 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +// emptyAuthorizationHandlerFactory is the AuthorizationHandlerFactory for emptyAuthorizationHandler. +type emptyAuthorizationHandlerFactory struct{} + +var _ AuthorizationHandlerFactory = emptyAuthorizationHandlerFactory{} + +// CreateHandler implements the AuthorizationHandlerFactory interface. +func (emptyAuthorizationHandlerFactory) CreateHandler(cat Catalog) AuthorizationHandler { + return emptyAuthorizationHandler{} +} + +// emptyAuthorizationHandler will always return a "true" result. +type emptyAuthorizationHandler struct{} + +var _ AuthorizationHandler = emptyAuthorizationHandler{} + +// NewQueryState implements the AuthorizationHandler interface. +func (emptyAuthorizationHandler) NewQueryState(ctx *Context) AuthorizationQueryState { + return nil +} + +// HandleAuth implements the AuthorizationHandler interface. +func (emptyAuthorizationHandler) HandleAuth(ctx *Context, aqs AuthorizationQueryState, auth ast.AuthInformation) error { + return nil +} + +// HandleAuthNode implements the AuthorizationHandler interface. +func (emptyAuthorizationHandler) HandleAuthNode(ctx *Context, state AuthorizationQueryState, node AuthorizationCheckerNode) error { + return nil +} + +// CheckDatabase implements the AuthorizationHandler interface. +func (emptyAuthorizationHandler) CheckDatabase(ctx *Context, aqs AuthorizationQueryState, dbName string) error { + return nil +} + +// CheckSchema implements the AuthorizationHandler interface. +func (emptyAuthorizationHandler) CheckSchema(ctx *Context, aqs AuthorizationQueryState, dbName string, schemaName string) error { + return nil +} + +// CheckTable implements the AuthorizationHandler interface. +func (emptyAuthorizationHandler) CheckTable(ctx *Context, aqs AuthorizationQueryState, dbName string, schemaName string, tableName string) error { + return nil +} diff --git a/sql/catalog.go b/sql/catalog.go index 67ccf37a94..b3d078e653 100644 --- a/sql/catalog.go +++ b/sql/catalog.go @@ -44,6 +44,9 @@ type Catalog interface { // UnlockTables unlocks all tables locked by the session id given UnlockTables(ctx *Context, id uint32) error + + // AuthorizationHandler returns the AuthorizationHandler that is used by the catalog. + AuthorizationHandler() AuthorizationHandler } // CatalogTable is a Table that depends on a Catalog. diff --git a/sql/catalog_map.go b/sql/catalog_map.go index 0b7062df8e..3f23b03a6b 100644 --- a/sql/catalog_map.go +++ b/sql/catalog_map.go @@ -148,3 +148,7 @@ func (t MapCatalog) DropDbStats(ctx *Context, db string, flush bool) error { //TODO implement me panic("implement me") } + +func (t MapCatalog) AuthorizationHandler() AuthorizationHandler { + return GetAuthorizationHandlerFactory().CreateHandler(t) +} diff --git a/sql/core.go b/sql/core.go index 7b652dfe64..d898cc0511 100644 --- a/sql/core.go +++ b/sql/core.go @@ -79,11 +79,7 @@ type Node interface { // the current number of children. They must be given in the same order // as they are returned by Children. WithChildren(children ...Node) (Node, error) - // CheckPrivileges passes the operations representative of this Node to the PrivilegedOperationChecker to determine - // whether a user (contained in the context, along with their active roles) has the necessary privileges to execute - // this node (and its children). - CheckPrivileges(ctx *Context, opChecker PrivilegedOperationChecker) bool - + // IsReadOnly returns whether the node is read-only. IsReadOnly() bool } diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 29b6782534..e866bf1e63 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -203,12 +203,6 @@ func (utf *UnresolvedTableFunction) WithChildren(node ...sql.Node) (sql.Node, er panic("no expected children for unresolved table function") } -// CheckPrivileges implements the Node interface -func (utf UnresolvedTableFunction) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - panic("attempting to check privileges on an unresolved table function") - return false -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (UnresolvedTableFunction) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/information_schema/columns_table.go b/sql/information_schema/columns_table.go index 01360dde1c..ca946ae917 100644 --- a/sql/information_schema/columns_table.go +++ b/sql/information_schema/columns_table.go @@ -374,6 +374,7 @@ func getRowsFromViews(ctx *sql.Context, catalog sql.Catalog, db DbWithNames, pri } privSetDb := privSet.Database(db.Database.Name()) for _, view := range views { + // TODO: figure out how auth works in this case node, _, err := planbuilder.Parse(ctx, catalog, view.CreateViewStatement) if err != nil { continue // sometimes views contains views from other databases diff --git a/sql/information_schema/information_schema.go b/sql/information_schema/information_schema.go index 676a1c5255..0b6f284c33 100644 --- a/sql/information_schema/information_schema.go +++ b/sql/information_schema/information_schema.go @@ -1939,6 +1939,7 @@ func triggersRowIter(ctx *Context, c Catalog) (RowIter, error) { var triggerPlans []*plan.CreateTrigger for _, trigger := range triggers { triggerSqlMode := NewSqlModeFromString(trigger.SqlMode) + // TODO: figure out how auth works in this case parsedTrigger, _, err := planbuilder.ParseWithOptions(ctx, c, trigger.CreateStatement, triggerSqlMode.ParserOptions()) if err != nil { return nil, err diff --git a/sql/information_schema/routines_table.go b/sql/information_schema/routines_table.go index 1273367863..efb3589909 100644 --- a/sql/information_schema/routines_table.go +++ b/sql/information_schema/routines_table.go @@ -155,6 +155,7 @@ func routinesRowIter(ctx *Context, c Catalog, p map[string][]*plan.Procedure) (R } // todo shortcircuit routineDef->procedure.CreateProcedureString? + // TODO: figure out how auth works in this case parsedProcedure, _, err := planbuilder.Parse(ctx, c, procedure.CreateProcedureString) if err != nil { continue diff --git a/sql/information_schema/views_table.go b/sql/information_schema/views_table.go index 1ffbae434c..a9ac29661f 100644 --- a/sql/information_schema/views_table.go +++ b/sql/information_schema/views_table.go @@ -82,6 +82,7 @@ func viewsRowIter(ctx *Context, catalog Catalog) (RowIter, error) { if !hasGlobalShowViewPriv && !hasDbShowViewPriv && !privTblSet.Has(PrivilegeType_ShowView) { continue } + // TODO: figure out how auth works in this case parsedView, _, err := planbuilder.ParseWithOptions(ctx, catalog, view.CreateViewStatement, NewSqlModeFromString(view.SqlMode).ParserOptions()) if err != nil { continue diff --git a/sql/mysql_db/privileged_database_provider.go b/sql/mysql_db/privileged_database_provider.go index 0c2dab2b20..d4d9b3dec3 100644 --- a/sql/mysql_db/privileged_database_provider.go +++ b/sql/mysql_db/privileged_database_provider.go @@ -29,6 +29,7 @@ import ( type PrivilegedDatabaseProvider struct { grantTables *MySQLDb provider sql.DatabaseProvider + authHandler sql.AuthorizationHandler } var _ sql.DatabaseProvider = PrivilegedDatabaseProvider{} @@ -37,10 +38,11 @@ var _ sql.DatabaseProvider = PrivilegedDatabaseProvider{} // analyzer when Grant Tables are disabled (and Grant Tables may be enabled or disabled at any time), a new // PrivilegedDatabaseProvider is returned whenever the sql.DatabaseProvider is needed (as long as Grant Tables are // enabled) rather than wrapping a sql.DatabaseProvider when it is provided to the analyzer. -func NewPrivilegedDatabaseProvider(grantTables *MySQLDb, p sql.DatabaseProvider) sql.DatabaseProvider { +func NewPrivilegedDatabaseProvider(grantTables *MySQLDb, p sql.DatabaseProvider, authHandler sql.AuthorizationHandler) sql.DatabaseProvider { return PrivilegedDatabaseProvider{ grantTables: grantTables, provider: p, + authHandler: authHandler, } } @@ -73,7 +75,7 @@ func (pdp PrivilegedDatabaseProvider) Database(ctx *sql.Context, name string) (s return nil, providerErr } - return NewPrivilegedDatabase(pdp.grantTables, db), nil + return NewPrivilegedDatabase(pdp.grantTables, db, pdp.authHandler), nil } // HasDatabase implements the interface sql.DatabaseProvider. @@ -114,7 +116,7 @@ func (pdp PrivilegedDatabaseProvider) AllDatabases(ctx *sql.Context) []sql.Datab } if privilegeSetCount > 0 || privilegeSet.Database(checkName).HasPrivileges() { - databasesWithAccess = append(databasesWithAccess, NewPrivilegedDatabase(pdp.grantTables, db)) + databasesWithAccess = append(databasesWithAccess, NewPrivilegedDatabase(pdp.grantTables, db, pdp.authHandler)) } } return databasesWithAccess @@ -131,6 +133,7 @@ func (pdp PrivilegedDatabaseProvider) usernameFromCtx(ctx *sql.Context) string { type PrivilegedDatabase struct { grantTables *MySQLDb db sql.Database + authHandler sql.AuthorizationHandler //TODO: this should also handle views as the relevant privilege exists } @@ -150,10 +153,11 @@ var _ sql.ViewDatabase = PrivilegedDatabase{} var _ fulltext.Database = PrivilegedDatabase{} // NewPrivilegedDatabase returns a new PrivilegedDatabase. -func NewPrivilegedDatabase(grantTables *MySQLDb, db sql.Database) sql.Database { +func NewPrivilegedDatabase(grantTables *MySQLDb, db sql.Database, authHandler sql.AuthorizationHandler) sql.Database { return PrivilegedDatabase{ grantTables: grantTables, db: db, + authHandler: authHandler, } } @@ -168,50 +172,25 @@ func (pdb PrivilegedDatabase) GetTableInsensitive(ctx *sql.Context, tblName stri if adb, ok := pdb.db.(sql.AliasedDatabase); ok { checkName = adb.AliasedName() } - - privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) - dbSet := privSet.Database(checkName) - // If there are no usable privileges for this database then the table is inaccessible. - if privSet.Count() == 0 && !dbSet.HasPrivileges() { - return nil, false, sql.ErrDatabaseAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), checkName) - } - - tblSet := dbSet.Table(tblName) - // If the user has no global static privileges, database-level privileges, or table-relevant privileges then the - // table is not accessible. - if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() { - return nil, false, sql.ErrTableAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), tblName) + if err := pdb.authHandler.CheckTable(ctx, nil, checkName, "", tblName); err != nil { + return nil, false, err } return pdb.db.GetTableInsensitive(ctx, tblName) } // GetTableNames implements the interface sql.Database. func (pdb PrivilegedDatabase) GetTableNames(ctx *sql.Context) ([]string, error) { - var tablesWithAccess []string - var err error - privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) - checkName := pdb.db.Name() if adb, ok := pdb.db.(sql.AliasedDatabase); ok { checkName = adb.AliasedName() } - - dbSet := privSet.Database(checkName) - // If there are no usable privileges for this database then no table is accessible. - privSetCount := privSet.Count() - if privSetCount == 0 && !dbSet.HasPrivileges() { - return nil, nil - } - tblNames, err := pdb.db.GetTableNames(ctx) if err != nil { return nil, err } - dbSetCount := dbSet.Count() + var tablesWithAccess []string for _, tblName := range tblNames { - // If the user has any global static privileges, database-level privileges, or table-relevant privileges then a - // table is accessible. - if privSetCount > 0 || dbSetCount > 0 || dbSet.Table(tblName).HasPrivileges() { + if err = pdb.authHandler.CheckTable(ctx, nil, checkName, "", tblName); err == nil { tablesWithAccess = append(tablesWithAccess, tblName) } } @@ -224,25 +203,12 @@ func (pdb PrivilegedDatabase) GetTableInsensitiveAsOf(ctx *sql.Context, tblName if !ok { return nil, false, sql.ErrAsOfNotSupported.New(pdb.db.Name()) } - - privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) - checkName := pdb.db.Name() if adb, ok := pdb.db.(sql.AliasedDatabase); ok { checkName = adb.AliasedName() } - - dbSet := privSet.Database(checkName) - // If there are no usable privileges for this database then the table is inaccessible. - if privSet.Count() == 0 && !dbSet.HasPrivileges() { - return nil, false, sql.ErrDatabaseAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), checkName) - } - - tblSet := dbSet.Table(tblName) - // If the user has no global static privileges, database-level privileges, or table-relevant privileges then the - // table is not accessible. - if privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() { - return nil, false, sql.ErrTableAccessDeniedForUser.New(pdb.usernameFromCtx(ctx), tblName) + if err := pdb.authHandler.CheckTable(ctx, nil, checkName, "", tblName); err != nil { + return nil, false, err } return db.GetTableInsensitiveAsOf(ctx, tblName, asOf) } @@ -253,36 +219,20 @@ func (pdb PrivilegedDatabase) GetTableNamesAsOf(ctx *sql.Context, asOf interface if !ok { return nil, nil } - - var tablesWithAccess []string - var err error - privSet := pdb.grantTables.UserActivePrivilegeSet(ctx) - checkName := pdb.db.Name() if adb, ok := pdb.db.(sql.AliasedDatabase); ok { checkName = adb.AliasedName() } - - dbSet := privSet.Database(checkName) - // If there are no usable privileges for this database then no table is accessible. - if privSet.Count() == 0 && !dbSet.HasPrivileges() { - return nil, nil - } - tblNames, err := db.GetTableNamesAsOf(ctx, asOf) if err != nil { return nil, err } - privSetCount := privSet.Count() - dbSetCount := dbSet.Count() + var tablesWithAccess []string for _, tblName := range tblNames { - // If the user has any global static privileges, database-level privileges, or table-relevant privileges then a - // table is accessible. - if privSetCount > 0 || dbSetCount > 0 && dbSet.Table(tblName).HasPrivileges() { + if err = pdb.authHandler.CheckTable(ctx, nil, checkName, "", tblName); err == nil { tablesWithAccess = append(tablesWithAccess, tblName) } } - return tablesWithAccess, nil } diff --git a/sql/plan/alter_auto_increment.go b/sql/plan/alter_auto_increment.go index ac8dcba862..357f9b3a8f 100644 --- a/sql/plan/alter_auto_increment.go +++ b/sql/plan/alter_auto_increment.go @@ -59,15 +59,6 @@ func (p *AlterAutoIncrement) IsReadOnly() bool { return false } -// CheckPrivileges implements the interface sql.Node. -func (p *AlterAutoIncrement) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: p.Database().Name(), - Table: getTableName(p.Table), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (p *AlterAutoIncrement) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/alter_check.go b/sql/plan/alter_check.go index 1fa08c9494..adf6e5d9f0 100644 --- a/sql/plan/alter_check.go +++ b/sql/plan/alter_check.go @@ -101,16 +101,6 @@ func (c *CreateCheck) Children() []sql.Node { return []sql.Node{c.Table} } -// CheckPrivileges implements the interface sql.Node. -func (c *CreateCheck) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - db := c.Table.Database() - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(db), - Table: getTableName(c.Table), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (c *CreateCheck) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -140,17 +130,6 @@ func (d *DropCheck) WithChildren(children ...sql.Node) (sql.Node, error) { return NewAlterDropCheck(children[0].(*ResolvedTable), d.Name), nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DropCheck) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - db := d.Table.Database() - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(db), - Table: getTableName(d.Table), - } - - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (d *DropCheck) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -216,16 +195,6 @@ func (d DropConstraint) WithChildren(children ...sql.Node) (sql.Node, error) { return nd, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DropConstraint) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - db := GetDatabase(d.Child) - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(db), - Table: getTableName(d.Child), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - func (d *DropConstraint) IsReadOnly() bool { return false } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/plan/alter_default.go b/sql/plan/alter_default.go index f90419089c..4c1c196a13 100644 --- a/sql/plan/alter_default.go +++ b/sql/plan/alter_default.go @@ -91,15 +91,6 @@ func (d *AlterDefaultSet) Children() []sql.Node { return []sql.Node{d.Table} } -// CheckPrivileges implements the interface sql.Node. -func (d *AlterDefaultSet) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: d.Database().Name(), - Table: getTableName(d.Table), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (d *AlterDefaultSet) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -218,18 +209,6 @@ func (d *AlterDefaultDrop) WithExpressions(exprs ...sql.Expression) (sql.Node, e return &nd, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *AlterDefaultDrop) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: d.Db.Name(), - Table: getTableName(d.Table), - Column: d.ColumnName, - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (d *AlterDefaultDrop) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/alter_event.go b/sql/plan/alter_event.go index 8c919eeb00..ab1e681f5c 100644 --- a/sql/plan/alter_event.go +++ b/sql/plan/alter_event.go @@ -215,18 +215,6 @@ func (a *AlterEvent) WithChildren(children ...sql.Node) (sql.Node, error) { return &na, nil } -// CheckPrivileges implements the sql.Node interface. -func (a *AlterEvent) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{Database: a.Db.Name()} - hasPriv := opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Event)) - - if a.AlterName && a.RenameToDb != "" { - subject = sql.PrivilegeCheckSubject{Database: a.RenameToDb} - hasPriv = hasPriv && opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Event)) - } - return hasPriv -} - // Database implements the sql.Databaser interface. func (a *AlterEvent) Database() sql.Database { return a.Db diff --git a/sql/plan/alter_foreign_key.go b/sql/plan/alter_foreign_key.go index 5adabe35a2..4f80b00caf 100644 --- a/sql/plan/alter_foreign_key.go +++ b/sql/plan/alter_foreign_key.go @@ -69,17 +69,6 @@ func (p *CreateForeignKey) WithChildren(children ...sql.Node) (sql.Node, error) return NillaryWithChildren(p, children...) } -// CheckPrivileges implements the interface sql.Node. -func (p *CreateForeignKey) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: p.FkDef.ParentDatabase, - Table: p.FkDef.ParentTable, - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_References)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateForeignKey) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -368,16 +357,6 @@ func (p *DropForeignKey) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(p, children...) } -// CheckPrivileges implements the interface sql.Node. -func (p *DropForeignKey) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: p.database, - Table: p.Table, - } - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropForeignKey) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -450,15 +429,6 @@ func (p *RenameForeignKey) WithChildren(children ...sql.Node) (sql.Node, error) return NillaryWithChildren(p, children...) } -// CheckPrivileges implements the interface sql.Node. -func (p *RenameForeignKey) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: p.database, - Table: p.Table, - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (p *RenameForeignKey) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/alter_index.go b/sql/plan/alter_index.go index c81c5e4f01..4e830be4fb 100644 --- a/sql/plan/alter_index.go +++ b/sql/plan/alter_index.go @@ -197,16 +197,6 @@ func (p AlterIndex) WithExpressions(expressions ...sql.Expression) (sql.Node, er return newIdx, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *AlterIndex) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(p.ddlNode.Database()), - Table: getTableName(p.Table), - } - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Index)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*AlterIndex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/alter_pk.go b/sql/plan/alter_pk.go index 024b189b2b..06c33c02d5 100644 --- a/sql/plan/alter_pk.go +++ b/sql/plan/alter_pk.go @@ -144,17 +144,6 @@ func (a AlterPK) WithDatabase(database sql.Database) (sql.Node, error) { return &a, nil } -// CheckPrivileges implements the interface sql.Node. -func (a *AlterPK) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: a.Database().Name(), - Table: getTableName(a.Table), - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*AlterPK) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/alter_table.go b/sql/plan/alter_table.go index fae31aa78e..c59bac765e 100644 --- a/sql/plan/alter_table.go +++ b/sql/plan/alter_table.go @@ -87,26 +87,6 @@ func (r *RenameTable) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(r, children...) } -// CheckPrivileges implements the interface sql.Node. -func (r *RenameTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - var operations []sql.PrivilegedOperation - for _, oldName := range r.OldNames { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(r.Db), - Table: oldName, - } - operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop)) - } - for _, newName := range r.NewNames { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(r.Db), - Table: newName, - } - operations = append(operations, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Create, sql.PrivilegeType_Insert)) - } - return opChecker.UserHasPrivileges(ctx, operations...) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*RenameTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -396,16 +376,6 @@ func (a AddColumn) WithChildren(children ...sql.Node) (sql.Node, error) { return &a, nil } -// CheckPrivileges implements the interface sql.Node. -func (a *AddColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(a.Db), - Table: getTableName(a.Table), - } - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*AddColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -597,15 +567,6 @@ func (d DropColumn) WithChildren(children ...sql.Node) (sql.Node, error) { return &d, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DropColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(d.Db), - Table: getTableName(d.Table), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -755,17 +716,6 @@ func (r RenameColumn) WithChildren(children ...sql.Node) (sql.Node, error) { return &r, nil } -// CheckPrivileges implements the interface sql.Node. -func (r *RenameColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(r.Db), - Table: getTableName(r.Table), - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*RenameColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -862,16 +812,6 @@ func (m *ModifyColumn) WithChildren(children ...sql.Node) (sql.Node, error) { return &nm, nil } -// CheckPrivileges implements the interface sql.Node. -func (m *ModifyColumn) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(m.Db), - Table: getTableName(m.Table), - } - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ModifyColumn) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -1028,13 +968,3 @@ func (atc *AlterTableCollation) WithChildren(children ...sql.Node) (sql.Node, er natc.Table = children[0] return &natc, nil } - -// CheckPrivileges implements the interface sql.Node. -func (atc *AlterTableCollation) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(atc.Db), - Table: getTableName(atc.Table), - } - - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} diff --git a/sql/plan/alter_user.go b/sql/plan/alter_user.go index 7ffa2ef15b..7aa6b50252 100644 --- a/sql/plan/alter_user.go +++ b/sql/plan/alter_user.go @@ -82,30 +82,6 @@ func (a *AlterUser) WithChildren(children ...sql.Node) (sql.Node, error) { return a, nil } -// CheckPrivileges implements the interface sql.Node. -func (a *AlterUser) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // From the MySQL reference on ALTER USER: - // https://dev.mysql.com/doc/refman/8.0/en/alter-user.html - // ALTER USER generally requires either the global `CREATE USER` privilege, or the `UPDATE` privilege - // for the `mysql` system schema. - if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation( - sql.PrivilegeCheckSubject{Database: "mysql"}, sql.PrivilegeType_Update)) { - return true - } else if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation( - sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) { - return true - } - - // There are several exceptions to the general privilege requirements. Currently, the only relevant one is - // that any client who connects to the server using a non-anonymous account can change the password for that account. - authenticatedUser := ctx.Session.Client() - if a.User.Name == authenticatedUser.User { - return true - } - - return false -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (a *AlterUser) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/analyze.go b/sql/plan/analyze.go index 29f042911a..a77f4897f8 100644 --- a/sql/plan/analyze.go +++ b/sql/plan/analyze.go @@ -84,11 +84,6 @@ func (n *AnalyzeTable) WithChildren(_ ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *AnalyzeTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*AnalyzeTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/begin_end_block.go b/sql/plan/begin_end_block.go index 463ce2fb87..c930111ce2 100644 --- a/sql/plan/begin_end_block.go +++ b/sql/plan/begin_end_block.go @@ -81,11 +81,6 @@ func (b *BeginEndBlock) WithChildren(children ...sql.Node) (sql.Node, error) { return &newBeginEndBlock, nil } -// CheckPrivileges implements the interface sql.Node. -func (b *BeginEndBlock) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return b.Block.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (b *BeginEndBlock) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return b.Block.CollationCoercibility(ctx) diff --git a/sql/plan/block.go b/sql/plan/block.go index 916c3c0fcd..4b93f2e83e 100644 --- a/sql/plan/block.go +++ b/sql/plan/block.go @@ -132,16 +132,6 @@ func (b *Block) WithParamReference(pRef *expression.ProcedureReference) sql.Node return &nb } -// CheckPrivileges implements the interface sql.Node. -func (b *Block) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - for _, statement := range b.statements { - if !statement.CheckPrivileges(ctx, opChecker) { - return false - } - } - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (b *Block) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { // The last SELECT used in the block takes priority diff --git a/sql/plan/cached_results.go b/sql/plan/cached_results.go index a83a4aa55a..2e653e798a 100644 --- a/sql/plan/cached_results.go +++ b/sql/plan/cached_results.go @@ -102,11 +102,6 @@ func (n *CachedResults) WithChildren(children ...sql.Node) (sql.Node, error) { return &nn, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *CachedResults) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return n.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (n *CachedResults) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, n.Child) diff --git a/sql/plan/call.go b/sql/plan/call.go index 78eedced4e..60a6f83008 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -90,34 +90,6 @@ func (c *Call) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(c, children...) } -// CheckPrivileges implements the interface sql.Node. -func (c *Call) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // Procedure permissions checking is performed in the same way MySQL does it, with an exception where - // procedures which are marked as AdminOnly. These procedures are only accessible to users with explicit Execute - // permissions on the procedure in question. - - adminOnly := false - if c.cat != nil { - paramCount := len(c.Params) - proc, err := c.cat.ExternalStoredProcedure(ctx, c.Name, paramCount) - // Not finding the procedure isn't great - but that's going to surface with a better error later in the - // query execution. For the permission check, we'll proceed as though the procedure exists, and is not AdminOnly. - if proc != nil && err == nil && proc.AdminOnly { - adminOnly = true - } - } - - if !adminOnly { - subject := sql.PrivilegeCheckSubject{Database: c.Database().Name()} - if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Execute)) { - return true - } - } - - subject := sql.PrivilegeCheckSubject{Database: c.Database().Name(), Routine: c.Name, IsProcedure: true} - return opChecker.RoutineAdminCheck(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Execute)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (c *Call) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return c.Procedure.CollationCoercibility(ctx) diff --git a/sql/plan/case.go b/sql/plan/case.go index 1881b6d483..f12de2ccc7 100644 --- a/sql/plan/case.go +++ b/sql/plan/case.go @@ -121,11 +121,6 @@ func (c *CaseStatement) WithExpressions(exprs ...sql.Expression) (sql.Node, erro }, nil } -// CheckPrivileges implements the interface sql.Node. -func (c *CaseStatement) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return c.IfElse.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (c *CaseStatement) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return c.IfElse.CollationCoercibility(ctx) @@ -164,11 +159,6 @@ func (e ElseCaseError) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(e, children...) } -// CheckPrivileges implements the interface sql.Node. -func (e ElseCaseError) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (e ElseCaseError) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/close.go b/sql/plan/close.go index f127e4938d..5214db302f 100644 --- a/sql/plan/close.go +++ b/sql/plan/close.go @@ -68,11 +68,6 @@ func (c *Close) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(c, children...) } -// CheckPrivileges implements the interface sql.Node. -func (c *Close) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Close) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/concat.go b/sql/plan/concat.go index c7eda6c78f..4a429f3019 100644 --- a/sql/plan/concat.go +++ b/sql/plan/concat.go @@ -59,11 +59,6 @@ func (c *Concat) WithChildren(children ...sql.Node) (sql.Node, error) { return NewConcat(children[0], children[1]), nil } -// CheckPrivileges implements the interface sql.Node. -func (c *Concat) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return c.left.CheckPrivileges(ctx, opChecker) && c.right.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Concat) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { // As this is similar to UNION, it isn't possible to determine what the resulting coercibility may be diff --git a/sql/plan/create_index.go b/sql/plan/create_index.go index 868b3627af..5b41a63193 100644 --- a/sql/plan/create_index.go +++ b/sql/plan/create_index.go @@ -137,17 +137,6 @@ func (c *CreateIndex) WithChildren(children ...sql.Node) (sql.Node, error) { return &nc, nil } -// CheckPrivileges implements the interface sql.Node. -func (c *CreateIndex) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(GetDatabase(c.Table)), - Table: getTableName(c.Table), - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Index)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateIndex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/create_role.go b/sql/plan/create_role.go index ff1f8d7334..7e87b04220 100644 --- a/sql/plan/create_role.go +++ b/sql/plan/create_role.go @@ -95,14 +95,6 @@ func (n *CreateRole) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *CreateRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // Both CREATE ROLE and CREATE USER are valid privileges, so we use an OR - subject := sql.PrivilegeCheckSubject{} - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_CreateRole)) || - opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_CreateUser)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateRole) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/create_user.go b/sql/plan/create_user.go index ff5e80945f..153b415647 100644 --- a/sql/plan/create_user.go +++ b/sql/plan/create_user.go @@ -92,12 +92,6 @@ func (n *CreateUser) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *CreateUser) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateUser) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/create_view.go b/sql/plan/create_view.go index b42f84d13d..1c6f8e5541 100644 --- a/sql/plan/create_view.go +++ b/sql/plan/create_view.go @@ -107,14 +107,6 @@ func (cv *CreateView) WithChildren(children ...sql.Node) (sql.Node, error) { return &newCreate, nil } -// CheckPrivileges implements the interface sql.Node. -func (cv *CreateView) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{Database: cv.database.Name()} - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_CreateView)) && - cv.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateView) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/dbddl.go b/sql/plan/dbddl.go index 1e018c81ab..04504230b0 100644 --- a/sql/plan/dbddl.go +++ b/sql/plan/dbddl.go @@ -62,11 +62,6 @@ func (c *CreateDB) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(c, children...) } -// CheckPrivileges implements the interface sql.Node. -func (c *CreateDB) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Create)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateDB) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -156,11 +151,6 @@ func (d *DropDB) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(d, children...) } -// CheckPrivileges implements the interface sql.Node. -func (d *DropDB) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Drop)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropDB) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -216,14 +206,6 @@ func (c *AlterDB) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(c, children...) } -// CheckPrivileges implements the interface sql.Node. -func (c *AlterDB) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: c.Database(ctx), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Alter)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*AlterDB) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/ddl.go b/sql/plan/ddl.go index d1c9a187ce..082e934b41 100644 --- a/sql/plan/ddl.go +++ b/sql/plan/ddl.go @@ -307,16 +307,6 @@ func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 1) } -// CheckPrivileges implements the Node interface. -func (c *CreateTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - priv := sql.PrivilegeType_Create - if c.temporary { - priv = sql.PrivilegeType_CreateTempTable - } - subject := sql.PrivilegeCheckSubject{Database: CheckPrivilegeNameForDatabase(c.Db)} - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, priv)) -} - // IsReadOnly implements the Node interface. func (c *CreateTable) IsReadOnly() bool { return false @@ -572,21 +562,6 @@ func (d *DropTable) WithChildren(children ...sql.Node) (sql.Node, error) { return &nd, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DropTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - for _, tbl := range d.Tables { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(GetDatabase(tbl)), - Table: getTableName(tbl), - } - - if !opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Drop)) { - return false - } - } - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/ddl_event.go b/sql/plan/ddl_event.go index c9f8cbb216..f99138711b 100644 --- a/sql/plan/ddl_event.go +++ b/sql/plan/ddl_event.go @@ -125,14 +125,6 @@ func (c *CreateEvent) WithChildren(children ...sql.Node) (sql.Node, error) { return &nc, nil } -// CheckPrivileges implements the interface sql.Node. -func (c *CreateEvent) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: c.Db.Name(), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Event)) -} - // Database implements the sql.Databaser interface. func (c *CreateEvent) Database() sql.Database { return c.Db @@ -647,14 +639,6 @@ func (d *DropEvent) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(d, children...) } -// CheckPrivileges implements the interface sql.Node. -func (d *DropEvent) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: d.Db.Name(), - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Event)) -} - // WithDatabase implements the sql.Databaser interface. func (d *DropEvent) WithDatabase(database sql.Database) (sql.Node, error) { nde := *d diff --git a/sql/plan/ddl_procedure.go b/sql/plan/ddl_procedure.go index 6fb568559a..ff824959d1 100644 --- a/sql/plan/ddl_procedure.go +++ b/sql/plan/ddl_procedure.go @@ -110,15 +110,6 @@ func (c *CreateProcedure) WithChildren(children ...sql.Node) (sql.Node, error) { return &nc, nil } -// CheckPrivileges implements the interface sql.Node. -func (c *CreateProcedure) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: c.Db.Name(), - } - - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_CreateRoutine)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/ddl_trigger.go b/sql/plan/ddl_trigger.go index 9bd4d0d087..733135f6cb 100644 --- a/sql/plan/ddl_trigger.go +++ b/sql/plan/ddl_trigger.go @@ -110,17 +110,6 @@ func (c *CreateTrigger) WithChildren(children ...sql.Node) (sql.Node, error) { return &nc, nil } -// CheckPrivileges implements the interface sql.Node. -func (c *CreateTrigger) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: GetDatabaseName(c.Table), - Table: getTableName(c.Table), - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Trigger)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*CreateTrigger) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/declare_condition.go b/sql/plan/declare_condition.go index 06535fcfd9..c3fb8a89fb 100644 --- a/sql/plan/declare_condition.go +++ b/sql/plan/declare_condition.go @@ -74,11 +74,6 @@ func (d *DeclareCondition) WithChildren(children ...sql.Node) (sql.Node, error) return NillaryWithChildren(d, children...) } -// CheckPrivileges implements the interface sql.Node. -func (d *DeclareCondition) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DeclareCondition) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/declare_cursor.go b/sql/plan/declare_cursor.go index 9638e69b11..6fc9850003 100644 --- a/sql/plan/declare_cursor.go +++ b/sql/plan/declare_cursor.go @@ -82,11 +82,6 @@ func (d *DeclareCursor) WithChildren(children ...sql.Node) (sql.Node, error) { return &nd, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DeclareCursor) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return d.Select.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DeclareCursor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/declare_handler.go b/sql/plan/declare_handler.go index 38c750e4ac..45d9f23b88 100644 --- a/sql/plan/declare_handler.go +++ b/sql/plan/declare_handler.go @@ -104,11 +104,6 @@ func (d *DeclareHandler) WithChildren(children ...sql.Node) (sql.Node, error) { return &nd, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DeclareHandler) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DeclareHandler) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/declare_variables.go b/sql/plan/declare_variables.go index 5fde5ccbf6..6a6d37437f 100644 --- a/sql/plan/declare_variables.go +++ b/sql/plan/declare_variables.go @@ -72,11 +72,6 @@ func (d *DeclareVariables) WithChildren(children ...sql.Node) (sql.Node, error) return NillaryWithChildren(d, children...) } -// CheckPrivileges implements the interface sql.Node. -func (d *DeclareVariables) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DeclareVariables) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/delete.go b/sql/plan/delete.go index a518936555..dca2b75d60 100644 --- a/sql/plan/delete.go +++ b/sql/plan/delete.go @@ -113,32 +113,6 @@ func (p *DeleteFrom) WithChildren(children ...sql.Node) (sql.Node, error) { return NewDeleteFrom(children[0], p.explicitTargets), nil } -// CheckPrivileges implements the interface sql.Node. -func (p *DeleteFrom) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // TODO: If column values are retrieved then the SELECT privilege is required - // For example: "DELETE FROM table WHERE z > 0" - // We would need SELECT privileges on the "z" column as it's retrieving values - - for _, target := range p.GetDeleteTargets() { - deletable, err := GetDeletable(target) - if err != nil { - ctx.GetLogger().Warnf("unable to determine deletable table from delete target: %v", target) - return false - } - - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(GetDatabase(target)), - Table: deletable.Name(), - } - op := sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Delete) - if opChecker.UserHasPrivileges(ctx, op) == false { - return false - } - } - - return true -} - func GetDeletable(node sql.Node) (sql.DeletableTable, error) { switch node := node.(type) { case sql.DeletableTable: diff --git a/sql/plan/describe.go b/sql/plan/describe.go index c728ad0c4a..37a62453e4 100644 --- a/sql/plan/describe.go +++ b/sql/plan/describe.go @@ -55,11 +55,6 @@ func (d *Describe) WithChildren(children ...sql.Node) (sql.Node, error) { return NewDescribe(children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (d *Describe) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return d.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Describe) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -100,11 +95,6 @@ func (d *DescribeQuery) WithChildren(node ...sql.Node) (sql.Node, error) { return d, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DescribeQuery) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return d.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DescribeQuery) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/distinct.go b/sql/plan/distinct.go index 51a1f51595..05a0f2e6a9 100644 --- a/sql/plan/distinct.go +++ b/sql/plan/distinct.go @@ -47,11 +47,6 @@ func (d *Distinct) WithChildren(children ...sql.Node) (sql.Node, error) { return NewDistinct(children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (d *Distinct) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return d.Child.CheckPrivileges(ctx, opChecker) -} - func (d *Distinct) IsReadOnly() bool { return d.Child.IsReadOnly() } @@ -118,11 +113,6 @@ func (d *OrderedDistinct) WithChildren(children ...sql.Node) (sql.Node, error) { return NewOrderedDistinct(children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (d *OrderedDistinct) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return d.Child.CheckPrivileges(ctx, opChecker) -} - func (d *OrderedDistinct) IsReadOnly() bool { return d.Child.IsReadOnly() } diff --git a/sql/plan/drop_index.go b/sql/plan/drop_index.go index 43bb8e1365..93592b78cb 100644 --- a/sql/plan/drop_index.go +++ b/sql/plan/drop_index.go @@ -80,16 +80,6 @@ func (d *DropIndex) WithChildren(children ...sql.Node) (sql.Node, error) { return &nd, nil } -// CheckPrivileges implements the interface sql.Node. -func (d *DropIndex) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: GetDatabaseName(d.Table), - Table: getTableName(d.Table), - } - - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Index)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropIndex) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/drop_procedure.go b/sql/plan/drop_procedure.go index 4c5e0c8e1b..75a3f90b74 100644 --- a/sql/plan/drop_procedure.go +++ b/sql/plan/drop_procedure.go @@ -76,13 +76,6 @@ func (d *DropProcedure) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(d, children...) } -// CheckPrivileges implements the interface sql.Node. -func (d *DropProcedure) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{Database: d.Db.Name()} - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_AlterRoutine)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/drop_role.go b/sql/plan/drop_role.go index 1ffc29f115..6a5937ffbd 100644 --- a/sql/plan/drop_role.go +++ b/sql/plan/drop_role.go @@ -96,14 +96,6 @@ func (n *DropRole) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *DropRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{} - // Both DROP ROLE and CREATE USER are valid privileges, so we use an OR - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_DropRole)) || - opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_CreateUser)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropRole) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/drop_trigger.go b/sql/plan/drop_trigger.go index 894bbd4ccb..f6a9cc2200 100644 --- a/sql/plan/drop_trigger.go +++ b/sql/plan/drop_trigger.go @@ -76,15 +76,6 @@ func (d *DropTrigger) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(d, children...) } -// CheckPrivileges implements the interface sql.Node. -func (d *DropTrigger) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: d.Db.Name(), - Table: d.TriggerName, - } - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Trigger)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropTrigger) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/drop_user.go b/sql/plan/drop_user.go index eced42cd99..67dbdef468 100644 --- a/sql/plan/drop_user.go +++ b/sql/plan/drop_user.go @@ -97,11 +97,6 @@ func (n *DropUser) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *DropUser) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DropUser) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/drop_view.go b/sql/plan/drop_view.go index fa4a53a4be..7bf1794c2a 100644 --- a/sql/plan/drop_view.go +++ b/sql/plan/drop_view.go @@ -83,15 +83,6 @@ func (dv *SingleDropView) WithChildren(children ...sql.Node) (sql.Node, error) { return dv, nil } -// CheckPrivileges implements the interface sql.Node. -func (dv *SingleDropView) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: dv.database.Name(), - } - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Drop)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*SingleDropView) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -173,16 +164,6 @@ func (dvs *DropView) WithChildren(children ...sql.Node) (sql.Node, error) { return newDrop, nil } -// CheckPrivileges implements the interface sql.Node. -func (dvs *DropView) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - for _, child := range dvs.children { - if !child.CheckPrivileges(ctx, opChecker) { - return false - } - } - return true -} - func (dvs *DropView) IsReadOnly() bool { return false } diff --git a/sql/plan/empty_table.go b/sql/plan/empty_table.go index ccb4e9cca8..a17a919b54 100644 --- a/sql/plan/empty_table.go +++ b/sql/plan/empty_table.go @@ -103,11 +103,6 @@ func (e *EmptyTable) WithChildren(children ...sql.Node) (sql.Node, error) { return e, nil } -// CheckPrivileges implements the interface sql.Node. -func (e *EmptyTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*EmptyTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/external_procedure.go b/sql/plan/external_procedure.go index b20f5d467e..1fb59fa9f8 100644 --- a/sql/plan/external_procedure.go +++ b/sql/plan/external_procedure.go @@ -94,12 +94,6 @@ func (n *ExternalProcedure) WithExpressions(expressions ...sql.Expression) (sql. return &nn, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *ExternalProcedure) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: when DEFINER is implemented for stored procedures then this should be added - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ExternalProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/fetch.go b/sql/plan/fetch.go index 262a1bbdb8..0158ace315 100644 --- a/sql/plan/fetch.go +++ b/sql/plan/fetch.go @@ -103,11 +103,6 @@ func (f *Fetch) WithChildren(children ...sql.Node) (sql.Node, error) { return f, nil } -// CheckPrivileges implements the interface sql.Node. -func (f *Fetch) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Fetch) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/filter.go b/sql/plan/filter.go index c3853c2425..f2c0691112 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -53,11 +53,6 @@ func (f *Filter) WithChildren(children ...sql.Node) (sql.Node, error) { return NewFilter(f.Expression, children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (f *Filter) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return f.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (f *Filter) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, f.UnaryNode.Child) diff --git a/sql/plan/flush.go b/sql/plan/flush.go index 849417ccda..d06fa3a9dd 100644 --- a/sql/plan/flush.go +++ b/sql/plan/flush.go @@ -65,12 +65,6 @@ func (f *FlushPrivileges) WithChildren(children ...sql.Node) (sql.Node, error) { return f, nil } -// CheckPrivileges implements the interface sql.Node. -func (f *FlushPrivileges) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{Database: "mysql"} - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Reload)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*FlushPrivileges) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/foreign_key_handler.go b/sql/plan/foreign_key_handler.go index 5aa2a2c3db..beb1146c0f 100644 --- a/sql/plan/foreign_key_handler.go +++ b/sql/plan/foreign_key_handler.go @@ -95,11 +95,6 @@ func (n *ForeignKeyHandler) WithChildren(children ...sql.Node) (sql.Node, error) return &nn, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *ForeignKeyHandler) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return n.OriginalNode.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ForeignKeyHandler) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/grant.go b/sql/plan/grant.go index 97bcb61dab..1f393f3e56 100644 --- a/sql/plan/grant.go +++ b/sql/plan/grant.go @@ -39,6 +39,7 @@ type Grant struct { var _ sql.Node = (*Grant)(nil) var _ sql.Databaser = (*Grant)(nil) var _ sql.CollationCoercible = (*Grant)(nil) +var _ sql.AuthorizationCheckerNode = (*Grant)(nil) // Schema implements the interface sql.Node. func (n *Grant) Schema() sql.Schema { @@ -89,8 +90,8 @@ func (n *Grant) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *Grant) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *Grant) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { subject := sql.PrivilegeCheckSubject{Database: "mysql"} if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Update)) { @@ -548,6 +549,7 @@ type GrantRole struct { var _ sql.Node = (*GrantRole)(nil) var _ sql.Databaser = (*GrantRole)(nil) var _ sql.CollationCoercible = (*GrantRole)(nil) +var _ sql.AuthorizationCheckerNode = (*GrantRole)(nil) // NewGrantRole returns a new GrantRole node. func NewGrantRole(roles []UserName, users []UserName, withAdmin bool) *GrantRole { @@ -612,8 +614,8 @@ func (n *GrantRole) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *GrantRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *GrantRole) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Super)) { return true @@ -662,6 +664,7 @@ type GrantProxy struct { var _ sql.Node = (*GrantProxy)(nil) var _ sql.CollationCoercible = (*GrantProxy)(nil) +var _ sql.AuthorizationCheckerNode = (*GrantProxy)(nil) // NewGrantProxy returns a new GrantProxy node. func NewGrantProxy(on UserName, to []UserName, withGrant bool) *GrantProxy { @@ -708,8 +711,8 @@ func (n *GrantProxy) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *GrantProxy) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *GrantProxy) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { //TODO: add this when proxy support is added return true } diff --git a/sql/plan/group_by.go b/sql/plan/group_by.go index 8e6942d120..84b0dc85f5 100644 --- a/sql/plan/group_by.go +++ b/sql/plan/group_by.go @@ -104,11 +104,6 @@ func (g *GroupBy) WithChildren(children ...sql.Node) (sql.Node, error) { return NewGroupBy(g.SelectedExprs, g.GroupByExprs, children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (g *GroupBy) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return g.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (g *GroupBy) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, g.Child) diff --git a/sql/plan/hash_lookup.go b/sql/plan/hash_lookup.go index dfa98bee82..216b1bc7b3 100644 --- a/sql/plan/hash_lookup.go +++ b/sql/plan/hash_lookup.go @@ -104,11 +104,6 @@ func (n *HashLookup) WithChildren(children ...sql.Node) (sql.Node, error) { return &nn, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *HashLookup) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return n.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (n *HashLookup) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, n.Child) diff --git a/sql/plan/having.go b/sql/plan/having.go index 9f02a93455..11893c0dd8 100644 --- a/sql/plan/having.go +++ b/sql/plan/having.go @@ -53,11 +53,6 @@ func (h *Having) WithChildren(children ...sql.Node) (sql.Node, error) { return NewHaving(h.Cond, children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (h *Having) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return h.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (h *Having) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, h.Child) diff --git a/sql/plan/histogram.go b/sql/plan/histogram.go index fa0c76177c..adb1ab9b71 100644 --- a/sql/plan/histogram.go +++ b/sql/plan/histogram.go @@ -75,10 +75,6 @@ func (u *UpdateHistogram) WithChildren(children ...sql.Node) (sql.Node, error) { return u, nil } -func (u *UpdateHistogram) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - func (u *UpdateHistogram) IsReadOnly() bool { return false } @@ -138,10 +134,6 @@ func (d *DropHistogram) WithChildren(_ ...sql.Node) (sql.Node, error) { return d, nil } -func (d *DropHistogram) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - func (d *DropHistogram) IsReadOnly() bool { return false } diff --git a/sql/plan/if_else.go b/sql/plan/if_else.go index ec20b9168c..523f098987 100644 --- a/sql/plan/if_else.go +++ b/sql/plan/if_else.go @@ -86,11 +86,6 @@ func (ic *IfConditional) WithChildren(children ...sql.Node) (sql.Node, error) { return &nic, nil } -// CheckPrivileges implements the interface sql.Node. -func (ic *IfConditional) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return ic.Body.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (ic *IfConditional) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, ic.Body) @@ -220,19 +215,6 @@ func (ieb *IfElseBlock) WithChildren(children ...sql.Node) (sql.Node, error) { return NewIfElse(ifConditionals, children[len(children)-1]), nil } -// CheckPrivileges implements the interface sql.Node. -func (ieb *IfElseBlock) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - for _, ifBlock := range ieb.IfConditionals { - if !ifBlock.CheckPrivileges(ctx, opChecker) { - return false - } - } - if ieb.Else != nil { - return ieb.Else.CheckPrivileges(ctx, opChecker) - } - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (ieb *IfElseBlock) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { // We'll only be able to know which branch was taken during the RowIter, so we can't rely on that here. diff --git a/sql/plan/indexed_table_access.go b/sql/plan/indexed_table_access.go index 756707d328..0d0b91a5f5 100644 --- a/sql/plan/indexed_table_access.go +++ b/sql/plan/indexed_table_access.go @@ -247,10 +247,6 @@ func (i *IndexedTableAccess) Database() sql.Database { return i.TableNode.Database() } -func (i *IndexedTableAccess) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return i.TableNode.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (i *IndexedTableAccess) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return i.TableNode.CollationCoercibility(ctx) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 78bf38d9d2..c52acf843d 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -162,22 +162,6 @@ func (ii *InsertInto) WithChildren(children ...sql.Node) (sql.Node, error) { return &np, nil } -// CheckPrivileges implements the interface sql.Node. -func (ii *InsertInto) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(ii.db), - Table: getTableName(ii.Destination), - } - - if ii.IsReplace { - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Insert, sql.PrivilegeType_Delete)) - } else { - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Insert)) - } -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*InsertInto) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -375,11 +359,6 @@ func (id InsertDestination) WithChildren(children ...sql.Node) (sql.Node, error) return &id, nil } -// CheckPrivileges implements the interface sql.Node. -func (id *InsertDestination) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return id.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (id *InsertDestination) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, id.Child) diff --git a/sql/plan/into.go b/sql/plan/into.go index 2e3a069632..85761ede23 100644 --- a/sql/plan/into.go +++ b/sql/plan/into.go @@ -117,11 +117,6 @@ func (i *Into) WithChildren(children ...sql.Node) (sql.Node, error) { return &ni, nil } -// CheckPrivileges implements the interface sql.Node. -func (i *Into) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return i.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (i *Into) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, i.Child) diff --git a/sql/plan/iterate.go b/sql/plan/iterate.go index 01fce462bc..83b5b8cc4a 100644 --- a/sql/plan/iterate.go +++ b/sql/plan/iterate.go @@ -65,11 +65,6 @@ func (i *Iterate) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(i, children...) } -// CheckPrivileges implements the interface sql.Node. -func (i *Iterate) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Iterate) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/join.go b/sql/plan/join.go index cbdaecb871..9e689c74e0 100644 --- a/sql/plan/join.go +++ b/sql/plan/join.go @@ -363,10 +363,6 @@ func (j *JoinNode) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { return &ret, nil } -func (j *JoinNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return j.left.CheckPrivileges(ctx, opChecker) && j.right.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*JoinNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { // Joins make use of coercibility, but they don't return anything themselves diff --git a/sql/plan/json_table.go b/sql/plan/json_table.go index feff8ed3ea..0382e8e679 100644 --- a/sql/plan/json_table.go +++ b/sql/plan/json_table.go @@ -253,11 +253,6 @@ func (t *JSONTable) WithChildren(children ...sql.Node) (sql.Node, error) { return t, nil } -// CheckPrivileges implements the sql.Node interface -func (t *JSONTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*JSONTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/kill.go b/sql/plan/kill.go index a024d6ba48..3b88fe847d 100644 --- a/sql/plan/kill.go +++ b/sql/plan/kill.go @@ -68,13 +68,6 @@ func (k *Kill) WithChildren(children ...sql.Node) (sql.Node, error) { return k, nil } -// CheckPrivileges implements the interface sql.Node. -func (k *Kill) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: If the user doesn't have the SUPER privilege, they should still be able to kill their own threads - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Super)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Kill) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/leave.go b/sql/plan/leave.go index 74e8c40ebe..c1b11ad16b 100644 --- a/sql/plan/leave.go +++ b/sql/plan/leave.go @@ -64,11 +64,6 @@ func (l *Leave) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(l, children...) } -// CheckPrivileges implements the interface sql.Node. -func (l *Leave) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Leave) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/limit.go b/sql/plan/limit.go index 6178afa0b1..e65e0b26b2 100644 --- a/sql/plan/limit.go +++ b/sql/plan/limit.go @@ -76,11 +76,6 @@ func (l *Limit) WithChildren(children ...sql.Node) (sql.Node, error) { return &nl, nil } -// CheckPrivileges implements the interface sql.Node. -func (l *Limit) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return l.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (l *Limit) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, l.Child) diff --git a/sql/plan/load_data.go b/sql/plan/load_data.go index ec81ec639d..18f97a8ff7 100644 --- a/sql/plan/load_data.go +++ b/sql/plan/load_data.go @@ -96,11 +96,6 @@ func (l *LoadData) WithChildren(children ...sql.Node) (sql.Node, error) { return &nl, nil } -// CheckPrivileges implements the interface sql.Node. -func (l *LoadData) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_File)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*LoadData) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/lock.go b/sql/plan/lock.go index 1ad1f9c25b..0f70bfb7b4 100644 --- a/sql/plan/lock.go +++ b/sql/plan/lock.go @@ -101,19 +101,6 @@ func (t *LockTables) WithChildren(children ...sql.Node) (sql.Node, error) { return &LockTables{t.Catalog, locks}, nil } -// CheckPrivileges implements the interface sql.Node. -func (t *LockTables) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - operations := make([]sql.PrivilegedOperation, len(t.Locks)) - for i, tableLock := range t.Locks { - subject := sql.PrivilegeCheckSubject{ - Database: GetDatabaseName(tableLock.Table), - Table: getTableName(tableLock.Table), - } - operations[i] = sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select, sql.PrivilegeType_LockTables) - } - return opChecker.UserHasPrivileges(ctx, operations...) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*LockTables) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -171,12 +158,6 @@ func (t *UnlockTables) WithChildren(children ...sql.Node) (sql.Node, error) { return t, nil } -// CheckPrivileges implements the interface sql.Node. -func (t *UnlockTables) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: Can't quite figure out the privileges for this one, needs more testing - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*UnlockTables) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/loop.go b/sql/plan/loop.go index 05882713a7..d014546c08 100644 --- a/sql/plan/loop.go +++ b/sql/plan/loop.go @@ -123,11 +123,6 @@ func (l *Loop) WithParamReference(pRef *expression.ProcedureReference) sql.Node return &nl } -// CheckPrivileges implements the interface sql.Node. -func (l *Loop) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return l.Block.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (l *Loop) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return l.Block.CollationCoercibility(ctx) diff --git a/sql/plan/namedwindows.go b/sql/plan/namedwindows.go index edc3029f6a..280dfd6cbb 100644 --- a/sql/plan/namedwindows.go +++ b/sql/plan/namedwindows.go @@ -88,11 +88,6 @@ func (n *NamedWindows) WithChildren(nodes ...sql.Node) (sql.Node, error) { return NewNamedWindows(n.WindowDefs, nodes[0]), nil } -// CheckPrivileges implements sql.Node -func (n *NamedWindows) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return n.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (n *NamedWindows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, n.Child) diff --git a/sql/plan/nothing.go b/sql/plan/nothing.go index 30dc013cb0..2b0d10365b 100644 --- a/sql/plan/nothing.go +++ b/sql/plan/nothing.go @@ -42,11 +42,6 @@ func (n Nothing) WithChildren(children ...sql.Node) (sql.Node, error) { return NothingImpl, nil } -// CheckPrivileges implements the interface sql.Node. -func (Nothing) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (Nothing) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/offset.go b/sql/plan/offset.go index f023198993..adff1522de 100644 --- a/sql/plan/offset.go +++ b/sql/plan/offset.go @@ -65,11 +65,6 @@ func (o *Offset) WithChildren(children ...sql.Node) (sql.Node, error) { return NewOffset(o.Offset, children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (o *Offset) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return o.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (o *Offset) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, o.Child) diff --git a/sql/plan/open.go b/sql/plan/open.go index 4ff44c9d10..dc0814c9e5 100644 --- a/sql/plan/open.go +++ b/sql/plan/open.go @@ -68,11 +68,6 @@ func (o *Open) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(o, children...) } -// CheckPrivileges implements the interface sql.Node. -func (o *Open) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Open) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/prepare.go b/sql/plan/prepare.go index 6c4cbf08bc..2bc11aca17 100644 --- a/sql/plan/prepare.go +++ b/sql/plan/prepare.go @@ -75,11 +75,6 @@ func (p *PrepareQuery) WithChildren(children ...sql.Node) (sql.Node, error) { return p, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *PrepareQuery) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return p.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*PrepareQuery) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -131,11 +126,6 @@ func (p *ExecuteQuery) WithChildren(children ...sql.Node) (sql.Node, error) { panic("ExecuteQuery methods shouldn't be used") } -// CheckPrivileges implements the interface sql.Node. -func (p *ExecuteQuery) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - panic("ExecuteQuery methods shouldn't be used") -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ExecuteQuery) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -189,11 +179,6 @@ func (p *DeallocateQuery) WithChildren(children ...sql.Node) (sql.Node, error) { return p, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *DeallocateQuery) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*DeallocateQuery) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index 945792a497..fc21b5b96a 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -167,11 +167,6 @@ func (p *Procedure) WithChildren(children ...sql.Node) (sql.Node, error) { return &np, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *Procedure) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return p.Body.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (p *Procedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, p.Body) diff --git a/sql/plan/procedure_resolved_table.go b/sql/plan/procedure_resolved_table.go index c1d8db601c..9e3088e32a 100644 --- a/sql/plan/procedure_resolved_table.go +++ b/sql/plan/procedure_resolved_table.go @@ -95,11 +95,6 @@ func (t *ProcedureResolvedTable) WithChildren(children ...sql.Node) (sql.Node, e return nt, err } -// CheckPrivileges implements the interface sql.Node. -func (t *ProcedureResolvedTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return t.ResolvedTable.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (t *ProcedureResolvedTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return t.ResolvedTable.CollationCoercibility(ctx) diff --git a/sql/plan/processlist.go b/sql/plan/processlist.go index 881fab1658..8bd26ab1e2 100644 --- a/sql/plan/processlist.go +++ b/sql/plan/processlist.go @@ -58,11 +58,6 @@ func (p *ShowProcessList) WithChildren(children ...sql.Node) (sql.Node, error) { return p, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *ShowProcessList) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Process)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowProcessList) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/project.go b/sql/plan/project.go index d329571b9c..3256b0f05a 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -167,11 +167,6 @@ func (p *Project) WithChildren(children ...sql.Node) (sql.Node, error) { return &np, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *Project) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return p.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (p *Project) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, p.Child) diff --git a/sql/plan/range_heap.go b/sql/plan/range_heap.go index d0ab4e2b57..72d0f55aff 100644 --- a/sql/plan/range_heap.go +++ b/sql/plan/range_heap.go @@ -73,8 +73,4 @@ func (s *RangeHeap) WithChildren(children ...sql.Node) (sql.Node, error) { return &s2, nil } -func (s *RangeHeap) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return s.Child.CheckPrivileges(ctx, opChecker) -} - var _ sql.Node = (*RangeHeap)(nil) diff --git a/sql/plan/recursive_cte.go b/sql/plan/recursive_cte.go index 13861e417d..6442bb05f8 100644 --- a/sql/plan/recursive_cte.go +++ b/sql/plan/recursive_cte.go @@ -171,10 +171,6 @@ func (r *RecursiveCte) Children() []sql.Node { return r.union.Children() } -func (r *RecursiveCte) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return r.union.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*RecursiveCte) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -309,11 +305,6 @@ func (r *RecursiveTable) WithChildren(node ...sql.Node) (sql.Node, error) { return r, nil } -// CheckPrivileges implements the interface sql.Node. -func (r *RecursiveTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*RecursiveTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/releaser.go b/sql/plan/releaser.go index b0ac6b3c92..f90ddf27de 100644 --- a/sql/plan/releaser.go +++ b/sql/plan/releaser.go @@ -51,11 +51,6 @@ func (r *Releaser) WithChildren(children ...sql.Node) (sql.Node, error) { return &Releaser{children[0], r.Release}, nil } -// CheckPrivileges implements the interface sql.Node. -func (r *Releaser) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return r.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (r *Releaser) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, r.Child) diff --git a/sql/plan/rename_user.go b/sql/plan/rename_user.go index 5cc2fb70c3..0f8bc1afd9 100644 --- a/sql/plan/rename_user.go +++ b/sql/plan/rename_user.go @@ -76,11 +76,6 @@ func (n *RenameUser) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *RenameUser) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*RenameUser) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/replication_commands.go b/sql/plan/replication_commands.go index 90ee2e5f4c..a34f9e620a 100644 --- a/sql/plan/replication_commands.go +++ b/sql/plan/replication_commands.go @@ -113,11 +113,6 @@ func (c *ChangeReplicationSource) WithChildren(children ...sql.Node) (sql.Node, return &newNode, nil } -func (c *ChangeReplicationSource) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, - sql.NewDynamicPrivilegedOperation(DynamicPrivilege_ReplicationSlaveAdmin)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ChangeReplicationSource) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -187,11 +182,6 @@ func (c *ChangeReplicationFilter) WithChildren(children ...sql.Node) (sql.Node, return &newNode, nil } -func (c *ChangeReplicationFilter) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, - sql.NewDynamicPrivilegedOperation(DynamicPrivilege_ReplicationSlaveAdmin)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ChangeReplicationFilter) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -247,11 +237,6 @@ func (s *StartReplica) WithChildren(children ...sql.Node) (sql.Node, error) { return &newNode, nil } -func (s *StartReplica) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, - sql.NewDynamicPrivilegedOperation(DynamicPrivilege_ReplicationSlaveAdmin)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*StartReplica) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -307,11 +292,6 @@ func (s *StopReplica) WithChildren(children ...sql.Node) (sql.Node, error) { return &newNode, nil } -func (s *StopReplica) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, - sql.NewDynamicPrivilegedOperation(DynamicPrivilege_ReplicationSlaveAdmin)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*StopReplica) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -375,10 +355,6 @@ func (r *ResetReplica) WithChildren(children ...sql.Node) (sql.Node, error) { return &newNode, nil } -func (r *ResetReplica) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Reload)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ResetReplica) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/resolved_table.go b/sql/plan/resolved_table.go index 78529e34ac..15e47eea87 100644 --- a/sql/plan/resolved_table.go +++ b/sql/plan/resolved_table.go @@ -225,26 +225,6 @@ func (t *ResolvedTable) WithChildren(children ...sql.Node) (sql.Node, error) { return t, nil } -// CheckPrivileges implements the interface sql.Node. -func (t *ResolvedTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // It is assumed that if we've landed upon this node, then we're doing a SELECT operation. Most other nodes that - // may contain a TableNode will have their own privilege checks, so we should only end up here if the parent - // nodes are things such as indexed access, filters, limits, etc. - if IsDualTable(t) { - return true - } - - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(t.SqlDatabase), - Table: t.Table.Name(), - } - if subject.Database == sql.InformationSchemaDatabaseName { - return true - } - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ResolvedTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/revoke.go b/sql/plan/revoke.go index c15007a38e..d718d39e63 100644 --- a/sql/plan/revoke.go +++ b/sql/plan/revoke.go @@ -36,6 +36,7 @@ type Revoke struct { var _ sql.Node = (*Revoke)(nil) var _ sql.Databaser = (*Revoke)(nil) var _ sql.CollationCoercible = (*Revoke)(nil) +var _ sql.AuthorizationCheckerNode = (*Revoke)(nil) // Schema implements the interface sql.Node. func (n *Revoke) Schema() sql.Schema { @@ -86,8 +87,8 @@ func (n *Revoke) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *Revoke) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *Revoke) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { subject := sql.PrivilegeCheckSubject{Database: "mysql"} if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Update)) { @@ -431,6 +432,7 @@ type RevokeAll struct { var _ sql.Node = (*RevokeAll)(nil) var _ sql.CollationCoercible = (*RevokeAll)(nil) +var _ sql.AuthorizationCheckerNode = (*RevokeAll)(nil) // NewRevokeAll returns a new RevokeAll node. func NewRevokeAll(users []UserName) *RevokeAll { @@ -475,8 +477,8 @@ func (n *RevokeAll) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *RevokeAll) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *RevokeAll) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { createUser := sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser) superUser := sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Super) @@ -503,6 +505,7 @@ type RevokeRole struct { var _ sql.Node = (*RevokeRole)(nil) var _ sql.Databaser = (*RevokeRole)(nil) var _ sql.CollationCoercible = (*RevokeRole)(nil) +var _ sql.AuthorizationCheckerNode = (*RevokeRole)(nil) // NewRevokeRole returns a new RevokeRole node. func NewRevokeRole(roles []UserName, users []UserName) *RevokeRole { @@ -566,8 +569,8 @@ func (n *RevokeRole) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *RevokeRole) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *RevokeRole) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { if opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Super)) { return true @@ -616,6 +619,7 @@ type RevokeProxy struct { var _ sql.Node = (*RevokeProxy)(nil) var _ sql.CollationCoercible = (*RevokeProxy)(nil) +var _ sql.AuthorizationCheckerNode = (*RevokeProxy)(nil) // NewRevokeProxy returns a new RevokeProxy node. func NewRevokeProxy(on UserName, from []UserName) *RevokeProxy { @@ -661,8 +665,8 @@ func (n *RevokeProxy) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *RevokeProxy) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { +// CheckAuth implements the interface sql.AuthorizationCheckerNode. +func (n *RevokeProxy) CheckAuth(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { //TODO: add this when proxy support is added return true } diff --git a/sql/plan/set.go b/sql/plan/set.go index bb0ea4a93b..51e22d06cd 100644 --- a/sql/plan/set.go +++ b/sql/plan/set.go @@ -58,12 +58,6 @@ func (s *Set) WithChildren(children ...sql.Node) (sql.Node, error) { return s, nil } -// CheckPrivileges implements the interface sql.Node. -func (s *Set) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: determine which variables cannot be set without a privilege check - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Set) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/set_op.go b/sql/plan/set_op.go index 0269f71abc..0969971432 100644 --- a/sql/plan/set_op.go +++ b/sql/plan/set_op.go @@ -198,11 +198,6 @@ func (s *SetOp) WithChildren(children ...sql.Node) (sql.Node, error) { return &ret, nil } -// CheckPrivileges implements the interface sql.Node. -func (s *SetOp) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return s.left.CheckPrivileges(ctx, opChecker) && s.right.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*SetOp) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { // Unions are able to return differing values, therefore they cannot be used to determine coercibility diff --git a/sql/plan/show_binlog_status.go b/sql/plan/show_binlog_status.go index f46abaafff..9cbbb116a2 100644 --- a/sql/plan/show_binlog_status.go +++ b/sql/plan/show_binlog_status.go @@ -78,10 +78,6 @@ func (s *ShowBinlogStatus) WithChildren(children ...sql.Node) (sql.Node, error) return &newNode, nil } -func (s *ShowBinlogStatus) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_ReplicationClient)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowBinlogStatus) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_binlogs.go b/sql/plan/show_binlogs.go index 5ad52d3ac3..8bdc8707c2 100644 --- a/sql/plan/show_binlogs.go +++ b/sql/plan/show_binlogs.go @@ -76,10 +76,6 @@ func (s *ShowBinlogs) WithChildren(children ...sql.Node) (sql.Node, error) { return &newNode, nil } -func (s *ShowBinlogs) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_ReplicationClient)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowBinlogs) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_charset.go b/sql/plan/show_charset.go index 4e7ac6357a..79385a23bb 100644 --- a/sql/plan/show_charset.go +++ b/sql/plan/show_charset.go @@ -49,11 +49,6 @@ func (sc *ShowCharset) WithChildren(children ...sql.Node) (sql.Node, error) { return sc, nil } -// CheckPrivileges implements the interface sql.Node. -func (sc *ShowCharset) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - func (sc *ShowCharset) IsReadOnly() bool { return true } diff --git a/sql/plan/show_create_database.go b/sql/plan/show_create_database.go index 5fa45cebc3..3bc1eee9fc 100644 --- a/sql/plan/show_create_database.go +++ b/sql/plan/show_create_database.go @@ -84,12 +84,6 @@ func (s *ShowCreateDatabase) WithChildren(children ...sql.Node) (sql.Node, error return s, nil } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowCreateDatabase) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // The database won't be visible during the resolution step if the user doesn't have the correct privileges - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowCreateDatabase) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_create_event.go b/sql/plan/show_create_event.go index e4502de2eb..f355c1b301 100644 --- a/sql/plan/show_create_event.go +++ b/sql/plan/show_create_event.go @@ -78,12 +78,6 @@ func (s *ShowCreateEvent) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(s, children...) } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowCreateEvent) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // TODO: figure out what privileges are needed here - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowCreateEvent) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_create_procedure.go b/sql/plan/show_create_procedure.go index ac49b78431..22f4ff03fe 100644 --- a/sql/plan/show_create_procedure.go +++ b/sql/plan/show_create_procedure.go @@ -131,23 +131,6 @@ func (s *ShowCreateProcedure) WithChildren(children ...sql.Node) (sql.Node, erro return NillaryWithChildren(s, children...) } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowCreateProcedure) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // TODO: set definer - // TODO: dynamic privilege SHOW ROUTINE - // According to: https://dev.mysql.com/doc/refman/8.0/en/show-create-procedure.html - // Must have Global SELECT, SHOW_ROUTINE, CREATE_ROUTINE, ALTER_ROUTINE, or EXECUTE privileges. - - dbSubject := sql.PrivilegeCheckSubject{ - Database: s.db.Name(), - } - - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Select)) || - opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(dbSubject, sql.PrivilegeType_CreateRoutine)) || - opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(dbSubject, sql.PrivilegeType_AlterRoutine)) || - opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(dbSubject, sql.PrivilegeType_Execute)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowCreateProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_create_table.go b/sql/plan/show_create_table.go index d7ab9b22d9..0eef39f838 100644 --- a/sql/plan/show_create_table.go +++ b/sql/plan/show_create_table.go @@ -93,12 +93,6 @@ func (sc ShowCreateTable) WithChildren(children ...sql.Node) (sql.Node, error) { return &sc, nil } -// CheckPrivileges implements the interface sql.Node. -func (sc *ShowCreateTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // The table won't be visible during the resolution step if the user doesn't have the correct privileges - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowCreateTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_create_trigger.go b/sql/plan/show_create_trigger.go index 50abb43fab..9f32d7dcef 100644 --- a/sql/plan/show_create_trigger.go +++ b/sql/plan/show_create_trigger.go @@ -79,12 +79,6 @@ func (s *ShowCreateTrigger) WithChildren(children ...sql.Node) (sql.Node, error) return NillaryWithChildren(s, children...) } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowCreateTrigger) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: figure out what privileges are needed here - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowCreateTrigger) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_events.go b/sql/plan/show_events.go index f2cda094c9..95337bfd89 100644 --- a/sql/plan/show_events.go +++ b/sql/plan/show_events.go @@ -158,12 +158,6 @@ func (s *ShowEvents) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(s, children...) } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowEvents) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: figure out what privileges are needed here - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowEvents) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_grants.go b/sql/plan/show_grants.go index 26f39a6438..94a6898c5f 100644 --- a/sql/plan/show_grants.go +++ b/sql/plan/show_grants.go @@ -98,16 +98,6 @@ func (n *ShowGrants) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *ShowGrants) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - if n.CurrentUser { - return true - } - - subject := sql.PrivilegeCheckSubject{Database: "mysql"} - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowGrants) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_indexes.go b/sql/plan/show_indexes.go index 321563d952..840ea0d9b8 100644 --- a/sql/plan/show_indexes.go +++ b/sql/plan/show_indexes.go @@ -49,12 +49,6 @@ func (n *ShowIndexes) WithChildren(children ...sql.Node) (sql.Node, error) { }, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *ShowIndexes) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: figure out what privileges are required - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowIndexes) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_privileges.go b/sql/plan/show_privileges.go index 40534e8769..b0433cf48b 100644 --- a/sql/plan/show_privileges.go +++ b/sql/plan/show_privileges.go @@ -66,11 +66,6 @@ func (n *ShowPrivileges) WithChildren(children ...sql.Node) (sql.Node, error) { return n, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *ShowPrivileges) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowPrivileges) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_replica_status.go b/sql/plan/show_replica_status.go index 73cf7fa69b..efa5ed1439 100644 --- a/sql/plan/show_replica_status.go +++ b/sql/plan/show_replica_status.go @@ -128,10 +128,6 @@ func (s *ShowReplicaStatus) WithChildren(children ...sql.Node) (sql.Node, error) return &newNode, nil } -func (s *ShowReplicaStatus) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_ReplicationClient)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowReplicaStatus) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_status.go b/sql/plan/show_status.go index c31dc00e68..5bca6c02fa 100644 --- a/sql/plan/show_status.go +++ b/sql/plan/show_status.go @@ -111,11 +111,6 @@ func (s *ShowStatus) WithChildren(node ...sql.Node) (sql.Node, error) { return NewShowStatus(s.isGlobal), nil } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowStatus) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowStatus) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_tables.go b/sql/plan/show_tables.go index a587ff2d55..7c84779d49 100644 --- a/sql/plan/show_tables.go +++ b/sql/plan/show_tables.go @@ -105,12 +105,6 @@ func (p *ShowTables) WithChildren(children ...sql.Node) (sql.Node, error) { return p, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *ShowTables) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // Some tables won't be visible during the resolution step if the user doesn't have the correct privileges - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowTables) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/show_triggers.go b/sql/plan/show_triggers.go index ca9a4fcccb..e01ceb4422 100644 --- a/sql/plan/show_triggers.go +++ b/sql/plan/show_triggers.go @@ -79,12 +79,6 @@ func (s *ShowTriggers) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(s, children...) } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowTriggers) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: figure out what privileges are needed here - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowTriggers) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/showcolumns.go b/sql/plan/showcolumns.go index 890bb875f7..79db2f985a 100644 --- a/sql/plan/showcolumns.go +++ b/sql/plan/showcolumns.go @@ -124,12 +124,6 @@ func (s *ShowColumns) WithChildren(children ...sql.Node) (sql.Node, error) { return &ss, nil } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowColumns) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // The table won't be visible during the resolution step if the user doesn't have the correct privileges - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowColumns) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/showdatabases.go b/sql/plan/showdatabases.go index 4ba77e2ae0..52f00d42c0 100644 --- a/sql/plan/showdatabases.go +++ b/sql/plan/showdatabases.go @@ -64,13 +64,6 @@ func (p *ShowDatabases) WithChildren(children ...sql.Node) (sql.Node, error) { return p, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *ShowDatabases) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: Having the "SHOW DATABASES" privilege should allow one to see all databases - // Currently, only shows databases that the user has access to - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowDatabases) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/showtablestatus.go b/sql/plan/showtablestatus.go index 7f94a24346..74a4c0d0ec 100644 --- a/sql/plan/showtablestatus.go +++ b/sql/plan/showtablestatus.go @@ -91,12 +91,6 @@ func (s *ShowTableStatus) WithChildren(children ...sql.Node) (sql.Node, error) { return s, nil } -// CheckPrivileges implements the interface sql.Node. -func (s *ShowTableStatus) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // Some tables won't be visible in RowIter if the user doesn't have the correct privileges - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowTableStatus) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/showvariables.go b/sql/plan/showvariables.go index cf5ca1b12a..55a8670a1b 100644 --- a/sql/plan/showvariables.go +++ b/sql/plan/showvariables.go @@ -58,11 +58,6 @@ func (sv *ShowVariables) WithChildren(children ...sql.Node) (sql.Node, error) { return sv, nil } -// CheckPrivileges implements the interface sql.Node. -func (sv *ShowVariables) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*ShowVariables) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/showwarnings.go b/sql/plan/showwarnings.go index 410861f5e7..7132fd4e1a 100644 --- a/sql/plan/showwarnings.go +++ b/sql/plan/showwarnings.go @@ -39,11 +39,6 @@ func (sw ShowWarnings) WithChildren(children ...sql.Node) (sql.Node, error) { return sw, nil } -// CheckPrivileges implements the interface sql.Node. -func (sw ShowWarnings) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (ShowWarnings) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/signal.go b/sql/plan/signal.go index 21c0f9b931..0ec23483cd 100644 --- a/sql/plan/signal.go +++ b/sql/plan/signal.go @@ -254,11 +254,6 @@ func (s Signal) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { return &s, nil } -// CheckPrivileges implements the interface sql.Node. -func (s *Signal) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Signal) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 @@ -344,11 +339,6 @@ func (s *SignalName) WithChildren(children ...sql.Node) (sql.Node, error) { return NillaryWithChildren(s, children...) } -// CheckPrivileges implements the interface sql.Node. -func (s *SignalName) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*SignalName) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/sort.go b/sql/plan/sort.go index cc01724a0a..c780e602c1 100644 --- a/sql/plan/sort.go +++ b/sql/plan/sort.go @@ -100,11 +100,6 @@ func (s *Sort) WithChildren(children ...sql.Node) (sql.Node, error) { return NewSort(s.SortFields, children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (s *Sort) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return s.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (s *Sort) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, s.Child) @@ -205,11 +200,6 @@ func (n *TopN) WithChildren(children ...sql.Node) (sql.Node, error) { return topn, nil } -// CheckPrivileges implements the interface sql.Node. -func (n *TopN) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return n.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (n *TopN) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, n.Child) diff --git a/sql/plan/spatial_ref.go b/sql/plan/spatial_ref.go index ca1bd40b45..e2ba975b24 100644 --- a/sql/plan/spatial_ref.go +++ b/sql/plan/spatial_ref.go @@ -92,13 +92,3 @@ func (n *CreateSpatialRefSys) WithChildren(children ...sql.Node) (sql.Node, erro nn := *n return &nn, nil } - -// CheckPrivileges implements the interface sql.Node -func (n *CreateSpatialRefSys) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: "mysql", - Table: "st_spatial_references_systems", - } - - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Insert)) -} diff --git a/sql/plan/str_expr.go b/sql/plan/str_expr.go index 37660e1105..8567bd991d 100644 --- a/sql/plan/str_expr.go +++ b/sql/plan/str_expr.go @@ -55,11 +55,6 @@ func (s *StrExpr) WithChildren(children ...sql.Node) (sql.Node, error) { panic("implement me") } -func (s *StrExpr) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO implement me - panic("implement me") -} - func (s *StrExpr) IsReadOnly() bool { //TODO implement me panic("implement me") diff --git a/sql/plan/subquery.go b/sql/plan/subquery.go index a1d26157c8..2d0d5f7147 100644 --- a/sql/plan/subquery.go +++ b/sql/plan/subquery.go @@ -101,11 +101,6 @@ func (srn *StripRowNode) WithChildren(children ...sql.Node) (sql.Node, error) { return NewStripRowNode(children[0], srn.NumCols), nil } -// CheckPrivileges implements the interface sql.Node. -func (srn *StripRowNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return srn.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (srn *StripRowNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, srn.Child) @@ -149,11 +144,6 @@ func (p *PrependNode) WithChildren(children ...sql.Node) (sql.Node, error) { return NewPrependNode(children[0], p.Row), nil } -// CheckPrivileges implements the interface sql.Node. -func (p *PrependNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return p.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (p *PrependNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, p.Child) @@ -317,10 +307,6 @@ func (m *Max1Row) WithChildren(children ...sql.Node) (sql.Node, error) { return &ret, nil } -func (m *Max1Row) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return m.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (m *Max1Row) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, m.Child) diff --git a/sql/plan/subqueryalias.go b/sql/plan/subqueryalias.go index d0fa43dc15..0f444a0d9f 100644 --- a/sql/plan/subqueryalias.go +++ b/sql/plan/subqueryalias.go @@ -120,11 +120,6 @@ func (sq *SubqueryAlias) WithChildren(children ...sql.Node) (sql.Node, error) { return &nn, nil } -// CheckPrivileges implements the interface sql.Node. -func (sq *SubqueryAlias) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return sq.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (sq *SubqueryAlias) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, sq.Child) diff --git a/sql/plan/table_copier.go b/sql/plan/table_copier.go index b3744305b0..dafb41875f 100644 --- a/sql/plan/table_copier.go +++ b/sql/plan/table_copier.go @@ -137,14 +137,6 @@ func (tc *TableCopier) WithChildren(...sql.Node) (sql.Node, error) { return tc, nil } -// CheckPrivileges implements the interface sql.Node. -func (tc *TableCopier) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: add a new branch when the INSERT optimization is added - subject := sql.PrivilegeCheckSubject{Database: tc.db.Name()} - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Create)) && - tc.Source.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*TableCopier) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/table_count.go b/sql/plan/table_count.go index 083feac51d..e7c705e586 100644 --- a/sql/plan/table_count.go +++ b/sql/plan/table_count.go @@ -77,7 +77,3 @@ func (t *TableCountLookup) Children() []sql.Node { func (t *TableCountLookup) WithChildren(_ ...sql.Node) (sql.Node, error) { return t, nil } - -func (t *TableCountLookup) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool { - return true -} diff --git a/sql/plan/tablealias.go b/sql/plan/tablealias.go index 946a700286..de92e1fd3f 100644 --- a/sql/plan/tablealias.go +++ b/sql/plan/tablealias.go @@ -110,14 +110,6 @@ func (t *TableAlias) WithChildren(children ...sql.Node) (sql.Node, error) { return ret, nil } -// CheckPrivileges implements the interface sql.Node. -func (t *TableAlias) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - if t.UnaryNode != nil { - return t.UnaryNode.Child.CheckPrivileges(ctx, opChecker) - } - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (t *TableAlias) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { if t.UnaryNode != nil { diff --git a/sql/plan/transaction.go b/sql/plan/transaction.go index 1ad2358b98..f3b7d5e1b9 100644 --- a/sql/plan/transaction.go +++ b/sql/plan/transaction.go @@ -27,11 +27,6 @@ func (transactionNode) Children() []sql.Node { return nil } -// CheckPrivileges implements the interface sql.Node. -func (transactionNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*transactionNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/transformed_named_node.go b/sql/plan/transformed_named_node.go index 5863e0d067..5f69671ee0 100644 --- a/sql/plan/transformed_named_node.go +++ b/sql/plan/transformed_named_node.go @@ -52,11 +52,6 @@ func (n *TransformedNamedNode) WithChildren(children ...sql.Node) (sql.Node, err return NewTransformedNamedNode(children[0], n.name), nil } -// CheckPrivileges implements the interface sql.Node. -func (n *TransformedNamedNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return n.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (n *TransformedNamedNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, n.Child) diff --git a/sql/plan/trigger.go b/sql/plan/trigger.go index d65bac4b31..2aab41f1ea 100644 --- a/sql/plan/trigger.go +++ b/sql/plan/trigger.go @@ -88,17 +88,6 @@ func (t *TriggerExecutor) WithChildren(children ...sql.Node) (sql.Node, error) { return NewTriggerExecutor(children[0], children[1], t.TriggerEvent, t.TriggerTime, t.TriggerDefinition), nil } -// CheckPrivileges implements the interface sql.Node. -func (t *TriggerExecutor) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // TODO: Figure out exactly how triggers work, not exactly clear whether trigger creator AND user needs the privileges - subject := sql.PrivilegeCheckSubject{ - Database: GetDatabaseName(t.right), - Table: getTableName(t.right), - } - return t.left.CheckPrivileges(ctx, opChecker) && - opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Trigger)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (t *TriggerExecutor) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, t.left) diff --git a/sql/plan/trigger_begin_end_block.go b/sql/plan/trigger_begin_end_block.go index 2d83a24554..a0370d7cf9 100644 --- a/sql/plan/trigger_begin_end_block.go +++ b/sql/plan/trigger_begin_end_block.go @@ -43,11 +43,6 @@ func (b *TriggerBeginEndBlock) WithChildren(children ...sql.Node) (sql.Node, err return NewTriggerBeginEndBlock(NewBeginEndBlock(b.BeginEndBlock.Label, NewBlock(children))), nil } -// CheckPrivileges implements the interface sql.Node. -func (b *TriggerBeginEndBlock) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return b.Block.CheckPrivileges(ctx, opChecker) -} - // WithParamReference implements the interface expression.ProcedureReferencable. func (b *TriggerBeginEndBlock) WithParamReference(pRef *expression.ProcedureReference) sql.Node { nb := *b diff --git a/sql/plan/truncate.go b/sql/plan/truncate.go index 7f84dbe1da..17c92ee932 100644 --- a/sql/plan/truncate.go +++ b/sql/plan/truncate.go @@ -93,17 +93,6 @@ func (p *Truncate) WithChildren(children ...sql.Node) (sql.Node, error) { return &nt, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *Truncate) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: p.db, - Table: getTableName(p.Child), - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Drop)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Truncate) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/unresolved.go b/sql/plan/unresolved.go index 9320867a50..fe6cc838b8 100644 --- a/sql/plan/unresolved.go +++ b/sql/plan/unresolved.go @@ -110,17 +110,6 @@ func (t *UnresolvedTable) AsOf() sql.Expression { return t.asOf } -// CheckPrivileges implements the interface sql.Node. -func (t *UnresolvedTable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - subject := sql.PrivilegeCheckSubject{ - Database: t.Database().Name(), - Table: t.name, - } - - return opChecker.UserHasPrivileges(ctx, - sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Select)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*UnresolvedTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/update.go b/sql/plan/update.go index 1d5f519bcc..c3ac76cce0 100644 --- a/sql/plan/update.go +++ b/sql/plan/update.go @@ -181,19 +181,6 @@ func (u *Update) WithChildren(children ...sql.Node) (sql.Node, error) { return &np, nil } -// CheckPrivileges implements the interface sql.Node. -func (u *Update) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - //TODO: If column values are retrieved then the SELECT privilege is required - // For example: "UPDATE table SET x = y + 1 WHERE z > 0" - // We would need SELECT privileges on both the "y" and "z" columns as they're retrieving values - subject := sql.PrivilegeCheckSubject{ - Database: CheckPrivilegeNameForDatabase(u.DB()), - Table: getTableName(u.Child), - } - // TODO: this needs a real database, fix it - return opChecker.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Update)) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Update) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/update_join.go b/sql/plan/update_join.go index d997d119ee..96984b201c 100644 --- a/sql/plan/update_join.go +++ b/sql/plan/update_join.go @@ -71,11 +71,6 @@ func (u *UpdateJoin) IsReadOnly() bool { return false } -// CheckPrivileges implements the interface sql.Node. -func (u *UpdateJoin) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return u.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (u *UpdateJoin) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, u.Child) diff --git a/sql/plan/update_source.go b/sql/plan/update_source.go index 0b894670f9..4a10ed8f97 100644 --- a/sql/plan/update_source.go +++ b/sql/plan/update_source.go @@ -134,11 +134,6 @@ func (u *UpdateSource) WithChildren(children ...sql.Node) (sql.Node, error) { return NewUpdateSource(children[0], u.Ignore, u.UpdateExprs), nil } -// CheckPrivileges implements the interface sql.Node. -func (u *UpdateSource) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return u.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (u *UpdateSource) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, u.Child) diff --git a/sql/plan/use.go b/sql/plan/use.go index 8ea881100c..524f5c4e21 100644 --- a/sql/plan/use.go +++ b/sql/plan/use.go @@ -90,13 +90,6 @@ func (u *Use) WithChildren(children ...sql.Node) (sql.Node, error) { return u, nil } -// CheckPrivileges implements the interface sql.Node. -func (u *Use) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - // The given database will not be visible if the user does not have the appropriate privileges, so we can just - // return true here. - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Use) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/values.go b/sql/plan/values.go index 51c2109dc9..9368daafbc 100644 --- a/sql/plan/values.go +++ b/sql/plan/values.go @@ -163,11 +163,6 @@ func (p *Values) WithChildren(children ...sql.Node) (sql.Node, error) { return p, nil } -// CheckPrivileges implements the interface sql.Node. -func (p *Values) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*Values) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/sql/plan/window.go b/sql/plan/window.go index a69f705db4..eb49cdfaa1 100644 --- a/sql/plan/window.go +++ b/sql/plan/window.go @@ -93,11 +93,6 @@ func (w *Window) WithChildren(children ...sql.Node) (sql.Node, error) { return NewWindow(w.SelectExprs, children[0]), nil } -// CheckPrivileges implements the interface sql.Node. -func (w *Window) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return w.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (w *Window) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, w.Child) diff --git a/sql/plan/with.go b/sql/plan/with.go index fb1f40ef16..d6da3df35a 100644 --- a/sql/plan/with.go +++ b/sql/plan/with.go @@ -85,11 +85,6 @@ func (w *With) WithChildren(children ...sql.Node) (sql.Node, error) { return NewWith(children[0], w.CTEs, w.Recursive), nil } -// CheckPrivileges implements the interface sql.Node. -func (w *With) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return w.Child.CheckPrivileges(ctx, opChecker) -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (w *With) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.GetCoercibility(ctx, w.Child) diff --git a/sql/planbuilder/auth_default.go b/sql/planbuilder/auth_default.go new file mode 100644 index 0000000000..6299593ffc --- /dev/null +++ b/sql/planbuilder/auth_default.go @@ -0,0 +1,510 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package planbuilder + +import ( + "fmt" + "strconv" + "strings" + + "github.com/dolthub/vitess/go/mysql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/mysql_db" +) + +// defaultAuthorizationQueryState contains query-specific state for defaultAuthorizationHandler. +type defaultAuthorizationQueryState struct { + enabled bool + db *mysql_db.MySQLDb + user *mysql_db.User + privSet mysql_db.PrivilegeSet + err error +} + +var _ sql.AuthorizationQueryState = defaultAuthorizationQueryState{} + +// AuthorizationQueryStateImpl implements the AuthorizationQueryState interface. +func (state defaultAuthorizationQueryState) Error() error { + return state.err +} + +// AuthorizationQueryStateImpl implements the AuthorizationQueryState interface. +func (defaultAuthorizationQueryState) AuthorizationQueryStateImpl() {} + +// defaultAuthorizationHandlerFactory is the AuthorizationHandlerFactory for defaultAuthorizationHandler. +type defaultAuthorizationHandlerFactory struct{} + +var _ sql.AuthorizationHandlerFactory = defaultAuthorizationHandlerFactory{} + +// CreateHandler implements the AuthorizationHandlerFactory interface. +func (defaultAuthorizationHandlerFactory) CreateHandler(cat sql.Catalog) sql.AuthorizationHandler { + return defaultAuthorizationHandler{ + cat: cat, + } +} + +// defaultAuthorizationHandler handles authorization for ASTs that were generated directly by the Vitess SQL parser. +type defaultAuthorizationHandler struct { + cat sql.Catalog +} + +var _ sql.AuthorizationHandler = defaultAuthorizationHandler{} + +// NewQueryState implements the AuthorizationHandler interface. +func (h defaultAuthorizationHandler) NewQueryState(ctx *sql.Context) sql.AuthorizationQueryState { + db, err := h.cat.Database(ctx, "mysql") + if err != nil { + // If we can't load the MySQL database, then we'll assume that it's been disabled + return defaultAuthorizationQueryState{ + enabled: false, + } + } + mysqlDb, ok := db.(*mysql_db.MySQLDb) + if !ok { + // If we can't load the MySQL database, then we'll assume that it's been disabled + return defaultAuthorizationQueryState{ + enabled: false, + } + } + var user *mysql_db.User + var privSet mysql_db.PrivilegeSet + enabled := mysqlDb.Enabled() + if enabled { + client := ctx.Session.Client() + user = func() *mysql_db.User { + rd := mysqlDb.Reader() + defer rd.Close() + return mysqlDb.GetUser(rd, client.User, client.Address, false) + }() + if user == nil { + return defaultAuthorizationQueryState{ + err: mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%s'", client.User), + } + } + privSet = mysqlDb.UserActivePrivilegeSet(ctx) + } + return defaultAuthorizationQueryState{ + enabled: enabled, + db: mysqlDb, + user: user, + privSet: privSet, + err: nil, + } +} + +// HandleAuth implements the AuthorizationHandler interface. +func (h defaultAuthorizationHandler) HandleAuth(ctx *sql.Context, aqs sql.AuthorizationQueryState, auth ast.AuthInformation) error { + if aqs == nil { + aqs = h.NewQueryState(ctx) + } + state := aqs.(defaultAuthorizationQueryState) + if state.err != nil || !state.enabled { + return state.err + } + + var err error + hasPrivileges := true + var privilegeTypes []sql.PrivilegeType + switch auth.AuthType { + case ast.AuthType_IGNORE: + // This means that authorization is being handled elsewhere (such as a child or parent), and should be ignored here + return nil + case ast.AuthType_ALTER: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Alter} + case ast.AuthType_ALTER_ROUTINE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_AlterRoutine} + case ast.AuthType_ALTER_USER: + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{Database: "mysql"}, sql.PrivilegeType_Update)) || + state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) || + state.user.User == auth.TargetNames[0] + case ast.AuthType_CALL: + hasPrivileges, err = h.call(ctx, state, auth) + if err != nil { + return err + } + case ast.AuthType_CREATE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Create} + case ast.AuthType_CREATE_ROLE: + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateRole)) || + state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) + case ast.AuthType_CREATE_ROUTINE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_CreateRoutine} + case ast.AuthType_CREATE_TEMP: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_CreateTempTable} + case ast.AuthType_CREATE_USER: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_CreateUser} + case ast.AuthType_CREATE_VIEW: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_CreateView} + case ast.AuthType_DELETE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Delete} + case ast.AuthType_DROP: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Drop} + case ast.AuthType_DROP_ROLE: + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_DropRole)) || + state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_CreateUser)) + case ast.AuthType_EVENT: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Event} + case ast.AuthType_FILE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_File} + case ast.AuthType_FOREIGN_KEY: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_References} + case ast.AuthType_GRANT_PRIVILEGE: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_GRANT_PROXY: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_GRANT_ROLE: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_INDEX: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Index} + case ast.AuthType_INSERT: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Insert} + case ast.AuthType_LOCK: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Select, sql.PrivilegeType_LockTables} + case ast.AuthType_PROCESS: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Process} + case ast.AuthType_RELOAD: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Reload} + case ast.AuthType_RENAME: + hasPrivileges, err = h.renameTables(ctx, state, auth) + if err != nil { + return err + } + case ast.AuthType_REPLACE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Insert, sql.PrivilegeType_Delete} + case ast.AuthType_REPLICATION: + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewDynamicPrivilegedOperation("replication_slave_admin")) + case ast.AuthType_REPLICATION_CLIENT: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_ReplicationClient} + case ast.AuthType_REVOKE_ALL: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_REVOKE_PRIVILEGE: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_REVOKE_PROXY: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_REVOKE_ROLE: + hasPrivileges = h.grantAndRevoke(ctx, state, auth) + case ast.AuthType_SELECT: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Select} + case ast.AuthType_SHOW: + // This a placeholder for some of the SHOW commands, as we don't yet know what permissions they should have + case ast.AuthType_SHOW_CREATE_PROCEDURE: + subject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, auth.TargetNames[0]), + } + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, sql.PrivilegeType_Select)) || + state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_CreateRoutine)) || + state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_AlterRoutine)) || + state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Execute)) + case ast.AuthType_SUPER: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Super} + case ast.AuthType_TRIGGER: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Trigger} + case ast.AuthType_UPDATE: + privilegeTypes = []sql.PrivilegeType{sql.PrivilegeType_Update} + case ast.AuthType_VISIBLE: + hasPrivileges, err = h.visible(ctx, state, &auth) + if err != nil { + return err + } + default: + if len(auth.AuthType) == 0 { + return fmt.Errorf("AuthType is empty") + } else { + return fmt.Errorf("AuthType not handled: `%s`", auth.AuthType) + } + } + + switch auth.TargetType { + case ast.AuthTargetType_Ignore: + // This means that the AuthType did not need a TargetType, so we can safely ignore it + case ast.AuthTargetType_DatabaseIdentifiers: + for i := 0; i < len(auth.TargetNames) && hasPrivileges; i++ { + dbName := auth.TargetNames[i] + if strings.EqualFold(dbName, "information_schema") { + continue + } + if err = h.authCheckDatabaseTableNames(ctx, state, dbName, ""); err != nil { + return err + } + subject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, dbName), + } + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...)) + } + case ast.AuthTargetType_Global: + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(sql.PrivilegeCheckSubject{}, privilegeTypes...)) && hasPrivileges + case ast.AuthTargetType_MultipleTableIdentifiers: + for i := 0; i < len(auth.TargetNames) && hasPrivileges; i += 2 { + dbName := auth.TargetNames[i] + tableName := auth.TargetNames[i+1] + if strings.EqualFold(dbName, "information_schema") { + continue + } + if err = h.authCheckDatabaseTableNames(ctx, state, dbName, tableName); err != nil { + return err + } + subject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, dbName), + Table: tableName, + } + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...)) + } + case ast.AuthTargetType_SingleTableIdentifier: + dbName := auth.TargetNames[0] + tableName := auth.TargetNames[1] + if strings.EqualFold(dbName, "information_schema") { + return nil + } + if err = h.authCheckDatabaseTableNames(ctx, state, dbName, tableName); err != nil { + return err + } + subject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, dbName), + Table: tableName, + } + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...)) && hasPrivileges + case ast.AuthTargetType_TableColumn: + dbName := auth.TargetNames[0] + tableName := auth.TargetNames[1] + colName := auth.TargetNames[2] + if strings.EqualFold(dbName, "information_schema") { + return nil + } + if err = h.authCheckDatabaseTableNames(ctx, state, dbName, tableName); err != nil { + return err + } + subject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, dbName), + Table: tableName, + Column: colName, + } + hasPrivileges = state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, privilegeTypes...)) && hasPrivileges + case ast.AuthTargetType_TODO: + // This is similar to IGNORE, except we're meant to replace this at some point + default: + if len(auth.TargetType) == 0 { + return fmt.Errorf("TargetType is unexpectedly empty") + } else { + return fmt.Errorf("TargetType not handled: `%s`", auth.TargetType) + } + } + + if !hasPrivileges { + return sql.ErrPrivilegeCheckFailed.New(state.user.UserHostToString("'")) + } + return nil +} + +// HandleAuthNode implements the AuthorizationHandler interface. +func (h defaultAuthorizationHandler) HandleAuthNode(ctx *sql.Context, aqs sql.AuthorizationQueryState, node sql.AuthorizationCheckerNode) error { + if aqs == nil { + aqs = h.NewQueryState(ctx) + } + state := aqs.(defaultAuthorizationQueryState) + if state.err != nil || !state.enabled { + return state.err + } + if !node.CheckAuth(ctx, state.db) { + return sql.ErrPrivilegeCheckFailed.New(state.user.UserHostToString("'")) + } + return nil +} + +// CheckDatabase implements the AuthorizationHandler interface. +func (h defaultAuthorizationHandler) CheckDatabase(ctx *sql.Context, aqs sql.AuthorizationQueryState, dbName string) error { + if aqs == nil { + aqs = h.NewQueryState(ctx) + } + state := aqs.(defaultAuthorizationQueryState) + if state.err != nil || !state.enabled { + return state.err + } + return h.authCheckDatabaseTableNames(ctx, state, dbName, "") +} + +// CheckSchema implements the AuthorizationHandler interface. +func (h defaultAuthorizationHandler) CheckSchema(ctx *sql.Context, aqs sql.AuthorizationQueryState, dbName string, schemaName string) error { + if aqs == nil { + aqs = h.NewQueryState(ctx) + } + state := aqs.(defaultAuthorizationQueryState) + if state.err != nil || !state.enabled { + return state.err + } + // Since MySQL/Vitess doesn't use schemas, this will just check the database only + return h.authCheckDatabaseTableNames(ctx, state, dbName, "") +} + +// CheckTable implements the AuthorizationHandler interface. +func (h defaultAuthorizationHandler) CheckTable(ctx *sql.Context, aqs sql.AuthorizationQueryState, dbName string, schemaName string, tableName string) error { + if aqs == nil { + aqs = h.NewQueryState(ctx) + } + state := aqs.(defaultAuthorizationQueryState) + if state.err != nil || !state.enabled { + return state.err + } + if len(tableName) == 0 { + return sql.ErrTableAccessDeniedForUser.New(state.user.UserHostToString("'"), tableName) + } + // Since MySQL/Vitess doesn't use schemas, it's ignored + return h.authCheckDatabaseTableNames(ctx, state, dbName, tableName) +} + +// call handles the CALL type. +func (h defaultAuthorizationHandler) call(ctx *sql.Context, state defaultAuthorizationQueryState, auth ast.AuthInformation) (bool, error) { + if len(auth.TargetNames) != 3 { + return false, fmt.Errorf("CALL expected 3 TargetNames") + } + dbName := h.authDatabaseName(ctx, auth.TargetNames[0]) + procName := auth.TargetNames[1] + paramCount, err := strconv.Atoi(auth.TargetNames[2]) + if err != nil { + return false, fmt.Errorf("CALL auth encountered error:\n%s", err.Error()) + } + if err = h.authCheckDatabaseTableNames(ctx, state, dbName, ""); err != nil { + return false, err + } + + // Procedure permissions checking is performed in the same way MySQL does it, with an exception where + // procedures which are marked as AdminOnly. These procedures are only accessible to users with explicit Execute + // permissions on the procedure in question. + + adminOnly := false + if h.cat != nil { + proc, err := h.cat.ExternalStoredProcedure(ctx, procName, paramCount) + // Not finding the procedure isn't great - but that's going to surface with a better error later in the + // query execution. For the permission check, we'll proceed as though the procedure exists, and is not AdminOnly. + if proc != nil && err == nil && proc.AdminOnly { + adminOnly = true + } + } + + if !adminOnly { + subject := sql.PrivilegeCheckSubject{ + Database: dbName, + } + if state.db.UserHasPrivileges(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Execute)) { + return true, nil + } + } + + subject := sql.PrivilegeCheckSubject{ + Database: dbName, + Routine: procName, + IsProcedure: true, + } + return state.db.RoutineAdminCheck(ctx, sql.NewPrivilegedOperation(subject, sql.PrivilegeType_Execute)), nil +} + +// grantAndRevoke handles the GRANT and REVOKE types. +func (h defaultAuthorizationHandler) grantAndRevoke(ctx *sql.Context, state defaultAuthorizationQueryState, auth ast.AuthInformation) bool { + // TODO: move all the logic to functions on the handler, rather than deferring to the nodes, but the nodes will still be inputs + node, ok := auth.Extra.(sql.AuthorizationCheckerNode) + if !ok { + return false + } + return node.CheckAuth(ctx, state.db) +} + +// renameTables handles the RENAME type. +func (h defaultAuthorizationHandler) renameTables(ctx *sql.Context, state defaultAuthorizationQueryState, auth ast.AuthInformation) (bool, error) { + // Names are given in groups of 4: from_db, from_table, to_db, to_table + if len(auth.TargetNames)%4 != 0 { + return false, fmt.Errorf("expected tables in groups of 4") + } + var operations []sql.PrivilegedOperation + for i := 0; i < len(auth.TargetNames); i += 4 { + fromSubject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, auth.TargetNames[i]), + Table: auth.TargetNames[i+1], + } + toSubject := sql.PrivilegeCheckSubject{ + Database: h.authDatabaseName(ctx, auth.TargetNames[i+2]), + Table: auth.TargetNames[i+3], + } + operations = append(operations, + sql.NewPrivilegedOperation(fromSubject, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop), + sql.NewPrivilegedOperation(toSubject, sql.PrivilegeType_Create, sql.PrivilegeType_Insert)) + } + return state.db.UserHasPrivileges(ctx, operations...), nil +} + +// visible handles the VISIBLE type. +func (h defaultAuthorizationHandler) visible(ctx *sql.Context, state defaultAuthorizationQueryState, auth *ast.AuthInformation) (bool, error) { + // We clear the TargetType on the AuthInformation so that it's ignored by later steps + targetType := auth.TargetType + auth.TargetType = ast.AuthTargetType_Ignore + + switch targetType { + case ast.AuthTargetType_DatabaseIdentifiers: + for _, dbName := range auth.TargetNames { + if err := h.authCheckDatabaseTableNames(ctx, state, dbName, ""); err != nil { + return false, err + } + } + return true, nil + case ast.AuthTargetType_TODO: + return true, nil + default: + if len(auth.TargetType) == 0 { + return false, fmt.Errorf("TargetType is unexpectedly empty") + } else { + return false, fmt.Errorf("TargetType not handled: `%s`", auth.TargetType) + } + } +} + +// authDatabaseName uses the current database from the context if a database is not specified, otherwise it returns the +// given database name. +func (h defaultAuthorizationHandler) authDatabaseName(ctx *sql.Context, dbName string) string { + if len(dbName) == 0 { + dbName = ctx.GetCurrentDatabase() + } + // Revision databases take the form "dbname/revision", so we must split the revision from the database name + splitDbName := strings.SplitN(dbName, "/", 2) + return splitDbName[0] +} + +// authCheckDatabaseTableNames errors if the user does not have access to the database or table in any capacity, +// regardless of the command. +func (h defaultAuthorizationHandler) authCheckDatabaseTableNames(ctx *sql.Context, state defaultAuthorizationQueryState, dbName string, tableName string) error { + if strings.EqualFold(dbName, "information_schema") { + return nil + } + dbName = h.authDatabaseName(ctx, dbName) + dbSet := state.privSet.Database(dbName) + // If there are no usable privileges for this database then the table is inaccessible. + if state.privSet.Count() == 0 && !dbSet.HasPrivileges() { + return sql.ErrDatabaseAccessDeniedForUser.New(state.user.UserHostToString("'"), dbName) + } + if len(tableName) > 0 { + tblSet := dbSet.Table(tableName) + // If the user has no global static privileges, database-level privileges, or table-relevant privileges then the + // table is not accessible. + if state.privSet.Count() == 0 && dbSet.Count() == 0 && !tblSet.HasPrivileges() { + return sql.ErrTableAccessDeniedForUser.New(state.user.UserHostToString("'"), tableName) + } + } + return nil +} + +// init sets the factory to use to this one by default. If this is changed, it will be from an integrator, and therefore +// this is guaranteed to run before the integrator's init function (if one is used). +func init() { + sql.SetAuthorizationHandlerFactory(defaultAuthorizationHandlerFactory{}) +} diff --git a/sql/planbuilder/binlog_replication.go b/sql/planbuilder/binlog_replication.go index 43b4cda307..4302e1099f 100644 --- a/sql/planbuilder/binlog_replication.go +++ b/sql/planbuilder/binlog_replication.go @@ -25,6 +25,9 @@ import ( ) func (b *Builder) buildChangeReplicationSource(inScope *scope, n *ast.ChangeReplicationSource) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() convertedOptions := make([]binlogreplication.ReplicationOption, 0, len(n.Options)) for _, option := range n.Options { @@ -64,6 +67,9 @@ func (b *Builder) buildReplicationOption(inScope *scope, option *ast.Replication } func (b *Builder) buildChangeReplicationFilter(inScope *scope, n *ast.ChangeReplicationFilter) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() convertedOptions := make([]binlogreplication.ReplicationOption, 0, len(n.Options)) for _, option := range n.Options { diff --git a/sql/planbuilder/builder.go b/sql/planbuilder/builder.go index 6bf6abd8de..ed3a5aefde 100644 --- a/sql/planbuilder/builder.go +++ b/sql/planbuilder/builder.go @@ -49,9 +49,11 @@ type Builder struct { nesting int // EventScheduler is used to communicate with the event scheduler // for any EVENT related statements. It can be nil if EventScheduler is not defined. - scheduler sql.EventScheduler - parser sql.Parser - qFlags *sql.QueryFlags + scheduler sql.EventScheduler + parser sql.Parser + qFlags *sql.QueryFlags + authEnabled bool + authQueryState sql.AuthorizationQueryState } // BindvarContext holds bind variable replacement literals. @@ -111,14 +113,20 @@ func New(ctx *sql.Context, cat sql.Catalog, es sql.EventScheduler, p sql.Parser) if p == nil { p = sql.NewMysqlParser() } + var state sql.AuthorizationQueryState + if cat != nil { + state = cat.AuthorizationHandler().NewQueryState(ctx) + } return &Builder{ - ctx: ctx, - cat: cat, - scheduler: es, - parserOpts: sql.LoadSqlMode(ctx).ParserOptions(), - f: &factory{}, - parser: p, - qFlags: &sql.QueryFlags{}, + ctx: ctx, + cat: cat, + scheduler: es, + parserOpts: sql.LoadSqlMode(ctx).ParserOptions(), + f: &factory{}, + parser: p, + qFlags: &sql.QueryFlags{}, + authEnabled: true, + authQueryState: state, } } @@ -189,6 +197,7 @@ func (b *Builder) Reset() { b.viewCtx = nil b.nesting = 0 b.qFlags = &sql.QueryFlags{} + b.authQueryState = b.cat.AuthorizationHandler().NewQueryState(b.ctx) } type parseErr struct { @@ -283,6 +292,9 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str case *ast.ChangeReplicationFilter: return b.buildChangeReplicationFilter(inScope, n) case *ast.StartReplica: + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() startRep := plan.NewStartReplica() if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { @@ -290,6 +302,9 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str } outScope.node = startRep case *ast.StopReplica: + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() stopRep := plan.NewStopReplica() if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { @@ -297,6 +312,9 @@ func (b *Builder) buildSubquery(inScope *scope, stmt ast.Statement, subQuery str } outScope.node = stopRep case *ast.ResetReplica: + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() resetRep := plan.NewResetReplica(n.All) if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.HasBinlogReplicaController() { @@ -415,6 +433,9 @@ func (b *Builder) buildVirtualTableScan(db string, tab sql.Table) *plan.VirtualC // buildInjectedStatement returns the sql.Node encapsulated by the injected statement. func (b *Builder) buildInjectedStatement(inScope *scope, n ast.InjectedStatement) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } resolvedChildren := make([]any, len(n.Children)) for i, child := range n.Children { resolvedChildren[i] = b.buildScalar(inScope, child) @@ -462,3 +483,14 @@ func (b *Builder) BuildColumnDefaultValueWithTable(defExpr ast.Expr, tableExpr a } return b.convertDefaultExpression(outscope, defExpr, typ, nullable) } + +// DisableAuth disables all authorization checks. +func (b *Builder) DisableAuth() { + b.authEnabled = false +} + +// EnableAuth enables all authorization checks. Auth is enabled by default, so this only needs to be called when it was +// previously disabled using DisableAuth. +func (b *Builder) EnableAuth() { + b.authEnabled = true +} diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index 714a8de842..0949749f37 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -94,6 +94,9 @@ func (b *Builder) buildAlterTable(inScope *scope, query string, c *ast.AlterTabl b.multiDDL = false }() + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } statements := make([]sql.Node, 0, len(c.Statements)) for i := 0; i < len(c.Statements); i++ { scopes := b.buildAlterTableClause(inScope, c.Statements[i]) @@ -114,6 +117,10 @@ func (b *Builder) buildAlterTable(inScope *scope, query string, c *ast.AlterTabl } func (b *Builder) buildDDL(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } + outScope = inScope.push() switch strings.ToLower(c.Action) { case ast.CreateStr: @@ -456,6 +463,9 @@ func (b *Builder) isUniqueColumn(tableSpec *ast.TableSpec, columnName string) bo } func (b *Builder) buildAlterTableClause(inScope *scope, ddl *ast.DDL) []*scope { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, ddl.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScopes := make([]*scope, 0, 1) // RENAME a to b, c to d .. @@ -1607,6 +1617,9 @@ func (b *Builder) convertDefaultExpression(inScope *scope, defaultExpr ast.Expr, } func (b *Builder) buildDBDDL(inScope *scope, c *ast.DBDDL) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() switch strings.ToLower(c.Action) { case ast.CreateStr: diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index a9d12100dc..d43a6326a4 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -34,6 +34,9 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) { sql.IncrementStatusVariable(b.ctx, "Com_insert", 1) b.qFlags.Set(sql.QFlagInsert) + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, i.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } if i.With != nil { inScope = b.buildWith(inScope, i.With) } diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index 400a1ab041..51283365b3 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -272,6 +272,9 @@ func (b *Builder) buildDataSource(inScope *scope, te ast.TableExpr) (outScope *s // build individual table, collect column definitions switch t := (te).(type) { case *ast.AliasedTableExpr: + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, t.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } switch e := t.Expr.(type) { case ast.TableName: tableName := strings.ToLower(e.Name.String()) @@ -497,6 +500,11 @@ func (b *Builder) buildTableFunc(inScope *scope, t *ast.TableFuncExpr) (outScope b.handleErr(err) } } + if authCheckerNode, ok := newInstance.(sql.AuthorizationCheckerNode); ok { + if err = b.cat.AuthorizationHandler().HandleAuthNode(b.ctx, b.authQueryState, authCheckerNode); err != nil { + b.handleErr(err) + } + } // Table Function must always have an alias, pick function name as alias if none is provided var name string diff --git a/sql/planbuilder/load.go b/sql/planbuilder/load.go index e89821f88b..bfa3ff2321 100644 --- a/sql/planbuilder/load.go +++ b/sql/planbuilder/load.go @@ -27,6 +27,9 @@ import ( ) func (b *Builder) buildLoad(inScope *scope, d *ast.Load) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, d.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } dbName := strings.ToLower(d.Table.DbQualifier.String()) if dbName == "" { dbName = b.ctx.GetCurrentDatabase() diff --git a/sql/planbuilder/parse.go b/sql/planbuilder/parse.go index 8ebd464af1..b964cc1911 100644 --- a/sql/planbuilder/parse.go +++ b/sql/planbuilder/parse.go @@ -65,6 +65,11 @@ func (b *Builder) Parse(query string, qFlags *sql.QueryFlags, multi bool) (ret s span, ctx := b.ctx.Span("parse", otel.WithAttributes(attribute.String("query", query))) defer span.End() + if b.authQueryState != nil { + if err = b.authQueryState.Error(); err != nil { + b.handleErr(err) + } + } stmt, parsed, remainder, err := b.parser.ParseWithOptions(ctx, query, ';', multi, b.parserOpts) if err != nil { if goerrors.Is(err, ast.ErrEmpty) { @@ -95,6 +100,11 @@ func (b *Builder) BindOnly(stmt ast.Statement, s string, queryFlags *sql.QueryFl } } }() + if b.authQueryState != nil { + if err = b.authQueryState.Error(); err != nil { + b.handleErr(err) + } + } if queryFlags != nil { b.qFlags = queryFlags } diff --git a/sql/planbuilder/priv.go b/sql/planbuilder/priv.go index c14e5167e7..f0cc291763 100644 --- a/sql/planbuilder/priv.go +++ b/sql/planbuilder/priv.go @@ -173,6 +173,9 @@ func (b *Builder) buildAuthenticatedUser(user ast.AccountWithAuth) plan.Authenti } func (b *Builder) buildCreateUser(inScope *scope, n *ast.CreateUser) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() authUsers := make([]plan.AuthenticatedUser, len(n.Users)) for i, user := range n.Users { @@ -301,6 +304,9 @@ func (b *Builder) buildCreateUser(inScope *scope, n *ast.CreateUser) (outScope * } func (b *Builder) buildRenameUser(inScope *scope, n *ast.RenameUser) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } oldNames := make([]plan.UserName, len(n.Accounts)) newNames := make([]plan.UserName, len(n.Accounts)) for i, account := range n.Accounts { @@ -350,11 +356,18 @@ func (b *Builder) buildGrantPrivilege(inScope *scope, n *ast.GrantPrivilege) (ou MySQLDb: b.resolveDb("mysql"), Catalog: b.cat, } + n.Auth.Extra = outScope.node + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } return outScope } func (b *Builder) buildShowGrants(inScope *scope, n *ast.ShowGrants) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } var currentUser bool var user *plan.UserName if n.For != nil { @@ -380,6 +393,9 @@ func (b *Builder) buildShowGrants(inScope *scope, n *ast.ShowGrants) (outScope * } func (b *Builder) buildFlush(inScope *scope, f *ast.Flush) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, f.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() var writesToBinlog = true switch strings.ToLower(f.Type) { @@ -414,6 +430,9 @@ func (b *Builder) buildFlush(inScope *scope, f *ast.Flush) (outScope *scope) { } func (b *Builder) buildCreateRole(inScope *scope, n *ast.CreateRole) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() outScope.node = &plan.CreateRole{ IfNotExists: n.IfNotExists, @@ -424,6 +443,9 @@ func (b *Builder) buildCreateRole(inScope *scope, n *ast.CreateRole) (outScope * } func (b *Builder) buildDropRole(inScope *scope, n *ast.DropRole) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() outScope.node = &plan.DropRole{ IfExists: n.IfExists, @@ -434,6 +456,9 @@ func (b *Builder) buildDropRole(inScope *scope, n *ast.DropRole) (outScope *scop } func (b *Builder) buildDropUser(inScope *scope, n *ast.DropUser) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() outScope.node = &plan.DropUser{ IfExists: n.IfExists, @@ -451,6 +476,10 @@ func (b *Builder) buildGrantRole(inScope *scope, n *ast.GrantRole) (outScope *sc WithAdminOption: n.WithAdminOption, MySQLDb: b.resolveDb("mysql"), } + n.Auth.Extra = outScope.node + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } return } @@ -462,6 +491,10 @@ func (b *Builder) buildGrantProxy(inScope *scope, n *ast.GrantProxy) (outScope * convertAccountName(n.To...), n.WithGrantOption, ) + n.Auth.Extra = outScope.node + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } return } @@ -482,12 +515,20 @@ func (b *Builder) buildRevokePrivilege(inScope *scope, n *ast.RevokePrivilege) ( Users: users, MySQLDb: b.resolveDb("mysql"), } + n.Auth.Extra = outScope.node + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } return } func (b *Builder) buildRevokeAllPrivileges(inScope *scope, n *ast.RevokeAllPrivileges) (outScope *scope) { outScope = inScope.push() outScope.node = plan.NewRevokeAll(convertAccountName(n.From...)) + n.Auth.Extra = outScope.node + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } return } @@ -498,16 +539,26 @@ func (b *Builder) buildRevokeRole(inScope *scope, n *ast.RevokeRole) (outScope * TargetUsers: convertAccountName(n.From...), MySQLDb: b.resolveDb("mysql"), } + n.Auth.Extra = outScope.node + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } return } func (b *Builder) buildRevokeProxy(inScope *scope, n *ast.RevokeProxy) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() outScope.node = plan.NewRevokeProxy(convertAccountName(n.On)[0], convertAccountName(n.From...)) return } func (b *Builder) buildShowPrivileges(inScope *scope, n *ast.ShowPrivileges) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() outScope.node = plan.NewShowPrivileges() return diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index eaf154cb29..c8dcdbfb2c 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -225,6 +225,9 @@ func (b *Builder) buildIfConditional(inScope *scope, n ast.IfStatementCondition, } func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, c.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { diff --git a/sql/planbuilder/process.go b/sql/planbuilder/process.go index d6b1955b54..251f5a3206 100644 --- a/sql/planbuilder/process.go +++ b/sql/planbuilder/process.go @@ -24,6 +24,9 @@ import ( ) func (b *Builder) buildKill(inScope *scope, kill *ast.Kill) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, kill.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() connID64 := b.getInt64Value(inScope, kill.ConnID, "Error parsing KILL, expected int literal") connID32 := uint32(connID64) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 2ee0ec8c78..7afc30af79 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -257,6 +257,9 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { } return ret case ast.InjectedExpr: + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, v.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } resolvedChildren := make([]any, len(v.Children)) for i, child := range v.Children { resolvedChildren[i] = b.buildScalar(inScope, child) diff --git a/sql/planbuilder/show.go b/sql/planbuilder/show.go index a67e3c29cd..915d49ba92 100644 --- a/sql/planbuilder/show.go +++ b/sql/planbuilder/show.go @@ -31,6 +31,9 @@ import ( ) func (b *Builder) buildShow(inScope *scope, s *ast.Show) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, s.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } showType := strings.ToLower(s.Type) switch showType { case "processlist": diff --git a/sql/planbuilder/spatial.go b/sql/planbuilder/spatial.go index caa3ba949f..7badb3d428 100644 --- a/sql/planbuilder/spatial.go +++ b/sql/planbuilder/spatial.go @@ -25,6 +25,9 @@ import ( ) func (b *Builder) buildCreateSpatialRefSys(inScope *scope, n *ast.CreateSpatialRefSys) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } outScope = inScope.push() srid, err := strconv.ParseInt(string(n.SRID.Val), 10, 16) if err != nil { diff --git a/sql/planbuilder/transactions.go b/sql/planbuilder/transactions.go index 6404245f74..5e7bc45d55 100644 --- a/sql/planbuilder/transactions.go +++ b/sql/planbuilder/transactions.go @@ -26,6 +26,9 @@ import ( ) func (b *Builder) buildUse(inScope *scope, n *ast.Use) (outScope *scope) { + if err := b.cat.AuthorizationHandler().HandleAuth(b.ctx, b.authQueryState, n.Auth); err != nil && b.authEnabled { + b.handleErr(err) + } name := n.DBName.String() ret := plan.NewUse(b.resolveDb(name)) ret.Catalog = b.cat diff --git a/sql/transform/node_test.go b/sql/transform/node_test.go index edcd83744c..8ecc436100 100644 --- a/sql/transform/node_test.go +++ b/sql/transform/node_test.go @@ -335,10 +335,6 @@ func (n *testNode) WithChildren(nodes ...sql.Node) (sql.Node, error) { return &nn, nil } -func (n *testNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return true -} - // CollationCoercibility implements the interface sql.CollationCoercible. func (*testNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 7 diff --git a/test/test_catalog.go b/test/test_catalog.go index b18849aada..1f94f439f1 100644 --- a/test/test_catalog.go +++ b/test/test_catalog.go @@ -213,3 +213,7 @@ func (c *Catalog) DropDbStats(ctx *sql.Context, db string, flush bool) error { //TODO implement me panic("implement me") } + +func (c *Catalog) AuthorizationHandler() sql.AuthorizationHandler { + return sql.GetAuthorizationHandlerFactory().CreateHandler(c) +}