Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,10 +866,6 @@ func (s SimpleTableFunction) WithChildren(_ ...sql.Node) (sql.Node, error) {
return s, nil
}

func (s SimpleTableFunction) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool {
return true
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (SimpleTableFunction) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 7
Expand Down
3 changes: 3 additions & 0 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,9 @@ func newMergableIndex(dbs []sql.Database, tableName string, exprs ...sql.Express
if db == nil {
return nil
}
if tableRevision, ok := table.(*memory.TableRevision); ok {
table = tableRevision.Table
}
return &memory.Index{
DB: db.Name(),
DriverName: memory.IndexDriverId,
Expand Down
16 changes: 8 additions & 8 deletions enginetest/queries/priv_auth_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1780,10 +1780,10 @@ var UserPrivTests = []UserPrivilegeTest{
},
},
{
User: "rand_user1",
Host: "54.244.85.252",
Query: "SELECT * FROM mydb.test;",
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
User: "rand_user1",
Host: "54.244.85.252",
Query: "SELECT * FROM mydb.test;",
ExpectedErrStr: "Access denied for user 'rand_user1' (errno 1045) (sqlstate 28000)",
},
{
User: "rand_user2",
Expand All @@ -1804,10 +1804,10 @@ var UserPrivTests = []UserPrivilegeTest{
},
},
{
User: "rand_user2",
Host: "54.244.85.252",
Query: "SELECT * FROM mydb.test2;",
ExpectedErr: sql.ErrDatabaseAccessDeniedForUser,
User: "rand_user2",
Host: "54.244.85.252",
Query: "SELECT * FROM mydb.test2;",
ExpectedErrStr: "Access denied for user 'rand_user2' (errno 1045) (sqlstate 28000)",
},
},
},
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20241028204000-267861bc75a0
github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20241028204000-267861bc75a0 h1:eeKypNsi1nQmjWxSAAWT6tvRsDWdmll03BozAUUIE4E=
github.com/dolthub/vitess v0.0.0-20241028204000-267861bc75a0/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683 h1:2/RJeUfNAXS7mbBnEr9C36htiCJHk5XldDPzhxtEsME=
github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
4 changes: 0 additions & 4 deletions memory/exponential_dist_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ func (s ExponentialDistTable) WithChildren(_ ...sql.Node) (sql.Node, error) {
return s, nil
}

