@@ -20,6 +20,7 @@ import (
20
20
"strings"
21
21
22
22
"github.com/dolthub/go-mysql-server/sql"
23
+ "github.com/dolthub/go-mysql-server/sql/types"
23
24
)
24
25
25
26
// ChildParentMapping is a mapping from the foreign key columns of a child schema to the parent schema. The position
@@ -227,7 +228,7 @@ func (fkEditor *ForeignKeyEditor) OnUpdateSetNull(ctx *sql.Context, refActionDat
227
228
228
229
// Delete handles both the standard DELETE statement and propagated referential actions from a parent table's ON DELETE.
229
230
func (fkEditor * ForeignKeyEditor ) Delete (ctx * sql.Context , row sql.Row , depth int ) error {
230
- //TODO: may need to process some cascades after the update to avoid recursive violations, write some tests on this
231
+ // TODO: may need to process some cascades after the update to avoid recursive violations, write some tests on this
231
232
for _ , refActionData := range fkEditor .RefActions {
232
233
switch refActionData .ForeignKey .OnDelete {
233
234
default : // RESTRICT and friends
@@ -425,6 +426,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
425
426
return nil
426
427
}
427
428
}
429
+
428
430
return sql .ErrForeignKeyChildViolation .New (reference .ForeignKey .Name , reference .ForeignKey .Table ,
429
431
reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
430
432
}
@@ -455,6 +457,9 @@ type ForeignKeyRowMapper struct {
455
457
Index sql.Index
456
458
Updater sql.ForeignKeyEditor
457
459
SourceSch sql.Schema
460
+ // TargetTypeConversions are a set of functions to transform the value in the table to the corresponding value in the
461
+ // other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
462
+ TargetTypeConversions []ForeignKeyTypeConversionFn
458
463
// IndexPositions hold the mapping between an index's column position and the source row's column position. Given
459
464
// an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
460
465
// [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -481,7 +486,21 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
481
486
if rowVal == nil {
482
487
return sql .RowsToRowIter (), nil
483
488
}
484
- rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , mapper .SourceSch [rowPos ].Type )
489
+
490
+ targetType := mapper .SourceSch [rowPos ].Type
491
+ // Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
492
+ if mapper .TargetTypeConversions != nil && mapper .TargetTypeConversions [rowPos ] != nil {
493
+ var err error
494
+ targetType , rowVal , err = mapper .TargetTypeConversions [rowPos ](ctx , rowVal )
495
+ // An error means the type conversion failed, which typically means there's no way to convert the value given to
496
+ // the target value because of e.g. range constraints (trying to assign an INT to a TINYINT column). We treat
497
+ // this as an empty result for this iterator, since this value cannot possibly be present in the other table.
498
+ if err != nil {
499
+ return sql .RowsToRowIter (), nil
500
+ }
501
+ }
502
+
503
+ rang [rangPosition ] = sql .ClosedRangeColumnExpr (rowVal , rowVal , targetType )
485
504
}
486
505
for i , appendType := range mapper .AppendTypes {
487
506
rang [i + len (mapper .IndexPositions )] = sql .AllRangeColumnExpr (appendType )
@@ -490,7 +509,7 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
490
509
if ! mapper .Index .CanSupport (rang ) {
491
510
return nil , ErrInvalidLookupForIndexedTable .New (rang .DebugString ())
492
511
}
493
- //TODO: profile this, may need to redesign this or add a fast path
512
+ // TODO: profile this, may need to redesign this or add a fast path
494
513
lookup := sql.IndexLookup {Ranges : sql.MySQLRangeCollection {rang }, Index : mapper .Index }
495
514
496
515
editorData := mapper .Updater .IndexedAccess (lookup )
@@ -520,30 +539,94 @@ func (mapper *ForeignKeyRowMapper) GetKeyString(row sql.Row) string {
520
539
521
540
// GetChildParentMapping returns a mapping from the foreign key columns of a child schema to the parent schema.
522
541
func GetChildParentMapping (parentSch sql.Schema , childSch sql.Schema , fkDef sql.ForeignKeyConstraint ) (ChildParentMapping , error ) {
523
- parentMap := make (map [string ]int )
524
- for i , col := range parentSch {
525
- parentMap [strings .ToLower (col .Name )] = i
526
- }
527
- childMap := make (map [string ]int )
528
- for i , col := range childSch {
529
- childMap [strings .ToLower (col .Name )] = i
530
- }
531
542
mapping := make (ChildParentMapping , len (childSch ))
532
543
for i := range mapping {
533
544
mapping [i ] = - 1
534
545
}
535
546
for i := range fkDef .Columns {
536
- childIndex , ok := childMap [ strings . ToLower (fkDef .Columns [i ])]
537
- if ! ok {
547
+ childIndex := childSch . IndexOfColName (fkDef .Columns [i ])
548
+ if childIndex < 0 {
538
549
return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on table `%s` but it could not be found" ,
539
550
fkDef .Name , fkDef .Columns [i ], fkDef .Table )
540
551
}
541
- parentIndex , ok := parentMap [ strings . ToLower (fkDef .ParentColumns [i ])]
542
- if ! ok {
552
+ parentIndex := parentSch . IndexOfColName (fkDef .ParentColumns [i ])
553
+ if parentIndex < 0 {
543
554
return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found" ,
544
555
fkDef .Name , fkDef .ParentColumns [i ], fkDef .ParentTable )
545
556
}
546
557
mapping [childIndex ] = parentIndex
547
558
}
548
559
return mapping , nil
549
560
}
561
+
562
+ // ForeignKeyTypeConversionDirection specifies whether a child column type is being converted to its parent type for
563
+ // constraint enforcement, or vice versa.
564
+ type ForeignKeyTypeConversionDirection byte
565
+
566
+ const (
567
+ ChildToParent ForeignKeyTypeConversionDirection = iota
568
+ ParentToChild
569
+ )
570
+
571
+ // ForeignKeyTypeConversionFn is a function that transforms a value from one type to another for foreign key constraint
572
+ // enforcement. The target type is returned along with the transformed value, or an error if the transformation fails.
573
+ type ForeignKeyTypeConversionFn func (ctx * sql.Context , val any ) (sql.Type , any , error )
574
+
575
+ // GetForeignKeyTypeConversions returns a set of functions to convert a type in a one foreign key column table to the
576
+ // type in the corresponding table. Specify the schema of both child and parent tables, as well as whether the
577
+ // transformation is from child to parent or vice versa.
578
+ func GetForeignKeyTypeConversions (
579
+ parentSch sql.Schema ,
580
+ childSch sql.Schema ,
581
+ fkDef sql.ForeignKeyConstraint ,
582
+ direction ForeignKeyTypeConversionDirection ,
583
+ ) ([]ForeignKeyTypeConversionFn , error ) {
584
+ var convFns []ForeignKeyTypeConversionFn
585
+
586
+ for i := range fkDef .Columns {
587
+ childIndex := childSch .IndexOfColName (fkDef .Columns [i ])
588
+ if childIndex < 0 {
589
+ return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on table `%s` but it could not be found" ,
590
+ fkDef .Name , fkDef .Columns [i ], fkDef .Table )
591
+ }
592
+ parentIndex := parentSch .IndexOfColName (fkDef .ParentColumns [i ])
593
+ if parentIndex < 0 {
594
+ return nil , fmt .Errorf ("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found" ,
595
+ fkDef .Name , fkDef .ParentColumns [i ], fkDef .ParentTable )
596
+ }
597
+
598
+ childType := childSch [childIndex ].Type
599
+ parentType := parentSch [parentIndex ].Type
600
+
601
+ childExtendedType , ok := childType .(types.ExtendedType )
602
+ // if even one of the types is not an extended type, then we can't transform any values
603
+ if ! ok {
604
+ return nil , nil
605
+ }
606
+
607
+ if ! childType .Equals (parentType ) {
608
+ parentExtendedType , ok := parentType .(types.ExtendedType )
609
+ if ! ok {
610
+ // this should be impossible (child and parent should both be extended types), but just in case
611
+ return nil , nil
612
+ }
613
+
614
+ fromType := childExtendedType
615
+ toType := parentExtendedType
616
+ if direction == ParentToChild {
617
+ fromType = parentExtendedType
618
+ toType = childExtendedType
619
+ }
620
+
621
+ if convFns == nil {
622
+ convFns = make ([]ForeignKeyTypeConversionFn , len (childSch ))
623
+ }
624
+ convFns [childIndex ] = func (ctx * sql.Context , val any ) (sql.Type , any , error ) {
625
+ convertedVal , err := toType .ConvertToType (ctx , fromType , val )
626
+ return toType , convertedVal , err
627
+ }
628
+ }
629
+ }
630
+
631
+ return convFns , nil
632
+ }
0 commit comments