Skip to content
86 changes: 43 additions & 43 deletions sql/analyzer/apply_foreign_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,20 @@ func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *f
fkParentTbls[i] = nil
continue
}
parentTbl, _, err := a.Catalog.Table(ctx, fkDef.ParentDatabase, fkDef.ParentTable)

parentTbl, _, err := a.Catalog.TableSchema(ctx, fkDef.ParentDatabase, fkDef.ParentSchema, fkDef.ParentTable)
if err != nil {
return nil, transform.SameTree, err
}

// If we are working with a schema-enabled database, alter the foreign key defn to apply the schema name we
// just resolved
dst, ok := parentTbl.(sql.DatabaseSchemaTable)
if ok {
schemaName := dst.DatabaseSchema().SchemaName()
fkDef.ParentSchema = schemaName
}

fkParentTbl, ok := parentTbl.(sql.ForeignKeyTable)
if !ok {
return nil, transform.SameTree, sql.ErrNoForeignKeySupport.New(fkDef.ParentTable)
Expand Down Expand Up @@ -236,11 +246,11 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
return nil, nil
}
// Tables do not include their database. As a workaround, we'll use the first foreign key to tell us the database.
updater, err = cache.AddUpdater(ctx, tbl, fks[0].Database, fks[0].Table)
updater, err = cache.AddUpdater(ctx, tbl, fks[0].Database, fks[0].SchemaName, fks[0].Table)
if err != nil {
return nil, err
}
fkChain = fkChain.AddTable(fks[0].Database, fks[0].Table).AddTableUpdater(fks[0].Database, fks[0].Table, updater)
fkChain = fkChain.AddTable(fks[0].Database, fks[0].SchemaName, fks[0].Table).AddTableUpdater(fks[0].Database, fks[0].SchemaName, fks[0].Table, updater)

tblSch := tbl.Schema()
fkEditor := &plan.ForeignKeyEditor{
Expand All @@ -251,7 +261,7 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
Cyclical: false,
}
for i, fk := range fks {
parentTbl, parentUpdater, err := cache.GetUpdater(ctx, a, fk.ParentDatabase, fk.ParentTable)
parentTbl, parentUpdater, err := cache.GetUpdater(ctx, a, fk.ParentDatabase, fk.ParentSchema, fk.ParentTable)
if err != nil {
return nil, sql.ErrForeignKeyNotResolved.New(fk.Database, fk.Table, fk.Name,
strings.Join(fk.Columns, "`, `"), fk.ParentTable, strings.Join(fk.ParentColumns, "`, `"))
Expand Down Expand Up @@ -316,7 +326,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa

// Check if we already have an editor that we can reuse. If we can, we'll return that instead.
// Tables do not include their database. As a workaround, we'll use the first foreign key to tell us the database.
cachedFkEditor := cache.GetEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentTable)
cachedFkEditor := cache.GetEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable)
if cachedFkEditor != nil {
// Reusing an editor means that we've hit a cycle, so we update the cached editor.
cachedFkEditor.Cyclical = true
Expand All @@ -332,7 +342,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
RefActions: make([]plan.ForeignKeyRefActionData, len(fks)),
Cyclical: false,
}
fkEditor.Editor, err = cache.AddUpdater(ctx, tbl, fks[0].ParentDatabase, fks[0].ParentTable)
fkEditor.Editor, err = cache.AddUpdater(ctx, tbl, fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable)
if err != nil {
return nil, err
}
Expand All @@ -341,20 +351,20 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
fkEditor.RefActions = make([]plan.ForeignKeyRefActionData, len(fks))
}
// Add the editor to the cache
cache.AddEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentTable)
cache.AddEditor(fkEditor, fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable)
// Ensure that the chain has the table and updater
fkChain = fkChain.AddTable(fks[0].ParentDatabase, fks[0].ParentTable).AddTableUpdater(fks[0].ParentDatabase, fks[0].ParentTable, fkEditor.Editor)
fkChain = fkChain.AddTable(fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable).AddTableUpdater(fks[0].ParentDatabase, fks[0].ParentSchema, fks[0].ParentTable, fkEditor.Editor)

