@@ -40,9 +40,7 @@ func applyForeignKeys(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sco
4040// and caching of table editors.
4141func applyForeignKeysToNodes (ctx * sql.Context , a * Analyzer , n sql.Node , cache * foreignKeyCache ) (sql.Node , transform.TreeIdentity , error ) {
4242 var err error
43- fkChain := foreignKeyChain {
44- fkUpdate : make (map [foreignKeyTableName ]sql.ForeignKeyEditor ),
45- }
43+ fkChain := newForeignKeyChain ()
4644
4745 switch n := n .(type ) {
4846 case * plan.CreateTable :
@@ -242,7 +240,7 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
242240 if err != nil {
243241 return nil , err
244242 }
245- fkChain = fkChain .AddTableUpdater (fks [0 ].Database , fks [0 ].Table , updater )
243+ fkChain = fkChain .AddTable ( fks [ 0 ]. Database , fks [ 0 ]. Table ). AddTableUpdater (fks [0 ].Database , fks [0 ].Table , updater )
246244
247245 tblSch := tbl .Schema ()
248246 fkEditor := & plan.ForeignKeyEditor {
@@ -356,6 +354,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
356354 // If either referential action is not equivalent to RESTRICT, then the updater has the possibility of having
357355 // its contents modified, therefore we add it to the chain.
358356 if ! fk .OnUpdate .IsEquivalentToRestrict () || ! fk .OnDelete .IsEquivalentToRestrict () {
357+ // TODO: why would I add the updater without the table here?
359358 fkChain = fkChain .AddTableUpdater (fk .Database , fk .Table , childUpdater )
360359 }
361360
@@ -387,6 +386,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
387386 return nil , err
388387 }
389388
389+ // TODO: does this need to be a deep copy or no???
390390 fkChain = fkChain .AddForeignKey (fk .Name )
391391 childEditor , err := getForeignKeyEditor (ctx , a , childTbl , cache , fkChain , checkRows )
392392 if err != nil {
@@ -431,6 +431,13 @@ type foreignKeyTableName struct {
431431 tblName string
432432}
433433
434+ func newForeignKeyTableName (dbName , tblName string ) foreignKeyTableName {
435+ return foreignKeyTableName {
436+ dbName : strings .ToLower (dbName ),
437+ tblName : strings .ToLower (tblName ),
438+ }
439+ }
440+
434441// foreignKeyTableUpdater is a foreign key table along with its updater.
435442type foreignKeyTableUpdater struct {
436443 tbl sql.ForeignKeyTable
@@ -544,9 +551,17 @@ func (cache *foreignKeyCache) GetEditor(fkEditor *plan.ForeignKeyEditor, dbName
544551// updaters that are not a part of this chain. In addition, any updaters that cannot be modified (such as those
545552// belonging to strictly RESTRICT referential actions) will not appear in the chain.
546553type foreignKeyChain struct {
547- fkNames map [string ]struct {}
548- fkTables map [foreignKeyTableName ]struct {}
549- fkUpdate map [foreignKeyTableName ]sql.ForeignKeyEditor // TODO: why is this even a map?
554+ fkNames map [string ]struct {}
555+ fkTables map [foreignKeyTableName ]struct {}
556+ fkUpdaters map [foreignKeyTableName ]sql.ForeignKeyEditor
557+ }
558+
559+ func newForeignKeyChain () foreignKeyChain {
560+ return foreignKeyChain {
561+ fkNames : make (map [string ]struct {}),
562+ fkTables : make (map [foreignKeyTableName ]struct {}),
563+ fkUpdaters : make (map [foreignKeyTableName ]sql.ForeignKeyEditor ),
564+ }
550565}
551566
552567// AddTable returns a new chain with the added table.
@@ -556,26 +571,17 @@ func (chain foreignKeyChain) AddTable(dbName string, tblName string) foreignKeyC
556571 for fkName := range chain .fkNames {
557572 newFkNames [fkName ] = struct {}{}
558573 }
559- for fkTable := range chain .fkTables {
560- newFkTables [fkTable ] = struct {}{}
561- }
562- newFkTables [foreignKeyTableName {
563- dbName : strings .ToLower (dbName ),
564- tblName : strings .ToLower (tblName ),
565- }] = struct {}{}
574+ newFkTables [newForeignKeyTableName (dbName , tblName )] = struct {}{}
566575 return foreignKeyChain {
567- fkNames : newFkNames ,
568- fkTables : newFkTables ,
569- fkUpdate : chain .fkUpdate ,
576+ fkNames : newFkNames ,
577+ fkTables : newFkTables ,
578+ fkUpdaters : chain .fkUpdaters ,
570579 }
571580}
572581
573582// AddTableUpdater returns a new chain with the added foreign key updater.
574583func (chain foreignKeyChain ) AddTableUpdater (dbName string , tblName string , fkUpdater sql.ForeignKeyEditor ) foreignKeyChain {
575- chain .fkUpdate [foreignKeyTableName {
576- dbName : strings .ToLower (dbName ),
577- tblName : strings .ToLower (tblName ),
578- }] = fkUpdater
584+ chain .fkUpdaters [newForeignKeyTableName (dbName , tblName )] = fkUpdater
579585 return chain
580586}
581587
@@ -593,33 +599,26 @@ func (chain foreignKeyChain) AddForeignKey(fkName string) foreignKeyChain {
593599 return foreignKeyChain {
594600 fkNames : newFkNames ,
595601 fkTables : newFkTables ,
596- fkUpdate : chain .fkUpdate ,
602+ fkUpdaters : chain .fkUpdaters ,
597603 }
598604}
599605
600606// HasTable returns whether the chain contains the given table. Case-insensitive.
601607func (chain foreignKeyChain ) HasTable (dbName string , tblName string ) bool {
602- if _ , ok := chain .fkTables [foreignKeyTableName {
603- dbName : strings .ToLower (dbName ),
604- tblName : strings .ToLower (tblName ),
605- }]; ok {
606- return true
607- }
608- return false
608+ _ , ok := chain .fkTables [newForeignKeyTableName (dbName , tblName )]
609+ return ok
609610}
610611
611612// HasForeignKey returns whether the chain contains the given foreign key. Case-insensitive.
612613func (chain foreignKeyChain ) HasForeignKey (fkName string ) bool {
613- if _ , ok := chain .fkNames [strings .ToLower (fkName )]; ok {
614- return true
615- }
616- return false
614+ _ , ok := chain .fkNames [strings .ToLower (fkName )]
615+ return ok
617616}
618617
619618// GetUpdaters returns all foreign key updaters that have been added to the chain.
620619func (chain foreignKeyChain ) GetUpdaters () []sql.ForeignKeyEditor {
621- updaters := make ([]sql.ForeignKeyEditor , 0 , len (chain .fkUpdate ))
622- for _ , updater := range chain .fkUpdate {
620+ updaters := make ([]sql.ForeignKeyEditor , 0 , len (chain .fkUpdaters ))
621+ for _ , updater := range chain .fkUpdaters {
623622 updaters = append (updaters , updater )
624623 }
625624 return updaters
0 commit comments