Skip to content

Commit d8ad856

Browse files
elianddbclaude
andcommitted
Clean up DECIMAL foreign key validation code
- Remove verbose comments and dead code - Reduce nesting in shouldRejectDecimalMatch function - Use existing patterns from codebase - Remove unnecessary vitess import - Simplify type checking logic 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 98b82b6 commit d8ad856

File tree

1 file changed

+13
-46
lines changed

1 file changed

+13
-46
lines changed

sql/plan/foreign_key_editor.go

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121

2222
"github.com/dolthub/go-mysql-server/sql"
2323
"github.com/dolthub/go-mysql-server/sql/types"
24-
"github.com/dolthub/vitess/go/sqltypes"
2524
)
2625

2726
// ChildParentMapping is a mapping from the foreign key columns of a child schema to the parent schema. The position
@@ -514,7 +513,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
514513
}
515514
if err == nil {
516515
// We have a parent row, but for DECIMAL types we need to be strict about precision/scale
517-
if shouldReject := reference.shouldRejectDecimalMatch(ctx, row); shouldReject {
516+
if shouldReject := reference.validateDecimalMatch(ctx, row); shouldReject {
518517
return sql.ErrForeignKeyChildViolation.New(reference.ForeignKey.Name, reference.ForeignKey.Table,
519518
reference.ForeignKey.ParentTable, reference.RowMapper.GetKeyString(row))
520519
}
@@ -545,36 +544,25 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
545544
reference.ForeignKey.ParentTable, reference.RowMapper.GetKeyString(row))
546545
}
547546

548-
func (reference *ForeignKeyReferenceHandler) shouldRejectDecimalMatch(ctx *sql.Context, row sql.Row) bool {
547+
func (reference *ForeignKeyReferenceHandler) validateDecimalMatch(ctx *sql.Context, row sql.Row) bool {
548+
if reference.RowMapper.Index == nil {
549+
return false
550+
}
551+
indexColumnTypes := reference.RowMapper.Index.ColumnExpressionTypes()
549552
for i := range reference.ForeignKey.Columns {
553+
if i >= len(indexColumnTypes) {
554+
continue
555+
}
550556
childColIdx := reference.RowMapper.IndexPositions[i]
551-
childType := reference.RowMapper.SourceSch[childColIdx].Type
552-
553-
if childType.Type() == sqltypes.Decimal {
554-
childDecimal, ok := childType.(sql.DecimalType)
555-
if !ok {
556-
continue
557-
}
558-
559-
if reference.RowMapper.Index != nil {
560-
indexColumnTypes := reference.RowMapper.Index.ColumnExpressionTypes()
561-
if len(indexColumnTypes) > i {
562-
parentType := indexColumnTypes[i].Type
563-
if parentType.Type() == sqltypes.Decimal {
564-
parentDecimal, ok := parentType.(sql.DecimalType)
565-
if ok && childDecimal.Scale() != parentDecimal.Scale() {
566-
return true
567-
}
568-
}
569-
}
570-
}
557+
childDecimal, childOk := reference.RowMapper.SourceSch[childColIdx].Type.(sql.DecimalType)
558+
parentDecimal, parentOk := indexColumnTypes[i].Type.(sql.DecimalType)
559+
if childOk && parentOk && childDecimal.Scale() != parentDecimal.Scale() {
560+
return true
571561
}
572562
}
573563
return false
574564
}
575565

576-
577-
578566
// CheckTable checks that every row in the table has an index entry in the referenced table.
579567
func (reference *ForeignKeyReferenceHandler) CheckTable(ctx *sql.Context, tbl sql.ForeignKeyTable) error {
580568
partIter, err := tbl.Partitions(ctx)
@@ -633,16 +621,6 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
633621

634622
targetType := mapper.SourceSch[rowPos].Type
635623

636-
// Special handling for DECIMAL types: check if this is a foreign key reference
637-
// with different precision/scale. If so, we need to be strict like MySQL
638-
if targetType.Type() == sqltypes.Decimal && refCheck {
639-
// For DECIMAL foreign key lookups, we need to ensure exact type matching
640-
// This is a simplified approach - we'll return an empty iterator for now
641-
// to prevent matches when precision/scale differs
642-
// TODO: This should be refined to only block when types actually differ
643-
// For now, let's continue with normal processing and handle it later
644-
}
645-
646624
// Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
647625
if mapper.TargetTypeConversions != nil && mapper.TargetTypeConversions[rowPos] != nil {
648626
var err error
@@ -766,17 +744,6 @@ func GetForeignKeyTypeConversions(
766744
return nil, nil
767745
}
768746

769-
// Special handling for DECIMAL types: when precision/scale differs,
770-
// don't allow type conversion to ensure strict constraint checking
771-
if childType.Type() == sqltypes.Decimal && parentType.Type() == sqltypes.Decimal {
772-
// For DECIMAL foreign keys with different precision/scale, we don't allow
773-
// automatic type conversion. This ensures constraint checking matches MySQL's
774-
// strict behavior where 78.9 (4,1) != 78.90 (4,2)
775-
// Note: childType.Equals(parentType) already returned false above, so we know they differ
776-
// However, since DECIMAL is not ExtendedType, this logic won't be triggered
777-
// The actual constraint validation is handled in CheckReference
778-
return nil, nil
779-
}
780747

781748
fromType := childExtendedType
782749
toType := parentExtendedType

0 commit comments

Comments
 (0)