@@ -426,6 +426,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
426426 return nil
427427 }
428428 }
429+
429430 return sql .ErrForeignKeyChildViolation .New (reference .ForeignKey .Name , reference .ForeignKey .Table ,
430431 reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
431432}
@@ -458,7 +459,7 @@ type ForeignKeyRowMapper struct {
458459 SourceSch sql.Schema
459460 // TargetTypeTransforms are a set of functions to transform the value in the table to the corresponding value in the
460461 // other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
461- TargetTypeTransforms []func (ctx * sql.Context , val any ) (any , error )
462+ TargetTypeTransforms []func (ctx * sql.Context , val any ) (sql. Type , any , error )
462463 // IndexPositions hold the mapping between an index's column position and the source row's column position. Given
463464 // an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
464465 // [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -486,17 +487,18 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
486487 return sql .RowsToRowIter (), nil
487488 }
488489
490+ targetType := mapper .SourceSch [rowPos ].Type
489491 // Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
490492 if mapper .TargetTypeTransforms != nil && mapper .TargetTypeTransforms [rowPos ] != nil {
491493 var err error
492- rowVal , err = mapper .TargetTypeTransforms [rowPos ](ctx , rowVal )
494+ targetType , rowVal , err = mapper .TargetTypeTransforms [rowPos ](ctx , rowVal )
493495 // TODO: possible for this to fail without error, which means the value cannot be found in the other table
494496 if err != nil {
495497 return nil , err
496498 }
497499 }
498500
499- rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , mapper . SourceSch [ rowPos ]. Type )
501+ rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , targetType )
500502 }
501503 for i , appendType := range mapper .AppendTypes {
502504 rang [i + len (mapper .IndexPositions )] = sql .AllRangeColumnExpr (appendType )
@@ -565,7 +567,7 @@ func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.
565567
566568// GetChildParentTypeTransforms returs a set of functions to transform the value in the child table to the
567569// corresponding type in the parent table, if necessary
568- func GetChildParentTypeTransforms (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) ([]func (ctx * sql.Context , val any ) (any , error ), error ) {
570+ func GetChildParentTypeTransforms (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) ([]func (ctx * sql.Context , val any ) (sql. Type , any , error ), error ) {
569571 parentMap := make (map [string ]int )
570572 for i , col := range parentSch {
571573 parentMap [strings .ToLower (col .Name )] = i
@@ -575,7 +577,7 @@ func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkD
575577 childMap [strings .ToLower (col .Name )] = i
576578 }
577579
578- var mapping []func (* sql.Context , any ) (any , error )
580+ var mapping []func (* sql.Context , any ) (sql. Type , any , error )
579581
580582 for i := range fkDef .Columns {
581583 childIndex , ok := childMap [strings .ToLower (fkDef .Columns [i ])]
@@ -601,10 +603,11 @@ func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkD
601603 if ! childType .Equals (parentType ) {
602604 parentExtendedType , _ := parentType .(types.ExtendedType )
603605 if mapping == nil {
604- mapping = make ([]func (* sql.Context , any ) (any , error ), len (childSch ))
606+ mapping = make ([]func (* sql.Context , any ) (sql. Type , any , error ), len (childSch ))
605607 }
606- mapping [childIndex ] = func (ctx * sql.Context , val any ) (any , error ) {
607- return parentExtendedType .ConvertToType (ctx , childExtendedType , val )
608+ mapping [childIndex ] = func (ctx * sql.Context , val any ) (sql.Type , any , error ) {
609+ convertedVal , err := parentExtendedType .ConvertToType (ctx , childExtendedType , val )
610+ return parentExtendedType , convertedVal , err
608611 }
609612 }
610613 }
0 commit comments