@@ -20,6 +20,7 @@ import (
20
20
"strings"
21
21
22
22
"github.com/dolthub/go-mysql-server/sql"
23
+ "github.com/dolthub/go-mysql-server/sql/types"
23
24
)
24
25
25
26
// ChildParentMapping is a mapping from the foreign key columns of a child schema to the parent schema. The position
@@ -455,6 +456,9 @@ type ForeignKeyRowMapper struct {
455
456
Index sql.Index
456
457
Updater sql.ForeignKeyEditor
457
458
SourceSch sql.Schema
459
+ // TargetTypeTransforms are a set of functions to transform the value in the table to the corresponding value in the
460
+ // other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
461
+ TargetTypeTransforms []func (ctx * sql.Context , val any ) (any , error )
458
462
// IndexPositions hold the mapping between an index's column position and the source row's column position. Given
459
463
// an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
460
464
// [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -481,6 +485,17 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
481
485
if rowVal == nil {
482
486
return sql .RowsToRowIter (), nil
483
487
}
488
+
489
+ // Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
490
+ if mapper .TargetTypeTransforms != nil && mapper .TargetTypeTransforms [rowPos ] != nil {
491
+ var err error
492
+ rowVal , err = mapper .TargetTypeTransforms [rowPos ](ctx , rowVal )
493
+ // TODO: possible for this to fail without error, which means the value cannot be found in the other table
494
+ if err != nil {
495
+ return nil , err
496
+ }
497
+ }
498
+
484
499
rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , mapper .SourceSch [rowPos ].Type )
485
500
}
486
501
for i , appendType := range mapper .AppendTypes {
@@ -547,3 +562,53 @@ func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.
547
562
}
548
563
return mapping , nil
549
564
}
565
+
566
+ // GetChildParentTypeTransforms returs a set of functions to transform the value in the child table to the
567
+ // corresponding type in the parent table, if necessary
568
+ func GetChildParentTypeTransforms (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) ([]func (ctx * sql.Context , val any ) (any , error ), error ) {
569
+
570
+ parentMap := make (map [string ]int )
571
+ for i , col := range parentSch {
572
+ parentMap [strings .ToLower (col .Name )] = i
573
+ }
574
+ childMap := make (map [string ]int )
575
+ for i , col := range childSch {
576
+ childMap [strings .ToLower (col .Name )] = i
577
+ }
578
+
579
+ var mapping []func (* sql.Context , any ) (any , error )
580
+
581
+ for i := range fkDef .Columns {
582
+ childIndex , ok := childMap [strings .ToLower (fkDef .Columns [i ])]
583
+ if ! ok {
584
+ return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on table `%s` but it could not be found" ,
585
+ fkDef .Name , fkDef .Columns [i ], fkDef .Table )
586
+ }
587
+ parentIndex , ok := parentMap [strings .ToLower (fkDef .ParentColumns [i ])]
588
+ if ! ok {
589
+ return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found" ,
590
+ fkDef .Name , fkDef .ParentColumns [i ], fkDef .ParentTable )
591
+ }
592
+
593
+ childType := childSch [childIndex ].Type
594
+ parentType := parentSch [parentIndex ].Type
595
+
596
+ childExtendedType , _ := childType .(types.ExtendedType )
597
+ // if even one of the types is not an extended type, then we can't transform any values
598
+ if childExtendedType == nil {
599
+ return nil , nil
600
+ }
601
+
602
+ if ! childType .Equals (parentType ) {
603
+ parentExtendedType , _ := parentType .(types.ExtendedType )
604
+ if mapping == nil {
605
+ mapping = make ([]func (* sql.Context , any ) (any , error ), len (childSch ))
606
+ }
607
+ mapping [childIndex ] = func (ctx * sql.Context , val any ) (any , error ) {
608
+ return parentExtendedType .ConvertToType (ctx , childExtendedType , val )
609
+ }
610
+ }
611
+ }
612
+
613
+ return mapping , nil
614
+ }
0 commit comments