Skip to content

Commit ee04255

Browse files
author
James Cor
committed
fix and tests
1 parent f5ba0ac commit ee04255

File tree

2 files changed

+102
-35
lines changed

2 files changed

+102
-35
lines changed

enginetest/queries/foreign_key_queries.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2616,6 +2616,74 @@ var ForeignKeyTests = []ScriptTest{
26162616
},
26172617
},
26182618
},
2619+
{
2620+
Name: "multiple foreign key refs",
2621+
SetUpScript: []string{
2622+
"create table parent1 (i int primary key);",
2623+
"create table child1 (j int, k int, foreign key (j) references parent1(i) on delete cascade on update cascade, foreign key (k) references parent1 (i) on delete cascade on update cascade);",
2624+
"insert into parent1 values (1), (2), (3);",
2625+
"insert into child1 values (1, 2), (2, 3), (3, 1);",
2626+
},
2627+
Assertions: []ScriptTestAssertion{
2628+
{
2629+
Query: "select * from parent1;",
2630+
Expected: []sql.Row{
2631+
{1},
2632+
{2},
2633+
{3},
2634+
},
2635+
},
2636+
{
2637+
Query: "select * from child1 order by j, k;",
2638+
Expected: []sql.Row{
2639+
{1, 2},
2640+
{2, 3},
2641+
{3, 1},
2642+
},
2643+
},
2644+
{
2645+
Query: "update parent1 set i = 20 where i = 2;",
2646+
Expected: []sql.Row{
2647+
{types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}},
2648+
},
2649+
},
2650+
{
2651+
Query: "select * from parent1 order by i;",
2652+
Expected: []sql.Row{
2653+
{1},
2654+
{3},
2655+
{20},
2656+
},
2657+
},
2658+
{
2659+
Query: "select * from child1 order by j, k;",
2660+
Expected: []sql.Row{
2661+
{1, 20},
2662+
{3, 1},
2663+
{20, 3},
2664+
},
2665+
},
2666+
{
2667+
Query: "delete from parent1 where i = 1;",
2668+
Expected: []sql.Row{
2669+
{types.OkResult{RowsAffected: 1}},
2670+
},
2671+
},
2672+
{
2673+
Query: "select * from parent1;",
2674+
Expected: []sql.Row{
2675+
{3},
2676+
{20},
2677+
},
2678+
},
2679+
{
2680+
Query: "select * from child1 order by j, k;",
2681+
Expected: []sql.Row{
2682+
{20, 3},
2683+
},
2684+
},
2685+
},
2686+
},
26192687
}
26202688

