diff --git a/sql/analyzer/apply_foreign_keys.go b/sql/analyzer/apply_foreign_keys.go index 8ec65010c9..8210fcc3c3 100644 --- a/sql/analyzer/apply_foreign_keys.go +++ b/sql/analyzer/apply_foreign_keys.go @@ -58,10 +58,20 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f fkParentTbls[i] = nil continue } - parentTbl, _, err := a.Catalog.Table(ctx, fkDef.ParentDatabase, fkDef.ParentTable) + + parentTbl, _, err := a.Catalog.TableSchema(ctx, fkDef.ParentDatabase, fkDef.ParentSchema, fkDef.ParentTable) if err != nil { return nil, transform.SameTree, err } + + // If we are working with a schema-enabled database, alter the foreign key defn to apply the schema name we + // just resolved + dst, ok := parentTbl.(sql.DatabaseSchemaTable) + if ok { + schemaName := dst.DatabaseSchema().SchemaName() + fkDef.ParentSchema = schemaName + } + fkParentTbl, ok := parentTbl.(sql.ForeignKeyTable) if !ok { return nil, transform.SameTree, sql.ErrNoForeignKeySupport.New(fkDef.ParentTable) @@ -236,11 +246,11 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa return nil, nil } // Tables do not include their database. As a workaround, we'll use the first foreign key to tell us the database. - updater, err = cache.AddUpdater(ctx, tbl, fks[0].Database, fks[0].Table) + updater, err = cache.AddUpdater(ctx, tbl, fks[0].Database, fks[0].SchemaName, fks[0].Table) if err != nil { return nil, err } - fkChain = fkChain.AddTable(fks[0].Database, fks[0].Table).AddTableUpdater(fks[0].Database, fks[0].Table, updater) + fkChain = fkChain.AddTable(fks[0].Database, fks[0].SchemaName, fks[0].Table).AddTableUpdater(fks[0].Database, fks[0].SchemaName, fks[0].Table, updater) tblSch := tbl.Schema() fkEditor := &plan.ForeignKeyEditor{ @@ -251,7 +261,7 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa Cyclical: false, } for i, fk := range fks { - parentTbl, parentUpdater, err := cache.GetUpdater(ctx, a, fk.ParentDatabase, fk.ParentTable) + parentTbl, parentUpdater, err := cache.GetUpdater(ctx, a, fk.ParentDatabase, fk.ParentSchema, fk.ParentTable) if err != nil { return nil, sql.ErrForeignKeyNotResolved.New(fk.Database, fk.Table, fk.Name, strings.Join(fk.Columns, "`, `"), fk.ParentTable, strings.Join(fk.ParentColumns, "`, `")) @@ -316,7 +326,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa // Check if we already have an editor that we can reuse. If we can, we'll return that instead. // Tables do not include their database. As a workaround, we'll use the first foreign key to tell us the database. - cachedFkEditor := cache.GetEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentTable) + cachedFkEditor := cache.GetEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable) if cachedFkEditor != nil { // Reusing an editor means that we've hit a cycle, so we update the cached editor. cachedFkEditor.Cyclical = true @@ -332,7 +342,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa RefActions: make([]plan.ForeignKeyRefActionData, len(fks)), Cyclical: false, } - fkEditor.Editor, err = cache.AddUpdater(ctx, tbl, fks[0].ParentDatabase, fks[0].ParentTable) + fkEditor.Editor, err = cache.AddUpdater(ctx, tbl, fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable) if err != nil { return nil, err } @@ -341,12 +351,12 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa fkEditor.RefActions = make([]plan.ForeignKeyRefActionData, len(fks)) } // Add the editor to the cache - cache.AddEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentTable) + cache.AddEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable) // Ensure that the chain has the table and updater - fkChain = fkChain.AddTable(fks[0].ParentDatabase, fks[0].ParentTable).AddTableUpdater(fks[0].ParentDatabase, fks[0].ParentTable, fkEditor.Editor) + fkChain = fkChain.AddTable(fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable).AddTableUpdater(fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable, fkEditor.Editor) for i, fk := range fks { - childTbl, childUpdater, err := cache.GetUpdater(ctx, a, fk.Database, fk.Table) + childTbl, childUpdater, err := cache.GetUpdater(ctx, a, fk.Database, fk.SchemaName, fk.Table) if err != nil { return nil, sql.ErrForeignKeyNotResolved.New(fk.Database, fk.Table, fk.Name, strings.Join(fk.Columns, "`, `"), fk.ParentTable, strings.Join(fk.ParentColumns, "`, `")) @@ -354,7 +364,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa // If either referential action is not equivalent to RESTRICT, then the updater has the possibility of having // its contents modified, therefore we add it to the chain. if !fk.OnUpdate.IsEquivalentToRestrict() || !fk.OnDelete.IsEquivalentToRestrict() { - fkChain = fkChain.AddTableUpdater(fk.Database, fk.Table, childUpdater) + fkChain = fkChain.AddTableUpdater(fk.Database, fk.SchemaName, fk.Table, childUpdater) } // Resolve the foreign key if it has not been resolved yet @@ -403,7 +413,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa fkEditor.Cyclical = fkEditor.Cyclical || childEditor.Cyclical // If "ON UPDATE CASCADE" or "ON UPDATE SET NULL" recurses onto the same table that has been previously updated // in the same cascade then it's treated like a RESTRICT (does not apply to "ON DELETE") - if fkChain.HasTable(fk.Database, fk.Table) { + if fkChain.HasTable(fk.Database, fk.SchemaName, fk.Table) { fk.OnUpdate = sql.ForeignKeyReferentialAction_Restrict } fkEditor.RefActions[i] = plan.ForeignKeyRefActionData{ @@ -424,14 +434,16 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa // foreignKeyTableName is the combination of a table's database along with their name, both lowercased. type foreignKeyTableName struct { - dbName string - tblName string + dbName string + schemaName string + tblName string } -func newForeignKeyTableName(dbName, tblName string) foreignKeyTableName { +func newForeignKeyTableName(dbName, schemaName, tblName string) foreignKeyTableName { return foreignKeyTableName{ - dbName: strings.ToLower(dbName), - tblName: strings.ToLower(tblName), + dbName: strings.ToLower(dbName), + schemaName: strings.ToLower(schemaName), + tblName: strings.ToLower(tblName), } } @@ -458,11 +470,8 @@ func newForeignKeyCache() *foreignKeyCache { // AddUpdater will add the given foreign key table (and updater) to the cache and returns its updater. If it already // exists, it is not added, and instead the cached updater is returned. This is so that the same updater is referenced // by all foreign key instances. -func (cache *foreignKeyCache) AddUpdater(ctx *sql.Context, tbl sql.ForeignKeyTable, dbName string, tblName string) (sql.ForeignKeyEditor, error) { - fkTableName := foreignKeyTableName{ - dbName: strings.ToLower(dbName), - tblName: strings.ToLower(tblName), - } +func (cache *foreignKeyCache) AddUpdater(ctx *sql.Context, tbl sql.ForeignKeyTable, dbName, schemaName, tblName string) (sql.ForeignKeyEditor, error) { + fkTableName := newForeignKeyTableName(dbName, schemaName, tblName) if cachedEditor, ok := cache.updaterCache[fkTableName]; ok { return cachedEditor.updater, nil } @@ -476,27 +485,21 @@ func (cache *foreignKeyCache) AddUpdater(ctx *sql.Context, tbl sql.ForeignKeyTab // AddEditor will add the given foreign key editor to the cache. Does not validate that the editor is unique, therefore // GetEditor should be called before this function. -func (cache *foreignKeyCache) AddEditor(editor *plan.ForeignKeyEditor, dbName string, tblName string) { +func (cache *foreignKeyCache) AddEditor(editor *plan.ForeignKeyEditor, dbName, schemaName, tblName string) { if editor == nil { panic("cannot pass in nil editor") // Should never be hit } - fkTableName := foreignKeyTableName{ - dbName: strings.ToLower(dbName), - tblName: strings.ToLower(tblName), - } + fkTableName := newForeignKeyTableName(dbName, schemaName, tblName) cache.editorsCache[fkTableName] = append(cache.editorsCache[fkTableName], editor) } // GetUpdater returns the given foreign key table updater. -func (cache *foreignKeyCache) GetUpdater(ctx *sql.Context, a *Analyzer, dbName string, tblName string) (sql.ForeignKeyTable, sql.ForeignKeyEditor, error) { - fkTableName := foreignKeyTableName{ - dbName: strings.ToLower(dbName), - tblName: strings.ToLower(tblName), - } +func (cache *foreignKeyCache) GetUpdater(ctx *sql.Context, a *Analyzer, dbName, schemaName, tblName string) (sql.ForeignKeyTable, sql.ForeignKeyEditor, error) { + fkTableName := newForeignKeyTableName(dbName, schemaName, tblName) if fkTblEditor, ok := cache.updaterCache[fkTableName]; ok { return fkTblEditor.tbl, fkTblEditor.updater, nil } - tbl, _, err := a.Catalog.Table(ctx, dbName, tblName) + tbl, _, err := a.Catalog.TableSchema(ctx, dbName, schemaName, tblName) if err != nil { return nil, nil, err } @@ -514,11 +517,8 @@ func (cache *foreignKeyCache) GetUpdater(ctx *sql.Context, a *Analyzer, dbName s // GetEditor returns a foreign key editor that matches the given editor in all ways except for the referential actions. // Returns nil if no such editors have been cached. -func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName string, tblName string) *plan.ForeignKeyEditor { - fkTableName := foreignKeyTableName{ - dbName: strings.ToLower(dbName), - tblName: strings.ToLower(tblName), - } +func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName, schemaName, tblName string) *plan.ForeignKeyEditor { + fkTableName := newForeignKeyTableName(dbName, schemaName, tblName) // It is safe to assume that the index and schema will match for a table that has the same name on the same database, // so we only need to check that the references match. As long as they refer to the same foreign key, they should // match, so we only need to check the names. @@ -562,7 +562,7 @@ func newForeignKeyChain() foreignKeyChain { } // AddTable returns a new chain with the added table. -func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyChain { +func (chain foreignKeyChain) AddTable(dbName string, schemaName, tblName string) foreignKeyChain { newFkNames := make(map[string]struct{}) newFkTables := make(map[foreignKeyTableName]struct{}) for fkName := range chain.fkNames { @@ -571,7 +571,7 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC for fkTable := range chain.fkTables { newFkTables[fkTable] = struct{}{} } - newFkTables[newForeignKeyTableName(dbName, tblName)] = struct{}{} + newFkTables[newForeignKeyTableName(dbName, schemaName, tblName)] = struct{}{} return foreignKeyChain{ fkNames: newFkNames, fkTables: newFkTables, @@ -580,8 +580,8 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC } // AddTableUpdater returns a new chain with the added foreign key updater. -func (chain foreignKeyChain) AddTableUpdater(dbName string, tblName string, fkUpdater sql.ForeignKeyEditor) foreignKeyChain { - chain.fkUpdaters[newForeignKeyTableName(dbName, tblName)] = fkUpdater +func (chain foreignKeyChain) AddTableUpdater(dbName, schemaName, tblName string, fkUpdater sql.ForeignKeyEditor) foreignKeyChain { + chain.fkUpdaters[newForeignKeyTableName(dbName, schemaName, tblName)] = fkUpdater return chain } @@ -604,8 +604,8 @@ func (chain foreignKeyChain) AddForeignKey(fkName string) foreignKeyChain { } // HasTable returns whether the chain contains the given table. Case-insensitive. -func (chain foreignKeyChain) HasTable(dbName string, tblName string) bool { - _, ok := chain.fkTables[newForeignKeyTableName(dbName, tblName)] +func (chain foreignKeyChain) HasTable(dbName, schemaName, tblName string) bool { + _, ok := chain.fkTables[newForeignKeyTableName(dbName, schemaName, tblName)] return ok } diff --git a/sql/analyzer/catalog.go b/sql/analyzer/catalog.go index 81f8ed155c..c8e458790a 100644 --- a/sql/analyzer/catalog.go +++ b/sql/analyzer/catalog.go @@ -247,6 +247,34 @@ func (c *Catalog) Table(ctx *sql.Context, dbName, tableName string) (sql.Table, return c.DatabaseTable(ctx, db, tableName) } +// TableSchema returns the table in the given database with the given name, in the given schema name +func (c *Catalog) TableSchema(ctx *sql.Context, dbName, schemaName, tableName string) (sql.Table, sql.Database, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + db, err := c.Database(ctx, dbName) + if err != nil { + return nil, nil, err + } + + if schemaName != "" { + sdb, ok := db.(sql.SchemaDatabase) + if !ok { + return nil, nil, sql.ErrDatabaseSchemasNotSupported.New(db.Name()) + } + + db, ok, err = sdb.GetSchema(ctx, schemaName) + if err != nil { + return nil, nil, err + } + if !ok { + return nil, nil, sql.ErrDatabaseSchemaNotFound.New(schemaName) + } + } + + return c.DatabaseTable(ctx, db, tableName) +} + func (c *Catalog) DatabaseTable(ctx *sql.Context, db sql.Database, tableName string) (sql.Table, sql.Database, error) { _, ok := db.(sql.UnresolvedDatabase) if ok { diff --git a/sql/constraints.go b/sql/constraints.go index 1ff4115f77..4c84871960 100644 --- a/sql/constraints.go +++ b/sql/constraints.go @@ -45,21 +45,36 @@ func (f ForeignKeyReferentialAction) IsEquivalentToRestrict() bool { // ForeignKeyConstraint declares a constraint between the columns of two tables. type ForeignKeyConstraint struct { - Name string - Database string - Table string - Columns []string + // Name is the name of the foreign key constraint + Name string + // Database is the name of the database of the table with the constraint + Database string + // SchemaName is the name of the schema of the table, for databases that support schemas. + SchemaName string + // Table is the name of the table with the constraint + Table string + // Columns is the list of columns in the table that are part of the foreign key + Columns []string + // ParentDatabase is the name of the database of the parent table ParentDatabase string - ParentTable string - ParentColumns []string - OnUpdate ForeignKeyReferentialAction - OnDelete ForeignKeyReferentialAction - IsResolved bool + // ParentSchema is the name of the schema of the parent table, for databases that support schemas. + ParentSchema string + // ParentTable is the name of the parent table + ParentTable string + // ParentColumns is the list of columns in the parent table that are part of the foreign key + ParentColumns []string + // OnUpdate is the action to take when the constraint is violated when a row in the parent table is updated + OnUpdate ForeignKeyReferentialAction + // OnDelete is the action to take when the constraint is violated when a row in the parent table is deleted + OnDelete ForeignKeyReferentialAction + // IsResolved is true if the foreign key has been resolved, false otherwise + IsResolved bool } // IsSelfReferential returns whether this foreign key represents a self-referential foreign key. func (f *ForeignKeyConstraint) IsSelfReferential() bool { return strings.EqualFold(f.Database, f.ParentDatabase) && + strings.EqualFold(f.SchemaName, f.ParentSchema) && strings.EqualFold(f.Table, f.ParentTable) } diff --git a/sql/databases.go b/sql/databases.go index f1891878c1..230efbd6a8 100644 --- a/sql/databases.go +++ b/sql/databases.go @@ -86,10 +86,6 @@ type SchemaDatabase interface { CreateSchema(ctx *Context, schemaName string) error // AllSchemas returns all schemas in the database. AllSchemas(ctx *Context) ([]DatabaseSchema, error) - // // GetTable returns the table with the name given in the schema given. The schema name may be empty. - // GetTable(ctx *Context, schemaName, tableName string) (Table, bool, error) - // // GetTableAsOf returns the table with the name given in the schema given. The schema name may be empty. - // GetTableAsOf(ctx *Context, schemaName, tableName string, asOf interface{}) (Table, bool, error) } // DatabaseSchema is a schema that can be queried for tables. It is functionally equivalent to a Database diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index a9bb6cee01..dd4de53cf1 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -590,6 +590,12 @@ func (b *Builder) buildAlterConstraint(inScope *scope, ddl *ast.DDL, table *plan case *sql.ForeignKeyConstraint: c.Database = table.SqlDatabase.Name() c.Table = table.Name() + + ds, ok := table.SqlDatabase.(sql.DatabaseSchema) + if ok { + c.SchemaName = ds.SchemaName() + } + alterFk := plan.NewAlterAddForeignKey(c) alterFk.DbProvider = b.cat outScope.node = alterFk @@ -631,6 +637,12 @@ func (b *Builder) buildAlterConstraint(inScope *scope, ddl *ast.DDL, table *plan b.handleErr(err) } database := table.SqlDatabase.Name() + + ds, ok := table.SqlDatabase.(sql.DatabaseSchema) + if ok { + c.SchemaName = ds.SchemaName() + } + dropFk := plan.NewAlterRenameForeignKey(database, table.Name(), c.Name, cc.Name) dropFk.DbProvider = b.cat outScope.node = dropFk @@ -760,6 +772,7 @@ func (b *Builder) convertConstraintDefinition(inScope *scope, cd *ast.Constraint Name: cd.Name, Columns: columns, ParentDatabase: refDatabase, + ParentSchema: fkConstraint.ReferencedTable.SchemaQualifier.String(), ParentTable: fkConstraint.ReferencedTable.Name.String(), ParentColumns: refColumns, OnUpdate: b.buildReferentialAction(fkConstraint.OnUpdate), diff --git a/sql/rowexec/ddl.go b/sql/rowexec/ddl.go index 8d3bfb8ca4..a06e14fdae 100644 --- a/sql/rowexec/ddl.go +++ b/sql/rowexec/ddl.go @@ -1187,6 +1187,22 @@ func (b *BaseBuilder) buildCreateForeignKey(ctx *sql.Context, n *plan.CreateFore if err != nil { return nil, err } + + if n.FkDef.SchemaName != "" { + sdb, ok := db.(sql.SchemaDatabase) + if !ok { + return nil, sql.ErrDatabaseSchemasNotSupported.New(n.FkDef.Database) + } + sch, schemaExists, err := sdb.GetSchema(ctx, n.FkDef.SchemaName) + if err != nil { + return nil, err + } + if !schemaExists { + return nil, sql.ErrDatabaseSchemaNotFound.New(n.FkDef.SchemaName) + } + db = sch + } + tbl, ok, err := db.GetTableInsensitive(ctx, n.FkDef.Table) if err != nil { return nil, err @@ -1199,6 +1215,22 @@ func (b *BaseBuilder) buildCreateForeignKey(ctx *sql.Context, n *plan.CreateFore if err != nil { return nil, err } + + if n.FkDef.ParentSchema != "" { + sdb, ok := refDb.(sql.SchemaDatabase) + if !ok { + return nil, sql.ErrDatabaseSchemasNotSupported.New(n.FkDef.ParentDatabase) + } + sch, schemaExists, err := sdb.GetSchema(ctx, n.FkDef.ParentSchema) + if err != nil { + return nil, err + } + if !schemaExists { + return nil, sql.ErrDatabaseSchemaNotFound.New(n.FkDef.ParentSchema) + } + refDb = sch + } + refTbl, ok, err := refDb.GetTableInsensitive(ctx, n.FkDef.ParentTable) if err != nil { return nil, err @@ -1207,6 +1239,14 @@ func (b *BaseBuilder) buildCreateForeignKey(ctx *sql.Context, n *plan.CreateFore return nil, sql.ErrTableNotFound.New(n.FkDef.ParentTable) } + // If we didn't have an explicit schema, fill in the resolved schema for the fk table defn + if n.FkDef.ParentSchema == "" { + dst, ok := refTbl.(sql.DatabaseSchemaTable) + if ok { + n.FkDef.ParentSchema = dst.DatabaseSchema().SchemaName() + } + } + fkTbl, ok := tbl.(sql.ForeignKeyTable) if !ok { return nil, sql.ErrNoForeignKeySupport.New(n.FkDef.Table) diff --git a/sql/tables.go b/sql/tables.go index 127552fae6..7b203aab25 100644 --- a/sql/tables.go +++ b/sql/tables.go @@ -32,6 +32,14 @@ type Table interface { PartitionRows(*Context, Partition) (RowIter, error) } +// DatabaseSchemaTable is a table that can return the database schema it belongs to. This interface must be implemented +// for correct function of some DDL in databases that implement SchemaDatabase. +type DatabaseSchemaTable interface { + Table + // DatabaseSchema returns the database schema that this table belongs to. + DatabaseSchema() DatabaseSchema +} + // TableFunction is a node that is generated by a function and can be used as a table factor in many SQL queries. type TableFunction interface { Node