@@ -21,7 +21,6 @@ import (
21
21
22
22
"github.com/dolthub/go-mysql-server/sql"
23
23
"github.com/dolthub/go-mysql-server/sql/types"
24
- "github.com/dolthub/vitess/go/sqltypes"
25
24
)
26
25
27
26
// ChildParentMapping is a mapping from the foreign key columns of a child schema to the parent schema. The position
@@ -514,7 +513,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
514
513
}
515
514
if err == nil {
516
515
// We have a parent row, but for DECIMAL types we need to be strict about precision/scale
517
- if shouldReject := reference .shouldRejectDecimalMatch (ctx , row ); shouldReject {
516
+ if shouldReject := reference .validateDecimalMatch (ctx , row ); shouldReject {
518
517
return sql .ErrForeignKeyChildViolation .New (reference .ForeignKey .Name , reference .ForeignKey .Table ,
519
518
reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
520
519
}
@@ -545,36 +544,25 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
545
544
reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
546
545
}
547
546
548
- func (reference * ForeignKeyReferenceHandler ) shouldRejectDecimalMatch (ctx * sql.Context , row sql.Row ) bool {
547
+ func (reference * ForeignKeyReferenceHandler ) validateDecimalMatch (ctx * sql.Context , row sql.Row ) bool {
548
+ if reference .RowMapper .Index == nil {
549
+ return false
550
+ }
551
+ indexColumnTypes := reference .RowMapper .Index .ColumnExpressionTypes ()
549
552
for i := range reference .ForeignKey .Columns {
553
+ if i >= len (indexColumnTypes ) {
554
+ continue
555
+ }
550
556
childColIdx := reference .RowMapper .IndexPositions [i ]
551
- childType := reference .RowMapper .SourceSch [childColIdx ].Type
552
-
553
- if childType .Type () == sqltypes .Decimal {
554
- childDecimal , ok := childType .(sql.DecimalType )
555
- if ! ok {
556
- continue
557
- }
558
-
559
- if reference .RowMapper .Index != nil {
560
- indexColumnTypes := reference .RowMapper .Index .ColumnExpressionTypes ()
561
- if len (indexColumnTypes ) > i {
562
- parentType := indexColumnTypes [i ].Type
563
- if parentType .Type () == sqltypes .Decimal {
564
- parentDecimal , ok := parentType .(sql.DecimalType )
565
- if ok && childDecimal .Scale () != parentDecimal .Scale () {
566
- return true
567
- }
568
- }
569
- }
570
- }
557
+ childDecimal , childOk := reference .RowMapper .SourceSch [childColIdx ].Type .(sql.DecimalType )
558
+ parentDecimal , parentOk := indexColumnTypes [i ].Type .(sql.DecimalType )
559
+ if childOk && parentOk && childDecimal .Scale () != parentDecimal .Scale () {
560
+ return true
571
561
}
572
562
}
573
563
return false
574
564
}
575
565
576
-
577
-
578
566
// CheckTable checks that every row in the table has an index entry in the referenced table.
579
567
func (reference * ForeignKeyReferenceHandler ) CheckTable (ctx * sql.Context , tbl sql.ForeignKeyTable ) error {
580
568
partIter , err := tbl .Partitions (ctx )
@@ -633,16 +621,6 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
633
621
634
622
targetType := mapper .SourceSch [rowPos ].Type
635
623
636
- // Special handling for DECIMAL types: check if this is a foreign key reference
637
- // with different precision/scale. If so, we need to be strict like MySQL
638
- if targetType .Type () == sqltypes .Decimal && refCheck {
639
- // For DECIMAL foreign key lookups, we need to ensure exact type matching
640
- // This is a simplified approach - we'll return an empty iterator for now
641
- // to prevent matches when precision/scale differs
642
- // TODO: This should be refined to only block when types actually differ
643
- // For now, let's continue with normal processing and handle it later
644
- }
645
-
646
624
// Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
647
625
if mapper .TargetTypeConversions != nil && mapper .TargetTypeConversions [rowPos ] != nil {
648
626
var err error
@@ -766,17 +744,6 @@ func GetForeignKeyTypeConversions(
766
744
return nil , nil
767
745
}
768
746
769
- // Special handling for DECIMAL types: when precision/scale differs,
770
- // don't allow type conversion to ensure strict constraint checking
771
- if childType .Type () == sqltypes .Decimal && parentType .Type () == sqltypes .Decimal {
772
- // For DECIMAL foreign keys with different precision/scale, we don't allow
773
- // automatic type conversion. This ensures constraint checking matches MySQL's
774
- // strict behavior where 78.9 (4,1) != 78.90 (4,2)
775
- // Note: childType.Equals(parentType) already returned false above, so we know they differ
776
- // However, since DECIMAL is not ExtendedType, this logic won't be triggered
777
- // The actual constraint validation is handled in CheckReference
778
- return nil , nil
779
- }
780
747
781
748
fromType := childExtendedType
782
749
toType := parentExtendedType
0 commit comments