for i, fk := range fks {
childTbl, childUpdater, err := cache.GetUpdater(ctx, a, fk.Database, fk.Table)
childTbl, childUpdater, err := cache.GetUpdater(ctx, a, fk.Database, fk.SchemaName, fk.Table)
if err != nil {
return nil, sql.ErrForeignKeyNotResolved.New(fk.Database, fk.Table, fk.Name,
strings.Join(fk.Columns, "`, `"), fk.ParentTable, strings.Join(fk.ParentColumns, "`, `"))
}
// If either referential action is not equivalent to RESTRICT, then the updater has the possibility of having
// its contents modified, therefore we add it to the chain.
if !fk.OnUpdate.IsEquivalentToRestrict() || !fk.OnDelete.IsEquivalentToRestrict() {
fkChain = fkChain.AddTableUpdater(fk.Database, fk.Table, childUpdater)
fkChain = fkChain.AddTableUpdater(fk.Database, fk.SchemaName, fk.Table, childUpdater)
}

// Resolve the foreign key if it has not been resolved yet
Expand Down Expand Up @@ -403,7 +413,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
fkEditor.Cyclical = fkEditor.Cyclical || childEditor.Cyclical
// If "ON UPDATE CASCADE" or "ON UPDATE SET NULL" recurses onto the same table that has been previously updated
// in the same cascade then it's treated like a RESTRICT (does not apply to "ON DELETE")
if fkChain.HasTable(fk.Database, fk.Table) {
if fkChain.HasTable(fk.Database, fk.SchemaName, fk.Table) {
fk.OnUpdate = sql.ForeignKeyReferentialAction_Restrict
}
fkEditor.RefActions[i] = plan.ForeignKeyRefActionData{
Expand All @@ -424,14 +434,16 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa

// foreignKeyTableName is the combination of a table's database along with their name, both lowercased.
type foreignKeyTableName struct {
dbName string
tblName string
dbName string
schemaName string
tblName string
}

func newForeignKeyTableName(dbName, tblName string) foreignKeyTableName {
func newForeignKeyTableName(dbName, schemaName, tblName string) foreignKeyTableName {
return foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
dbName: strings.ToLower(dbName),
schemaName: strings.ToLower(schemaName),
tblName: strings.ToLower(tblName),
}
}

Expand All @@ -458,11 +470,8 @@ func newForeignKeyCache() *foreignKeyCache {
// AddUpdater will add the given foreign key table (and updater) to the cache and returns its updater. If it already
// exists, it is not added, and instead the cached updater is returned. This is so that the same updater is referenced
// by all foreign key instances.
func (cache *foreignKeyCache) AddUpdater(ctx *sql.Context, tbl sql.ForeignKeyTable, dbName string, tblName string) (sql.ForeignKeyEditor, error) {
fkTableName := foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}
func (cache *foreignKeyCache) AddUpdater(ctx *sql.Context, tbl sql.ForeignKeyTable, dbName, schemaName, tblName string) (sql.ForeignKeyEditor, error) {
fkTableName := newForeignKeyTableName(dbName, schemaName, tblName)
if cachedEditor, ok := cache.updaterCache[fkTableName]; ok {
return cachedEditor.updater, nil
}
Expand All @@ -476,27 +485,21 @@ func (cache *foreignKeyCache) AddUpdater(ctx *sql.Context, tbl sql.ForeignKeyTab

// AddEditor will add the given foreign key editor to the cache. Does not validate that the editor is unique, therefore
// GetEditor should be called before this function.
func (cache *foreignKeyCache) AddEditor(editor *plan.ForeignKeyEditor, dbName string, tblName string) {
func (cache *foreignKeyCache) AddEditor(editor *plan.ForeignKeyEditor, dbName, schemaName, tblName string) {
if editor == nil {
panic("cannot pass in nil editor") // Should never be hit
}
fkTableName := foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}
fkTableName := newForeignKeyTableName(dbName, schemaName, tblName)
cache.editorsCache[fkTableName] = append(cache.editorsCache[fkTableName], editor)
}

// GetUpdater returns the given foreign key table updater.
func (cache *foreignKeyCache) GetUpdater(ctx *sql.Context, a *Analyzer, dbName string, tblName string) (sql.ForeignKeyTable, sql.ForeignKeyEditor, error) {
fkTableName := foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}
func (cache *foreignKeyCache) GetUpdater(ctx *sql.Context, a *Analyzer, dbName, schemaName, tblName string) (sql.ForeignKeyTable, sql.ForeignKeyEditor, error) {
fkTableName := newForeignKeyTableName(dbName, schemaName, tblName)
if fkTblEditor, ok := cache.updaterCache[fkTableName]; ok {
return fkTblEditor.tbl, fkTblEditor.updater, nil
}
tbl, _, err := a.Catalog.Table(ctx, dbName, tblName)
tbl, _, err := a.Catalog.TableSchema(ctx, dbName, schemaName, tblName)
if err != nil {
return nil, nil, err
}
Expand All @@ -514,11 +517,8 @@ func (cache *foreignKeyCache) GetUpdater(ctx *sql.Context, a *Analyzer, dbName s

// GetEditor returns a foreign key editor that matches the given editor in all ways except for the referential actions.
// Returns nil if no such editors have been cached.
func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName string, tblName string) *plan.ForeignKeyEditor {
fkTableName := foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}
func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName, schemaName, tblName string) *plan.ForeignKeyEditor {
fkTableName := newForeignKeyTableName(dbName, schemaName, tblName)
// It is safe to assume that the index and schema will match for a table that has the same name on the same database,
// so we only need to check that the references match. As long as they refer to the same foreign key, they should
// match, so we only need to check the names.
Expand Down Expand Up @@ -562,7 +562,7 @@ func newForeignKeyChain() foreignKeyChain {
}

// AddTable returns a new chain with the added table.
func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyChain {
func (chain foreignKeyChain) AddTable(dbName string, schemaName, tblName string) foreignKeyChain {
newFkNames := make(map[string]struct{})
newFkTables := make(map[foreignKeyTableName]struct{})
for fkName := range chain.fkNames {
Expand All @@ -571,7 +571,7 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC
for fkTable := range chain.fkTables {
newFkTables[fkTable] = struct{}{}
}
newFkTables[newForeignKeyTableName(dbName, tblName)] = struct{}{}
newFkTables[newForeignKeyTableName(dbName, schemaName, tblName)] = struct{}{}
return foreignKeyChain{
fkNames: newFkNames,
fkTables: newFkTables,
Expand All @@ -580,8 +580,8 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC
}

// AddTableUpdater returns a new chain with the added foreign key updater.
func (chain foreignKeyChain) AddTableUpdater(dbName string, tblName string, fkUpdater sql.ForeignKeyEditor) foreignKeyChain {
chain.fkUpdaters[newForeignKeyTableName(dbName, tblName)] = fkUpdater
func (chain foreignKeyChain) AddTableUpdater(dbName, schemaName, tblName string, fkUpdater sql.ForeignKeyEditor) foreignKeyChain {
chain.fkUpdaters[newForeignKeyTableName(dbName, schemaName, tblName)] = fkUpdater
return chain
}

Expand All @@ -604,8 +604,8 @@ func (chain foreignKeyChain) AddForeignKey(fkName string) foreignKeyChain {
}

// HasTable returns whether the chain contains the given table. Case-insensitive.
func (chain foreignKeyChain) HasTable(dbName string, tblName string) bool {
_, ok := chain.fkTables[newForeignKeyTableName(dbName, tblName)]
func (chain foreignKeyChain) HasTable(dbName, schemaName, tblName string) bool {
_, ok := chain.fkTables[newForeignKeyTableName(dbName, schemaName, tblName)]
return ok
}

Expand Down
28 changes: 28 additions & 0 deletions sql/analyzer/catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,34 @@ func (c *Catalog) Table(ctx *sql.Context, dbName, tableName string) (sql.Table,
return c.DatabaseTable(ctx, db, tableName)
}

// TableSchema returns the table in the given database with the given name, in the given schema name
func (c *Catalog) TableSchema(ctx *sql.Context, dbName, schemaName, tableName string) (sql.Table, sql.Database, error) {
c.mu.RLock()
defer c.mu.RUnlock()

db, err := c.Database(ctx, dbName)
if err != nil {
return nil, nil, err
}

if schemaName != "" {
sdb, ok := db.(sql.SchemaDatabase)
if !ok {
return nil, nil, sql.ErrDatabaseSchemasNotSupported.New(db.Name())
}

db, ok, err = sdb.GetSchema(ctx, schemaName)
if err != nil {
return nil, nil, err
}
if !ok {
return nil, nil, sql.ErrDatabaseSchemaNotFound.New(schemaName)
}
}

return c.DatabaseTable(ctx, db, tableName)
}

func (c *Catalog) DatabaseTable(ctx *sql.Context, db sql.Database, tableName string) (sql.Table, sql.Database, error) {
_, ok := db.(sql.UnresolvedDatabase)
if ok {
Expand Down
33 changes: 24 additions & 9 deletions sql/constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,36 @@ func (f ForeignKeyReferentialAction) IsEquivalentToRestrict() bool {

// ForeignKeyConstraint declares a constraint between the columns of two tables.
type ForeignKeyConstraint struct {
Name string
Database string
Table string
Columns []string
// Name is the name of the foreign key constraint
Name string
// Database is the name of the database of the table with the constraint
Database string
// SchemaName is the name of the schema of the table, for databases that support schemas.
SchemaName string
// Table is the name of the table with the constraint
Table string
// Columns is the list of columns in the table that are part of the foreign key
Columns []string
// ParentDatabase is the name of the database of the parent table
ParentDatabase string
ParentTable string
ParentColumns []string
OnUpdate ForeignKeyReferentialAction
OnDelete ForeignKeyReferentialAction
IsResolved bool
// ParentSchema is the name of the schema of the parent table, for databases that support schemas.
ParentSchema string
// ParentTable is the name of the parent table
ParentTable string
// ParentColumns is the list of columns in the parent table that are part of the foreign key
ParentColumns []string
// OnUpdate is the action to take when the constraint is violated when a row in the parent table is updated
OnUpdate ForeignKeyReferentialAction
// OnDelete is the action to take when the constraint is violated when a row in the parent table is deleted
OnDelete ForeignKeyReferentialAction
// IsResolved is true if the foreign key has been resolved, false otherwise
IsResolved bool
}

// IsSelfReferential returns whether this foreign key represents a self-referential foreign key.
func (f *ForeignKeyConstraint) IsSelfReferential() bool {
return strings.EqualFold(f.Database, f.ParentDatabase) &&
strings.EqualFold(f.SchemaName, f.ParentSchema) &&
strings.EqualFold(f.Table, f.ParentTable)
}

Expand Down
4 changes: 0 additions & 4 deletions sql/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ type SchemaDatabase interface {
CreateSchema(ctx *Context, schemaName string) error
// AllSchemas returns all schemas in the database.
AllSchemas(ctx *Context) ([]DatabaseSchema, error)
// // GetTable returns the table with the name given in the schema given. The schema name may be empty.
// GetTable(ctx *Context, schemaName, tableName string) (Table, bool, error)
// // GetTableAsOf returns the table with the name given in the schema given. The schema name may be empty.
// GetTableAsOf(ctx *Context, schemaName, tableName string, asOf interface{}) (Table, bool, error)
}

// DatabaseSchema is a schema that can be queried for tables. It is functionally equivalent to a Database
Expand Down
13 changes: 13 additions & 0 deletions sql/planbuilder/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,12 @@ func (b *Builder) buildAlterConstraint(inScope *scope, ddl *ast.DDL, table *plan
case *sql.ForeignKeyConstraint:
c.Database = table.SqlDatabase.Name()
c.Table = table.Name()

ds, ok := table.SqlDatabase.(sql.DatabaseSchema)
if ok {
c.SchemaName = ds.SchemaName()
}

alterFk := plan.NewAlterAddForeignKey(c)
alterFk.DbProvider = b.cat
outScope.node = alterFk
Expand Down Expand Up @@ -631,6 +637,12 @@ func (b *Builder) buildAlterConstraint(inScope *scope, ddl *ast.DDL, table *plan
b.handleErr(err)
}
database := table.SqlDatabase.Name()

ds, ok := table.SqlDatabase.(sql.DatabaseSchema)
if ok {
c.SchemaName = ds.SchemaName()
}

dropFk := plan.NewAlterRenameForeignKey(database, table.Name(), c.Name, cc.Name)
dropFk.DbProvider = b.cat
outScope.node = dropFk
Expand Down Expand Up @@ -760,6 +772,7 @@ func (b *Builder) convertConstraintDefinition(inScope *scope, cd *ast.Constraint
Name: cd.Name,
Columns: columns,
ParentDatabase: refDatabase,
ParentSchema: fkConstraint.ReferencedTable.SchemaQualifier.String(),
ParentTable: fkConstraint.ReferencedTable.Name.String(),
ParentColumns: refColumns,
OnUpdate: b.buildReferentialAction(fkConstraint.OnUpdate),
Expand Down
Loading