func (s ExponentialDistTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool {
return true
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (ExponentialDistTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
Expand Down
4 changes: 0 additions & 4 deletions memory/normal_dist_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,6 @@ func (s NormalDistTable) WithChildren(_ ...sql.Node) (sql.Node, error) {
return s, nil
}

func (s NormalDistTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool {
return true
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (NormalDistTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
Expand Down
4 changes: 0 additions & 4 deletions memory/sequence_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ func (s IntSequenceTable) WithChildren(_ ...sql.Node) (sql.Node, error) {
return s, nil
}

func (s IntSequenceTable) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool {
return true
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (IntSequenceTable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
Expand Down
4 changes: 0 additions & 4 deletions memory/table_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ func (s TableFunc) WithChildren(_ ...sql.Node) (sql.Node, error) {
return s, nil
}

func (s TableFunc) CheckPrivileges(_ *sql.Context, _ sql.PrivilegedOperationChecker) bool {
return true
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (TableFunc) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
Expand Down
15 changes: 11 additions & 4 deletions sql/analyzer/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type Catalog struct {
StatsProvider sql.StatsProvider

DbProvider sql.DatabaseProvider
AuthHandler sql.AuthorizationHandler
builtInFunctions function.Registry

// BinlogReplicaController holds an optional controller that receives forwarded binlog
Expand Down Expand Up @@ -64,14 +65,16 @@ type sessionLocks map[uint32]dbLocks

// NewCatalog returns a new empty Catalog with the given provider
func NewCatalog(provider sql.DatabaseProvider) *Catalog {
return &Catalog{
c := &Catalog{
MySQLDb: mysql_db.CreateEmptyMySQLDb(),
InfoSchema: information_schema.NewInformationSchemaDatabase(),
DbProvider: provider,
builtInFunctions: function.NewRegistry(),
StatsProvider: memory.NewStatsProv(),
locks: make(sessionLocks),
}
c.AuthHandler = sql.GetAuthorizationHandlerFactory().CreateHandler(c)
return c
}

func (c *Catalog) HasBinlogReplicaController() bool {
Expand Down Expand Up @@ -109,7 +112,7 @@ func (c *Catalog) AllDatabases(ctx *sql.Context) []sql.Database {
dbs = append(dbs, c.InfoSchema)

if c.MySQLDb.Enabled() {
dbs = append(dbs, mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider).AllDatabases(ctx)...)
dbs = append(dbs, mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider, c.AuthHandler).AllDatabases(ctx)...)
} else {
dbs = append(dbs, c.DbProvider.AllDatabases(ctx)...)
}
Expand Down Expand Up @@ -162,7 +165,7 @@ func (c *Catalog) HasDatabase(ctx *sql.Context, db string) bool {
if db == "information_schema" {
return true
} else if c.MySQLDb.Enabled() {
return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider).HasDatabase(ctx, db)
return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider, c.AuthHandler).HasDatabase(ctx, db)
} else {
return c.DbProvider.HasDatabase(ctx, db)
}
Expand All @@ -173,7 +176,7 @@ func (c *Catalog) Database(ctx *sql.Context, db string) (sql.Database, error) {
if strings.ToLower(db) == "information_schema" {
return c.InfoSchema, nil
} else if c.MySQLDb.Enabled() {
return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider).Database(ctx, db)
return mysql_db.NewPrivilegedDatabaseProvider(c.MySQLDb, c.DbProvider, c.AuthHandler).Database(ctx, db)
} else {
return c.DbProvider.Database(ctx, db)
}
Expand Down Expand Up @@ -440,6 +443,10 @@ func (c *Catalog) DataLength(ctx *sql.Context, db string, table sql.Table) (uint
return st.DataLength(ctx)
}

func (c *Catalog) AuthorizationHandler() sql.AuthorizationHandler {
return c.AuthHandler
}

func getStatisticsTable(table sql.Table, prevTable sql.Table) (sql.StatisticsTable, bool) {
// Some TableNodes return themselves for UnderlyingTable, so we need to check for that
if table == prevTable {
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/load_triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func loadTriggersFromDb(ctx *sql.Context, a *Analyzer, db sql.Database) ([]*plan
for _, trigger := range triggers {
var parsedTrigger sql.Node
sqlMode := sql.NewSqlModeFromString(trigger.SqlMode)
// TODO: should perhaps add the auth query handler to the analyzer? does this even use auth?
parsedTrigger, _, err = planbuilder.ParseWithOptions(ctx, a.Catalog, trigger.CreateStatement, sqlMode.ParserOptions())
if err != nil {
return nil, err
Expand Down
12 changes: 0 additions & 12 deletions sql/analyzer/node_batches.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ func getBatchesForNode(node sql.Node) ([]*Batch, bool) {
Id: applyForeignKeysId,
Apply: applyForeignKeys,
},
{
Id: validatePrivilegesId,
Apply: validatePrivileges,
},
{
Id: validateReadOnlyDatabaseId,
Apply: validateReadOnlyDatabase,
Expand Down Expand Up @@ -74,10 +70,6 @@ func getBatchesForNode(node sql.Node) ([]*Batch, bool) {
Id: applyForeignKeysId,
Apply: applyForeignKeys,
},
{
Id: validatePrivilegesId,
Apply: validatePrivileges,
},
{
Id: optimizeJoinsId,
Apply: optimizeJoins,
Expand Down Expand Up @@ -123,10 +115,6 @@ func getBatchesForNode(node sql.Node) ([]*Batch, bool) {
Id: applyForeignKeysId,
Apply: applyForeignKeys,
},
{
Id: validatePrivilegesId,
Apply: validatePrivileges,
},
{
Id: optimizeJoinsId,
Apply: optimizeJoins,
Expand Down
60 changes: 0 additions & 60 deletions sql/analyzer/privileges.go

This file was deleted.

5 changes: 2 additions & 3 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ var OnceBeforeDefault = []Rule{
{validateReadOnlyTransactionId, validateReadOnlyTransaction},
{validateDatabaseSetId, validateDatabaseSet},
{validateDeleteFromId, validateDeleteFrom},
{validatePrivilegesId, validatePrivileges}, // Ensure that checking privileges happens after db, table & table function resolution
{simplifyFiltersId, simplifyFilters}, //TODO inline?
{pushNotFiltersId, pushNotFilters}, //TODO inline?
{simplifyFiltersId, simplifyFilters}, //TODO inline?
{pushNotFiltersId, pushNotFilters}, //TODO inline?
{hoistOutOfScopeFiltersId, hoistOutOfScopeFilters},
}

Expand Down
2 changes: 2 additions & 0 deletions sql/analyzer/stored_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan
var procToRegister *plan.Procedure
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
Expand Down Expand Up @@ -290,6 +291,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
}
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
if call.AsOf() != nil {
asOf, err := call.AsOf().Eval(ctx, nil)
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
}

b := planbuilder.New(ctx, a.Catalog, nil, nil)
b.DisableAuth()
prevActive := b.TriggerCtx().Active
b.TriggerCtx().Active = true
defer func() {
Expand Down
3 changes: 0 additions & 3 deletions sql/analyzer/validation_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,6 @@ func (dummyNode) Schema() sql.Schema { return ni
func (dummyNode) Children() []sql.Node { return nil }
func (dummyNode) RowIter(*sql.Context, sql.Row) (sql.RowIter, error) { return nil, nil }
func (dummyNode) WithChildren(...sql.Node) (sql.Node, error) { return nil, nil }
func (dummyNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return true
}
func (dummyNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 7
}
Expand Down
Loading