@@ -457,9 +457,9 @@ type ForeignKeyRowMapper struct {
457457 Index sql.Index
458458 Updater sql.ForeignKeyEditor
459459 SourceSch sql.Schema
460- // TargetTypeTransforms are a set of functions to transform the value in the table to the corresponding value in the
460+ // TargetTypeConversions are a set of functions to transform the value in the table to the corresponding value in the
461461 // other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
462- TargetTypeTransforms []func ( ctx * sql. Context , val any ) (sql. Type , any , error )
462+ TargetTypeConversions []ForeignKeyTypeConversionFn
463463 // IndexPositions hold the mapping between an index's column position and the source row's column position. Given
464464 // an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
465465 // [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -489,9 +489,9 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
489489
490490 targetType := mapper .SourceSch [rowPos ].Type
491491 // Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
492- if mapper .TargetTypeTransforms != nil && mapper .TargetTypeTransforms [rowPos ] != nil {
492+ if mapper .TargetTypeConversions != nil && mapper .TargetTypeConversions [rowPos ] != nil {
493493 var err error
494- targetType , rowVal , err = mapper .TargetTypeTransforms [rowPos ](ctx , rowVal )
494+ targetType , rowVal , err = mapper .TargetTypeConversions [rowPos ](ctx , rowVal )
495495 // TODO: possible for this to fail without error, which means the value cannot be found in the other table
496496 if err != nil {
497497 return nil , err
@@ -537,26 +537,18 @@ func (mapper *ForeignKeyRowMapper) GetKeyString(row sql.Row) string {
537537
538538// GetChildParentMapping returns a mapping from the foreign key columns of a child schema to the parent schema.
539539func GetChildParentMapping (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) (ChildParentMapping , error ) {
540- parentMap := make (map [string ]int )
541- for i , col := range parentSch {
542- parentMap [strings .ToLower (col .Name )] = i
543- }
544- childMap := make (map [string ]int )
545- for i , col := range childSch {
546- childMap [strings .ToLower (col .Name )] = i
547- }
548540 mapping := make (ChildParentMapping , len (childSch ))
549541 for i := range mapping {
550542 mapping [i ] = - 1
551543 }
552544 for i := range fkDef .Columns {
553- childIndex , ok := childMap [ strings . ToLower (fkDef .Columns [i ])]
554- if ! ok {
545+ childIndex := childSch . IndexOfColName (fkDef .Columns [i ])
546+ if childIndex < 0 {
555547 return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on table `%s` but it could not be found" ,
556548 fkDef .Name , fkDef .Columns [i ], fkDef .Table )
557549 }
558- parentIndex , ok := parentMap [ strings . ToLower (fkDef .ParentColumns [i ])]
559- if ! ok {
550+ parentIndex := parentSch . IndexOfColName (fkDef .ParentColumns [i ])
551+ if parentIndex < 0 {
560552 return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found" ,
561553 fkDef .Name , fkDef .ParentColumns [i ], fkDef .ParentTable )
562554 }
@@ -565,28 +557,37 @@ func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.
565557 return mapping , nil
566558}
567559
568- // GetChildParentTypeTransforms returs a set of functions to transform the value in the child table to the
569- // corresponding type in the parent table, if necessary
570- func GetChildParentTypeTransforms (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) ([]func (ctx * sql.Context , val any ) (sql.Type , any , error ), error ) {
571- parentMap := make (map [string ]int )
572- for i , col := range parentSch {
573- parentMap [strings .ToLower (col .Name )] = i
574- }
575- childMap := make (map [string ]int )
576- for i , col := range childSch {
577- childMap [strings .ToLower (col .Name )] = i
578- }
560+ // ForeignKeyTypeConversionDirection specifies whether a child column type is being converted to its parent type for
561+ // constraint enforcement, or vice versa.
562+ type ForeignKeyTypeConversionDirection byte
563+ const (
564+ ChildToParent ForeignKeyTypeConversionDirection = iota
565+ ParentToChild
566+ )
579567
580- var mapping []func (* sql.Context , any ) (sql.Type , any , error )
568+ // ForeignKeyTypeConversionFn is a function that transforms a value from one type to another for foreign key constraint
569+ // enforcement. The target type is returned along with the transformed value, or an error if the transformation fails.
570+ type ForeignKeyTypeConversionFn func (ctx * sql.Context , val any ) (sql.Type , any , error )
571+
572+ // GetForeignKeyTypeConversions returns a set of functions to convert a type in a one foreign key column table to the
573+ // type in the corresponding table. Specify the schema of both child and parent tables, as well as whether the
574+ // transformation is from child to parent or vice versa.
575+ func GetForeignKeyTypeConversions (
576+ parentSch sql.Schema ,
577+ childSch sql.Schema ,
578+ fkDef sql.ForeignKeyConstraint ,
579+ direction ForeignKeyTypeConversionDirection ,
580+ ) ([]ForeignKeyTypeConversionFn , error ) {
581+ var mapping []ForeignKeyTypeConversionFn
581582
582583 for i := range fkDef .Columns {
583- childIndex , ok := childMap [ strings . ToLower (fkDef .Columns [i ])]
584- if ! ok {
584+ childIndex := childSch . IndexOfColName (fkDef .Columns [i ])
585+ if childIndex < 0 {
585586 return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on table `%s` but it could not be found" ,
586587 fkDef .Name , fkDef .Columns [i ], fkDef .Table )
587588 }
588- parentIndex , ok := parentMap [ strings . ToLower (fkDef .ParentColumns [i ])]
589- if ! ok {
589+ parentIndex := parentSch . IndexOfColName (fkDef .ParentColumns [i ])
590+ if parentIndex < 0 {
590591 return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found" ,
591592 fkDef .Name , fkDef .ParentColumns [i ], fkDef .ParentTable )
592593 }
@@ -602,12 +603,24 @@ func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkD
602603
603604 if ! childType .Equals (parentType ) {
604605 parentExtendedType , _ := parentType .(types.ExtendedType )
606+ if parentExtendedType == nil {
607+ // this should be impossible (child and parent should both be extended types), but just in case
608+ return nil , nil
609+ }
610+
611+ fromType := childExtendedType
612+ toType := parentExtendedType
613+ if direction == ParentToChild {
614+ fromType = parentExtendedType
615+ toType = childExtendedType
616+ }
617+
605618 if mapping == nil {
606- mapping = make ([]func ( * sql. Context , any ) (sql. Type , any , error ) , len (childSch ))
619+ mapping = make ([]ForeignKeyTypeConversionFn , len (childSch ))
607620 }
608621 mapping [childIndex ] = func (ctx * sql.Context , val any ) (sql.Type , any , error ) {
609- convertedVal , err := parentExtendedType .ConvertToType (ctx , childExtendedType , val )
610- return parentExtendedType , convertedVal , err
622+ convertedVal , err := toType .ConvertToType (ctx , fromType , val )
623+ return toType , convertedVal , err
611624 }
612625 }
613626 }
0 commit comments