@@ -426,6 +426,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
426
426
return nil
427
427
}
428
428
}
429
+
429
430
return sql .ErrForeignKeyChildViolation .New (reference .ForeignKey .Name , reference .ForeignKey .Table ,
430
431
reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
431
432
}
@@ -458,7 +459,7 @@ type ForeignKeyRowMapper struct {
458
459
SourceSch sql.Schema
459
460
// TargetTypeTransforms are a set of functions to transform the value in the table to the corresponding value in the
460
461
// other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
461
- TargetTypeTransforms []func (ctx * sql.Context , val any ) (any , error )
462
+ TargetTypeTransforms []func (ctx * sql.Context , val any ) (sql. Type , any , error )
462
463
// IndexPositions hold the mapping between an index's column position and the source row's column position. Given
463
464
// an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
464
465
// [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -486,17 +487,18 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
486
487
return sql .RowsToRowIter (), nil
487
488
}
488
489
490
+ targetType := mapper .SourceSch [rowPos ].Type
489
491
// Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
490
492
if mapper .TargetTypeTransforms != nil && mapper .TargetTypeTransforms [rowPos ] != nil {
491
493
var err error
492
- rowVal , err = mapper .TargetTypeTransforms [rowPos ](ctx , rowVal )
494
+ targetType , rowVal , err = mapper .TargetTypeTransforms [rowPos ](ctx , rowVal )
493
495
// TODO: possible for this to fail without error, which means the value cannot be found in the other table
494
496
if err != nil {
495
497
return nil , err
496
498
}
497
499
}
498
500
499
- rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , mapper . SourceSch [ rowPos ]. Type )
501
+ rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , targetType )
500
502
}
501
503
for i , appendType := range mapper .AppendTypes {
502
504
rang [i + len (mapper .IndexPositions )] = sql .AllRangeColumnExpr (appendType )
@@ -565,7 +567,7 @@ func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.
565
567
566
568
// GetChildParentTypeTransforms returs a set of functions to transform the value in the child table to the
567
569
// corresponding type in the parent table, if necessary
568
- func GetChildParentTypeTransforms (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) ([]func (ctx * sql.Context , val any ) (any , error ), error ) {
570
+ func GetChildParentTypeTransforms (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) ([]func (ctx * sql.Context , val any ) (sql. Type , any , error ), error ) {
569
571
parentMap := make (map [string ]int )
570
572
for i , col := range parentSch {
571
573
parentMap [strings .ToLower (col .Name )] = i
@@ -575,7 +577,7 @@ func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkD
575
577
childMap [strings .ToLower (col .Name )] = i
576
578
}
577
579
578
- var mapping []func (* sql.Context , any ) (any , error )
580
+ var mapping []func (* sql.Context , any ) (sql. Type , any , error )
579
581
580
582
for i := range fkDef .Columns {
581
583
childIndex , ok := childMap [strings .ToLower (fkDef .Columns [i ])]
@@ -601,10 +603,11 @@ func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkD
601
603
if ! childType .Equals (parentType ) {
602
604
parentExtendedType , _ := parentType .(types.ExtendedType )
603
605
if mapping == nil {
604
- mapping = make ([]func (* sql.Context , any ) (any , error ), len (childSch ))
606
+ mapping = make ([]func (* sql.Context , any ) (sql. Type , any , error ), len (childSch ))
605
607
}
606
- mapping [childIndex ] = func (ctx * sql.Context , val any ) (any , error ) {
607
- return parentExtendedType .ConvertToType (ctx , childExtendedType , val )
608
+ mapping [childIndex ] = func (ctx * sql.Context , val any ) (sql.Type , any , error ) {
609
+ convertedVal , err := parentExtendedType .ConvertToType (ctx , childExtendedType , val )
610
+ return parentExtendedType , convertedVal , err
608
611
}
609
612
}
610
613
}
0 commit comments