@@ -493,6 +493,7 @@ func (reference *ForeignKeyReferenceHandler) IsInitialized() bool {
493493}
494494
495495// CheckReference checks that the given row has an index entry in the referenced table.
496+ // Performs MySQL-compatible foreign key constraint validation with type-specific checks.
496497func (reference * ForeignKeyReferenceHandler ) CheckReference (ctx * sql.Context , row sql.Row ) error {
497498 // If even one of the values are NULL then we don't check the parent
498499 for _ , pos := range reference .RowMapper .IndexPositions {
@@ -507,7 +508,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
507508 }
508509 defer rowIter .Close (ctx )
509510
510- _ , err = rowIter .Next (ctx )
511+ parentRow , err : = rowIter .Next (ctx )
511512 if err != nil && err != io .EOF {
512513 // For SET types, conversion failures during foreign key validation should be treated as foreign key violations
513514 if sql .ErrConvertingToSet .Is (err ) || sql .ErrInvalidSetValue .Is (err ) {
@@ -518,12 +519,10 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
518519 }
519520 if err == nil {
520521 // We have a parent row, but check for type-specific validation
521- if validationErr := reference .validateDecimalConstraints (row ); validationErr != nil {
522- return validationErr
523- }
524- if validationErr := reference .validateTimeConstraints (row ); validationErr != nil {
522+ if validationErr := reference .validateColumnTypeConstraints (ctx , row , parentRow ); validationErr != nil {
525523 return validationErr
526524 }
525+
527526 // We have a parent row so throw no error
528527 return nil
529528 }
@@ -551,46 +550,14 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
551550 reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
552551}
553552
554- // validateDecimalConstraints checks that decimal foreign key columns have compatible scales.
555- func (reference * ForeignKeyReferenceHandler ) validateDecimalConstraints (row sql.Row ) error {
556- if reference .RowMapper .Index == nil {
557- return nil
558- }
559- indexColumnTypes := reference .RowMapper .Index .ColumnExpressionTypes ()
560- for parentIdx , parentCol := range indexColumnTypes {
561- if parentIdx >= len (reference .RowMapper .IndexPositions ) {
562- break
563- }
564- parentType := parentCol .Type
565- childColIdx := reference .RowMapper .IndexPositions [parentIdx ]
566- childType := reference .RowMapper .SourceSch [childColIdx ].Type
567- childDecimal , ok := childType .(sql.DecimalType )
568- if ! ok {
569- continue
570- }
571- parentDecimal , ok := parentType .(sql.DecimalType )
572- if ! ok {
573- continue
574- }
575- if childDecimal .Scale () != parentDecimal .Scale () {
576- return sql .ErrForeignKeyChildViolation .New (
577- reference .ForeignKey .Name ,
578- reference .ForeignKey .Table ,
579- reference .ForeignKey .ParentTable ,
580- reference .RowMapper .GetKeyString (row ),
581- )
582- }
583- }
584- return nil
585- }
586553
587- // validateTimeConstraints checks that time-related foreign key columns have exact type and precision matches.
588- // MySQL requires strict matching for time types in foreign keys - even logically equivalent values
589- // like '2001-02-03 12:34:56' vs '2001-02-03 12:34:56.000000' are rejected if precision differs.
590- func (reference * ForeignKeyReferenceHandler ) validateTimeConstraints (row sql.Row ) error {
554+ // validateColumnTypeConstraints validates that column types meet MySQL foreign key requirements.
555+ // Centralizes type validation for decimal scale matching and exact time type precision matching.
556+ func (reference * ForeignKeyReferenceHandler ) validateColumnTypeConstraints (ctx * sql.Context , childRow sql.Row , parentRow sql.Row ) error {
591557 if reference .RowMapper .Index == nil {
592558 return nil
593559 }
560+
594561 indexColumnTypes := reference .RowMapper .Index .ColumnExpressionTypes ()
595562 for parentIdx , parentCol := range indexColumnTypes {
596563 if parentIdx >= len (reference .RowMapper .IndexPositions ) {
@@ -600,27 +567,40 @@ func (reference *ForeignKeyReferenceHandler) validateTimeConstraints(row sql.Row
600567 childColIdx := reference .RowMapper .IndexPositions [parentIdx ]
601568 childType := reference .RowMapper .SourceSch [childColIdx ].Type
602569
603- // Check if both types are time-related
604- isChildTime := types .IsTime (childType ) || types .IsTimespan (childType )
605- isParentTime := types .IsTime (parentType ) || types .IsTimespan (parentType )
606-
607- if ! isChildTime || ! isParentTime {
608- continue
570+ // Check decimal constraints
571+ childDecimal , isChildDecimal := childType .(sql.DecimalType )
572+ parentDecimal , isParentDecimal := parentType .(sql.DecimalType )
573+ if isChildDecimal && isParentDecimal {
574+ if childDecimal .Scale () != parentDecimal .Scale () {
575+ return sql .ErrForeignKeyChildViolation .New (
576+ reference .ForeignKey .Name ,
577+ reference .ForeignKey .Table ,
578+ reference .ForeignKey .ParentTable ,
579+ reference .RowMapper .GetKeyString (childRow ),
580+ )
581+ }
609582 }
610583
611- // MySQL requires exact type matching for time types in foreign key validation
612- if ! childType .Equals (parentType ) {
613- return sql .ErrForeignKeyChildViolation .New (
614- reference .ForeignKey .Name ,
615- reference .ForeignKey .Table ,
616- reference .ForeignKey .ParentTable ,
617- reference .RowMapper .GetKeyString (row ),
618- )
584+ // Check time type constraints - MySQL requires exact type matching
585+ isChildTime := types .IsTime (childType ) || types .IsTimespan (childType )
586+ isParentTime := types .IsTime (parentType ) || types .IsTimespan (parentType )
587+ if isChildTime && isParentTime {
588+ // Different precisions of the same base type (e.g., DATETIME vs DATETIME(6)) are rejected
589+ // Cross-type references (e.g., DATETIME vs TIMESTAMP) are also rejected
590+ if ! childType .Equals (parentType ) {
591+ return sql .ErrForeignKeyChildViolation .New (
592+ reference .ForeignKey .Name ,
593+ reference .ForeignKey .Table ,
594+ reference .ForeignKey .ParentTable ,
595+ reference .RowMapper .GetKeyString (childRow ),
596+ )
597+ }
619598 }
620599 }
621600 return nil
622601}
623602
603+
624604// CheckTable checks that every row in the table has an index entry in the referenced table.
625605func (reference * ForeignKeyReferenceHandler ) CheckTable (ctx * sql.Context , tbl sql.ForeignKeyTable ) error {
626606 partIter , err := tbl .Partitions (ctx )
@@ -678,6 +658,7 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
678658 }
679659
680660 targetType := mapper .SourceSch [rowPos ].Type
661+
681662 // Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
682663 if mapper .TargetTypeConversions != nil && mapper .TargetTypeConversions [rowPos ] != nil {
683664 var err error
0 commit comments