26212689
var CreateForeignKeyTests = []ScriptTest{

sql/analyzer/apply_foreign_keys.go

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ func applyForeignKeys(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sco
4040
// and caching of table editors.
4141
func applyForeignKeysToNodes(ctx *sql.Context, a *Analyzer, n sql.Node, cache *foreignKeyCache) (sql.Node, transform.TreeIdentity, error) {
4242
var err error
43-
fkChain := foreignKeyChain{
44-
fkUpdate: make(map[foreignKeyTableName]sql.ForeignKeyEditor),
45-
}
43+
fkChain := newForeignKeyChain()
4644

4745
switch n := n.(type) {
4846
case *plan.CreateTable:
@@ -242,7 +240,7 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
242240
if err != nil {
243241
return nil, err
244242
}
245-
fkChain = fkChain.AddTableUpdater(fks[0].Database, fks[0].Table, updater)
243+
fkChain = fkChain.AddTable(fks[0].Database, fks[0].Table).AddTableUpdater(fks[0].Database, fks[0].Table, updater)
246244

247245
tblSch := tbl.Schema()
248246
fkEditor := &plan.ForeignKeyEditor{
@@ -356,6 +354,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
356354
// If either referential action is not equivalent to RESTRICT, then the updater has the possibility of having
357355
// its contents modified, therefore we add it to the chain.
358356
if !fk.OnUpdate.IsEquivalentToRestrict() || !fk.OnDelete.IsEquivalentToRestrict() {
357+
// TODO: why would I add the updater without the table here?
359358
fkChain = fkChain.AddTableUpdater(fk.Database, fk.Table, childUpdater)
360359
}
361360

@@ -387,6 +386,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
387386
return nil, err
388387
}
389388

389+
// TODO: does this need to be a deep copy or no???
390390
fkChain = fkChain.AddForeignKey(fk.Name)
391391
childEditor, err := getForeignKeyEditor(ctx, a, childTbl, cache, fkChain, checkRows)
392392
if err != nil {
@@ -431,6 +431,13 @@ type foreignKeyTableName struct {
431431
tblName string
432432
}
433433

434+
func newForeignKeyTableName(dbName, tblName string) foreignKeyTableName {
435+
return foreignKeyTableName{
436+
dbName: strings.ToLower(dbName),
437+
tblName: strings.ToLower(tblName),
438+
}
439+
}
440+
434441
// foreignKeyTableUpdater is a foreign key table along with its updater.
435442
type foreignKeyTableUpdater struct {
436443
tbl sql.ForeignKeyTable
@@ -544,9 +551,17 @@ func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName
544551
// updaters that are not a part of this chain. In addition, any updaters that cannot be modified (such as those
545552
// belonging to strictly RESTRICT referential actions) will not appear in the chain.
546553
type foreignKeyChain struct {
547-
fkNames map[string]struct{}
548-
fkTables map[foreignKeyTableName]struct{}
549-
fkUpdate map[foreignKeyTableName]sql.ForeignKeyEditor // TODO: why is this even a map?
554+
fkNames map[string]struct{}
555+
fkTables map[foreignKeyTableName]struct{}
556+
fkUpdaters map[foreignKeyTableName]sql.ForeignKeyEditor
557+
}
558+
559+
func newForeignKeyChain() foreignKeyChain {
560+
return foreignKeyChain{
561+
fkNames: make(map[string]struct{}),
562+
fkTables: make(map[foreignKeyTableName]struct{}),
563+
fkUpdaters: make(map[foreignKeyTableName]sql.ForeignKeyEditor),
564+
}
550565
}
551566

552567
// AddTable returns a new chain with the added table.
@@ -556,26 +571,17 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC
556571
for fkName := range chain.fkNames {
557572
newFkNames[fkName] = struct{}{}
558573
}
559-
for fkTable := range chain.fkTables {
560-
newFkTables[fkTable] = struct{}{}
561-
}
562-
newFkTables[foreignKeyTableName{
563-
dbName: strings.ToLower(dbName),
564-
tblName: strings.ToLower(tblName),
565-
}] = struct{}{}
574+
newFkTables[newForeignKeyTableName(dbName, tblName)] = struct{}{}
566575
return foreignKeyChain{
567-
fkNames: newFkNames,
568-
fkTables: newFkTables,
569-
fkUpdate: chain.fkUpdate,
576+
fkNames: newFkNames,
577+
fkTables: newFkTables,
578+
fkUpdaters: chain.fkUpdaters,
570579
}
571580
}
572581

573582
// AddTableUpdater returns a new chain with the added foreign key updater.
574583
func (chain foreignKeyChain) AddTableUpdater(dbName string, tblName string, fkUpdater sql.ForeignKeyEditor) foreignKeyChain {
575-
chain.fkUpdate[foreignKeyTableName{
576-
dbName: strings.ToLower(dbName),
577-
tblName: strings.ToLower(tblName),
578-
}] = fkUpdater
584+
chain.fkUpdaters[newForeignKeyTableName(dbName, tblName)] = fkUpdater
579585
return chain
580586
}
581587

@@ -593,33 +599,26 @@ func (chain foreignKeyChain) AddForeignKey(fkName string) foreignKeyChain {
593599
return foreignKeyChain{
594600
fkNames: newFkNames,
595601
fkTables: newFkTables,
596-
fkUpdate: chain.fkUpdate,
602+
fkUpdaters: chain.fkUpdaters,
597603
}
598604
}
599605

600606
// HasTable returns whether the chain contains the given table. Case-insensitive.
601607
func (chain foreignKeyChain) HasTable(dbName string, tblName string) bool {
602-
if _, ok := chain.fkTables[foreignKeyTableName{
603-
dbName: strings.ToLower(dbName),
604-
tblName: strings.ToLower(tblName),
605-
}]; ok {
606-
return true
607-
}
608-
return false
608+
_, ok := chain.fkTables[newForeignKeyTableName(dbName, tblName)]
609+
return ok
609610
}
610611

611612
// HasForeignKey returns whether the chain contains the given foreign key. Case-insensitive.
612613
func (chain foreignKeyChain) HasForeignKey(fkName string) bool {
613-
if _, ok := chain.fkNames[strings.ToLower(fkName)]; ok {
614-
return true
615-
}
616-
return false
614+
_, ok := chain.fkNames[strings.ToLower(fkName)]
615+
return ok
617616
}
618617

619618
// GetUpdaters returns all foreign key updaters that have been added to the chain.
620619
func (chain foreignKeyChain) GetUpdaters() []sql.ForeignKeyEditor {
621-
updaters := make([]sql.ForeignKeyEditor, 0, len(chain.fkUpdate))
622-
for _, updater := range chain.fkUpdate {
620+
updaters := make([]sql.ForeignKeyEditor, 0, len(chain.fkUpdaters))
621+
for _, updater := range chain.fkUpdaters {
623622
updaters = append(updaters, updater)
624623
}
625624
return updaters

0 commit comments

Comments
 (0)