@@ -34,6 +34,7 @@ type Catalog struct {
3434 StatsProvider sql.StatsProvider
3535
3636 DbProvider sql.DatabaseProvider
37+ AuthHandler sql.AuthorizationHandler
3738 builtInFunctions function.Registry
3839
3940 // BinlogReplicaController holds an optional controller that receives forwarded binlog
@@ -64,14 +65,16 @@ type sessionLocks map[uint32]dbLocks
6465
6566// NewCatalog returns a new empty Catalog with the given provider
6667func NewCatalog (provider sql.DatabaseProvider ) * Catalog {
67- return & Catalog {
68+ c := & Catalog {
6869 MySQLDb : mysql_db .CreateEmptyMySQLDb (),
6970 InfoSchema : information_schema .NewInformationSchemaDatabase (),
7071 DbProvider : provider ,
7172 builtInFunctions : function .NewRegistry (),
7273 StatsProvider : memory .NewStatsProv (),
7374 locks : make (sessionLocks ),
7475 }
76+ c .AuthHandler = sql .GetAuthorizationHandlerFactory ().CreateHandler (c )
77+ return c
7578}
7679
7780func (c * Catalog ) HasBinlogReplicaController () bool {
@@ -109,7 +112,7 @@ func (c *Catalog) AllDatabases(ctx *sql.Context) []sql.Database {
109112 dbs = append (dbs , c .InfoSchema )
110113
111114 if c .MySQLDb .Enabled () {
112- dbs = append (dbs , mysql_db .NewPrivilegedDatabaseProvider (c .MySQLDb , c .DbProvider ).AllDatabases (ctx )... )
115+ dbs = append (dbs , mysql_db .NewPrivilegedDatabaseProvider (c .MySQLDb , c .DbProvider , c . AuthHandler ).AllDatabases (ctx )... )
113116 } else {
114117 dbs = append (dbs , c .DbProvider .AllDatabases (ctx )... )
115118 }
@@ -162,7 +165,7 @@ func (c *Catalog) HasDatabase(ctx *sql.Context, db string) bool {
162165 if db == "information_schema" {
163166 return true
164167 } else if c .MySQLDb .Enabled () {
165- return mysql_db .NewPrivilegedDatabaseProvider (c .MySQLDb , c .DbProvider ).HasDatabase (ctx , db )
168+ return mysql_db .NewPrivilegedDatabaseProvider (c .MySQLDb , c .DbProvider , c . AuthHandler ).HasDatabase (ctx , db )
166169 } else {
167170 return c .DbProvider .HasDatabase (ctx , db )
168171 }
@@ -173,7 +176,7 @@ func (c *Catalog) Database(ctx *sql.Context, db string) (sql.Database, error) {
173176 if strings .ToLower (db ) == "information_schema" {
174177 return c .InfoSchema , nil
175178 } else if c .MySQLDb .Enabled () {
176- return mysql_db .NewPrivilegedDatabaseProvider (c .MySQLDb , c .DbProvider ).Database (ctx , db )
179+ return mysql_db .NewPrivilegedDatabaseProvider (c .MySQLDb , c .DbProvider , c . AuthHandler ).Database (ctx , db )
177180 } else {
178181 return c .DbProvider .Database (ctx , db )
179182 }
@@ -440,6 +443,10 @@ func (c *Catalog) DataLength(ctx *sql.Context, db string, table sql.Table) (uint
440443 return st .DataLength (ctx )
441444}
442445
446+ func (c * Catalog ) AuthorizationHandler () sql.AuthorizationHandler {
447+ return c .AuthHandler
448+ }
449+
443450func getStatisticsTable (table sql.Table , prevTable sql.Table ) (sql.StatisticsTable , bool ) {
444451 // Some TableNodes return themselves for UnderlyingTable, so we need to check for that
445452 if table == prevTable {
0 commit comments