Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions enginetest/queries/foreign_key_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2616,6 +2616,74 @@ var ForeignKeyTests = []ScriptTest{
},
},
},
{
Name: "multiple foreign key refs",
SetUpScript: []string{
"create table parent1 (i int primary key);",
"create table child1 (j int, k int, foreign key (j) references parent1(i) on delete cascade on update cascade, foreign key (k) references parent1 (i) on delete cascade on update cascade);",
"insert into parent1 values (1), (2), (3);",
"insert into child1 values (1, 2), (2, 3), (3, 1);",
},
Assertions: []ScriptTestAssertion{
{
Query: "select * from parent1;",
Expected: []sql.Row{
{1},
{2},
{3},
},
},
{
Query: "select * from child1 order by j, k;",
Expected: []sql.Row{
{1, 2},
{2, 3},
{3, 1},
},
},
{
Query: "update parent1 set i = 20 where i = 2;",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}},
},
},
{
Query: "select * from parent1 order by i;",
Expected: []sql.Row{
{1},
{3},
{20},
},
},
{
Query: "select * from child1 order by j, k;",
Expected: []sql.Row{
{1, 20},
{3, 1},
{20, 3},
},
},
{
Query: "delete from parent1 where i = 1;",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 1}},
},
},
{
Query: "select * from parent1;",
Expected: []sql.Row{
{3},
{20},
},
},
{
Query: "select * from child1 order by j, k;",
Expected: []sql.Row{
{20, 3},
},
},
},
},
}

var CreateForeignKeyTests = []ScriptTest{
Expand Down
6 changes: 3 additions & 3 deletions enginetest/queries/update_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,9 @@ var UpdateIgnoreScripts = []ScriptTest{
Name: "UPDATE IGNORE with foreign keys",
SetUpScript: []string{
"CREATE TABLE colors ( id INT NOT NULL, color VARCHAR(32) NOT NULL, PRIMARY KEY (id), INDEX color_index(color));",
"CREATE TABLE objects (id INT NOT NULL, name VARCHAR(64) NOT NULL,color VARCHAR(32), PRIMARY KEY(id),FOREIGN KEY (color) REFERENCES colors(color))",
"INSERT INTO colors (id,color) VALUES (1,'red'),(2,'green'),(3,'blue'),(4,'purple')",
"INSERT INTO objects (id,name,color) VALUES (1,'truck','red'),(2,'ball','green'),(3,'shoe','blue')",
"CREATE TABLE objects (id INT NOT NULL, name VARCHAR(64) NOT NULL,color VARCHAR(32), PRIMARY KEY(id),FOREIGN KEY (color) REFERENCES colors(color));",
"INSERT INTO colors (id,color) VALUES (1,'red'),(2,'green'),(3,'blue'),(4,'purple');",
"INSERT INTO objects (id,name,color) VALUES (1,'truck','red'),(2,'ball','green'),(3,'shoe','blue');",
},
Assertions: []ScriptTestAssertion{
{
Expand Down
80 changes: 41 additions & 39 deletions sql/analyzer/apply_foreign_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ func applyForeignKeys(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sco
// and caching of table editors.
func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *foreignKeyCache) (sql.Node, transform.TreeIdentity, error) {
var err error
fkChain := foreignKeyChain{
fkUpdate: make(map[foreignKeyTableName]sql.ForeignKeyEditor),
}
fkChain := newForeignKeyChain()

switch n := n.(type) {
case *plan.CreateTable:
Expand Down Expand Up @@ -209,7 +207,11 @@ func getForeignKeyEditor(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTable,
if err != nil {
return nil, err
}
return getForeignKeyRefActions(ctx, a, tbl, cache, fkChain, fkEditor, checkRows)
fkEditor, err = getForeignKeyRefActions(ctx, a, tbl, cache, fkChain, fkEditor, checkRows)
if err != nil {
return nil, err
}
return fkEditor, err
}

// getForeignKeyReferences returns an editor containing only the references for the given table.
Expand Down Expand Up @@ -238,7 +240,7 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
if err != nil {
return nil, err
}
fkChain = fkChain.AddTable(fks[0].ParentDatabase, fks[0].ParentTable).AddTableUpdater(fks[0].ParentDatabase, fks[0].ParentTable, updater)
fkChain = fkChain.AddTable(fks[0].Database, fks[0].Table).AddTableUpdater(fks[0].Database, fks[0].Table, updater)

tblSch := tbl.Schema()
fkEditor := &plan.ForeignKeyEditor{
Expand Down Expand Up @@ -383,7 +385,8 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
return nil, err
}

childEditor, err := getForeignKeyEditor(ctx, a, childTbl, cache, fkChain.AddForeignKey(fk.Name), checkRows)
fkChain = fkChain.AddForeignKey(fk.Name)
childEditor, err := getForeignKeyEditor(ctx, a, childTbl, cache, fkChain, checkRows)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -426,6 +429,13 @@ type foreignKeyTableName struct {
tblName string
}

func newForeignKeyTableName(dbName, tblName string) foreignKeyTableName {
return foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}
}

// foreignKeyTableUpdater is a foreign key table along with its updater.
type foreignKeyTableUpdater struct {
tbl sql.ForeignKeyTable
Expand Down Expand Up @@ -539,9 +549,17 @@ func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName
// updaters that are not a part of this chain. In addition, any updaters that cannot be modified (such as those
// belonging to strictly RESTRICT referential actions) will not appear in the chain.
type foreignKeyChain struct {
fkNames map[string]struct{}
fkTables map[foreignKeyTableName]struct{}
fkUpdate map[foreignKeyTableName]sql.ForeignKeyEditor
fkNames map[string]struct{}
fkTables map[foreignKeyTableName]struct{}
fkUpdaters map[foreignKeyTableName]sql.ForeignKeyEditor
}

func newForeignKeyChain() foreignKeyChain {
return foreignKeyChain{
fkNames: make(map[string]struct{}),
fkTables: make(map[foreignKeyTableName]struct{}),
fkUpdaters: make(map[foreignKeyTableName]sql.ForeignKeyEditor),
}
}

// AddTable returns a new chain with the added table.
Expand All @@ -551,26 +569,17 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC
for fkName := range chain.fkNames {
newFkNames[fkName] = struct{}{}
}
for fkTable := range chain.fkTables {
newFkTables[fkTable] = struct{}{}
}
newFkTables[foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}] = struct{}{}
newFkTables[newForeignKeyTableName(dbName, tblName)] = struct{}{}
return foreignKeyChain{
fkNames: newFkNames,
fkTables: newFkTables,
fkUpdate: chain.fkUpdate,
fkNames: newFkNames,
fkTables: newFkTables,
fkUpdaters: chain.fkUpdaters,
}
}

// AddTableUpdater returns a new chain with the added foreign key updater.
func (chain foreignKeyChain) AddTableUpdater(dbName string, tblName string, fkUpdater sql.ForeignKeyEditor) foreignKeyChain {
chain.fkUpdate[foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}] = fkUpdater
chain.fkUpdaters[newForeignKeyTableName(dbName, tblName)] = fkUpdater
return chain
}

Expand All @@ -586,35 +595,28 @@ func (chain foreignKeyChain) AddForeignKey(fkName string) foreignKeyChain {
}
newFkNames[strings.ToLower(fkName)] = struct{}{}
return foreignKeyChain{
fkNames: newFkNames,
fkTables: newFkTables,
fkUpdate: chain.fkUpdate,
fkNames: newFkNames,
fkTables: newFkTables,
fkUpdaters: chain.fkUpdaters,
}
}

// HasTable returns whether the chain contains the given table. Case-insensitive.
func (chain foreignKeyChain) HasTable(dbName string, tblName string) bool {
if _, ok := chain.fkTables[foreignKeyTableName{
dbName: strings.ToLower(dbName),
tblName: strings.ToLower(tblName),
}]; ok {
return true
}
return false
_, ok := chain.fkTables[newForeignKeyTableName(dbName, tblName)]
return ok
}

// HasForeignKey returns whether the chain contains the given foreign key. Case-insensitive.
func (chain foreignKeyChain) HasForeignKey(fkName string) bool {
if _, ok := chain.fkNames[strings.ToLower(fkName)]; ok {
return true
}
return false
_, ok := chain.fkNames[strings.ToLower(fkName)]
return ok
}

// GetUpdaters returns all foreign key updaters that have been added to the chain.
func (chain foreignKeyChain) GetUpdaters() []sql.ForeignKeyEditor {
updaters := make([]sql.ForeignKeyEditor, 0, len(chain.fkUpdate))
for _, updater := range chain.fkUpdate {
updaters := make([]sql.ForeignKeyEditor, 0, len(chain.fkUpdaters))
for _, updater := range chain.fkUpdaters {
updaters = append(updaters, updater)
}
return updaters
Expand Down