Skip to content

Commit 73b3865

Browse files
authored
Merge pull request #2888 from dolthub/zachmu/foreign-key-types
Allow type conversions for foreign key checks
2 parents e6186e6 + 92db3c8 commit 73b3865

File tree

5 files changed

+138
-36
lines changed

5 files changed

+138
-36
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
273273
if err != nil {
274274
return nil, err
275275
}
276+
277+
typeConversions, err := plan.GetForeignKeyTypeConversions(parentTbl.Schema(), tblSch, fk, plan.ChildToParent)
278+
if err != nil {
279+
return nil, err
280+
}
281+
276282
var selfCols map[string]int
277283
if fk.IsSelfReferential() {
278284
selfCols = make(map[string]int)
@@ -284,11 +290,12 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
284290
ForeignKey: fk,
285291
SelfCols: selfCols,
286292
RowMapper: plan.ForeignKeyRowMapper{
287-
Index: parentIndex,
288-
Updater: parentUpdater,
289-
SourceSch: tblSch,
290-
IndexPositions: indexPositions,
291-
AppendTypes: appendTypes,
293+
Index: parentIndex,
294+
Updater: parentUpdater,
295+
SourceSch: tblSch,
296+
TargetTypeConversions: typeConversions,
297+
IndexPositions: indexPositions,
298+
AppendTypes: appendTypes,
292299
},
293300
}
294301
}
@@ -379,6 +386,11 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
379386
return nil, err
380387
}
381388

389+
typeConversions, err := plan.GetForeignKeyTypeConversions(tblSch, childTblSch, fk, plan.ParentToChild)
390+
if err != nil {
391+
return nil, err
392+
}
393+
382394
childEditor, err := getForeignKeyEditor(ctx, a, childTbl, cache, fkChain.AddForeignKey(fk.Name), checkRows)
383395
if err != nil {
384396
return nil, err
@@ -402,11 +414,12 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
402414
}
403415
fkEditor.RefActions[i] = plan.ForeignKeyRefActionData{
404416
RowMapper: &plan.ForeignKeyRowMapper{
405-
Index: childIndex,
406-
Updater: childUpdater,
407-
SourceSch: tblSch,
408-
IndexPositions: indexPositions,
409-
AppendTypes: appendTypes,
417+
Index: childIndex,
418+
Updater: childUpdater,
419+
SourceSch: tblSch,
420+
TargetTypeConversions: typeConversions,
421+
IndexPositions: indexPositions,
422+
AppendTypes: appendTypes,
410423
},
411424
Editor: childEditor,
412425
ForeignKey: fk,

sql/plan/alter_foreign_key.go

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,22 @@ func ResolveForeignKey(ctx *sql.Context, tbl sql.ForeignKeyTable, refTbl sql.For
199199
selfCols[strings.ToLower(col.Name)] = i
200200
}
201201
}
202+
203+
typeConversions, err := GetForeignKeyTypeConversions(refTbl.Schema(), tbl.Schema(), fkDef, ChildToParent)
204+
if err != nil {
205+
return err
206+
}
207+
202208
reference := &ForeignKeyReferenceHandler{
203209
ForeignKey: fkDef,
204210
SelfCols: selfCols,
205211
RowMapper: ForeignKeyRowMapper{
206-
Index: refTblIndex,
207-
Updater: refTbl.GetForeignKeyEditor(ctx),
208-
SourceSch: tbl.Schema(),
209-
IndexPositions: indexPositions,
210-
AppendTypes: appendTypes,
212+
Index: refTblIndex,
213+
Updater: refTbl.GetForeignKeyEditor(ctx),
214+
SourceSch: tbl.Schema(),
215+
TargetTypeConversions: typeConversions,
216+
IndexPositions: indexPositions,
217+
AppendTypes: appendTypes,
211218
},
212219
}
213220

@@ -531,21 +538,17 @@ func FindForeignKeyColMapping(
531538
localRowPos, ok := localSchPositionMap[colName]
532539
if !ok {
533540
// Will happen if a column is renamed that is referenced by a foreign key
534-
//TODO: enforce that renaming a column referenced by a foreign key updates that foreign key
541+
// TODO: enforce that renaming a column referenced by a foreign key updates that foreign key
535542
return nil, nil, fmt.Errorf("column `%s` in foreign key `%s` cannot be found",
536543
colName, fkName)
537544
}
538-
expectedType := localSchTypeMap[colName]
539545
destFkCol := destTblName + "." + destFKCols[fkIdx]
540546
indexPos, ok := indexColMap[destFkCol]
541547
if !ok {
542548
// Same as above, renaming a referenced column would cause this error
543549
return nil, nil, fmt.Errorf("index column `%s` in foreign key `%s` cannot be found",
544550
destFKCols[fkIdx], fkName)
545551
}
546-
if !foreignKeyComparableTypes(ctx, indexTypeMap[destFkCol], expectedType) {
547-
return nil, nil, sql.ErrForeignKeyColumnTypeMismatch.New(colName, destFkCol)
548-
}
549552
indexPositions[indexPos] = localRowPos
550553
}
551554
return indexPositions, appendTypes, nil

sql/plan/foreign_key_editor.go

Lines changed: 98 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"strings"
2121

2222
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/types"
2324
)
2425

2526
// ChildParentMapping is a mapping from the foreign key columns of a child schema to the parent schema. The position
@@ -227,7 +228,7 @@ func (fkEditor *ForeignKeyEditor) OnUpdateSetNull(ctx *sql.Context, refActionDat
227228

228229
// Delete handles both the standard DELETE statement and propagated referential actions from a parent table's ON DELETE.
229230
func (fkEditor *ForeignKeyEditor) Delete(ctx *sql.Context, row sql.Row, depth int) error {
230-
//TODO: may need to process some cascades after the update to avoid recursive violations, write some tests on this
231+
// TODO: may need to process some cascades after the update to avoid recursive violations, write some tests on this
231232
for _, refActionData := range fkEditor.RefActions {
232233
switch refActionData.ForeignKey.OnDelete {
233234
default: // RESTRICT and friends
@@ -425,6 +426,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
425426
return nil
426427
}
427428
}
429+
428430
return sql.ErrForeignKeyChildViolation.New(reference.ForeignKey.Name, reference.ForeignKey.Table,
429431
reference.ForeignKey.ParentTable, reference.RowMapper.GetKeyString(row))
430432
}
@@ -455,6 +457,9 @@ type ForeignKeyRowMapper struct {
455457
Index sql.Index
456458
Updater sql.ForeignKeyEditor
457459
SourceSch sql.Schema
460+
// TargetTypeConversions are a set of functions to transform the value in the table to the corresponding value in the
461+
// other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
462+
TargetTypeConversions []ForeignKeyTypeConversionFn
458463
// IndexPositions hold the mapping between an index's column position and the source row's column position. Given
459464
// an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
460465
// [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -481,7 +486,21 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
481486
if rowVal == nil {
482487
return sql.RowsToRowIter(), nil
483488
}
484-
rang[rangPosition] = sql.ClosedRangeColumnExpr(rowVal, rowVal, mapper.SourceSch[rowPos].Type)
489+
490+
targetType := mapper.SourceSch[rowPos].Type
491+
// Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
492+
if mapper.TargetTypeConversions != nil && mapper.TargetTypeConversions[rowPos] != nil {
493+
var err error
494+
targetType, rowVal, err = mapper.TargetTypeConversions[rowPos](ctx, rowVal)
495+
// An error means the type conversion failed, which typically means there's no way to convert the value given to
496+
// the target value because of e.g. range constraints (trying to assign an INT to a TINYINT column). We treat
497+
// this as an empty result for this iterator, since this value cannot possibly be present in the other table.
498+
if err != nil {
499+
return sql.RowsToRowIter(), nil
500+
}
501+
}
502+
503+
rang[rangPosition] = sql.ClosedRangeColumnExpr(rowVal, rowVal, targetType)
485504
}
486505
for i, appendType := range mapper.AppendTypes {
487506
rang[i+len(mapper.IndexPositions)] = sql.AllRangeColumnExpr(appendType)
@@ -490,7 +509,7 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
490509
if !mapper.Index.CanSupport(rang) {
491510
return nil, ErrInvalidLookupForIndexedTable.New(rang.DebugString())
492511
}
493-
//TODO: profile this, may need to redesign this or add a fast path
512+
// TODO: profile this, may need to redesign this or add a fast path
494513
lookup := sql.IndexLookup{Ranges: sql.MySQLRangeCollection{rang}, Index: mapper.Index}
495514

496515
editorData := mapper.Updater.IndexedAccess(lookup)
@@ -520,30 +539,94 @@ func (mapper *ForeignKeyRowMapper) GetKeyString(row sql.Row) string {
520539

521540
// GetChildParentMapping returns a mapping from the foreign key columns of a child schema to the parent schema.
522541
func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.ForeignKeyConstraint) (ChildParentMapping, error) {
523-
parentMap := make(map[string]int)
524-
for i, col := range parentSch {
525-
parentMap[strings.ToLower(col.Name)] = i
526-
}
527-
childMap := make(map[string]int)
528-
for i, col := range childSch {
529-
childMap[strings.ToLower(col.Name)] = i
530-
}
531542
mapping := make(ChildParentMapping, len(childSch))
532543
for i := range mapping {
533544
mapping[i] = -1
534545
}
535546
for i := range fkDef.Columns {
536-
childIndex, ok := childMap[strings.ToLower(fkDef.Columns[i])]
537-
if !ok {
547+
childIndex := childSch.IndexOfColName(fkDef.Columns[i])
548+
if childIndex < 0 {
538549
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on table `%s` but it could not be found",
539550
fkDef.Name, fkDef.Columns[i], fkDef.Table)
540551
}
541-
parentIndex, ok := parentMap[strings.ToLower(fkDef.ParentColumns[i])]
542-
if !ok {
552+
parentIndex := parentSch.IndexOfColName(fkDef.ParentColumns[i])
553+
if parentIndex < 0 {
543554
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found",
544555
fkDef.Name, fkDef.ParentColumns[i], fkDef.ParentTable)
545556
}
546557
mapping[childIndex] = parentIndex
547558
}
548559
return mapping, nil
549560
}
561+
562+
// ForeignKeyTypeConversionDirection specifies whether a child column type is being converted to its parent type for
563+
// constraint enforcement, or vice versa.
564+
type ForeignKeyTypeConversionDirection byte
565+
566+
const (
567+
ChildToParent ForeignKeyTypeConversionDirection = iota
568+
ParentToChild
569+
)
570+
571+
// ForeignKeyTypeConversionFn is a function that transforms a value from one type to another for foreign key constraint
572+
// enforcement. The target type is returned along with the transformed value, or an error if the transformation fails.
573+
type ForeignKeyTypeConversionFn func(ctx *sql.Context, val any) (sql.Type, any, error)
574+
575+
// GetForeignKeyTypeConversions returns a set of functions to convert a type in a one foreign key column table to the
576+
// type in the corresponding table. Specify the schema of both child and parent tables, as well as whether the
577+
// transformation is from child to parent or vice versa.
578+
func GetForeignKeyTypeConversions(
579+
parentSch sql.Schema,
580+
childSch sql.Schema,
581+
fkDef sql.ForeignKeyConstraint,
582+
direction ForeignKeyTypeConversionDirection,
583+
) ([]ForeignKeyTypeConversionFn, error) {
584+
var convFns []ForeignKeyTypeConversionFn
585+
586+
for i := range fkDef.Columns {
587+
childIndex := childSch.IndexOfColName(fkDef.Columns[i])
588+
if childIndex < 0 {
589+
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on table `%s` but it could not be found",
590+
fkDef.Name, fkDef.Columns[i], fkDef.Table)
591+
}
592+
parentIndex := parentSch.IndexOfColName(fkDef.ParentColumns[i])
593+
if parentIndex < 0 {
594+
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found",
595+
fkDef.Name, fkDef.ParentColumns[i], fkDef.ParentTable)
596+
}
597+
598+
childType := childSch[childIndex].Type
599+
parentType := parentSch[parentIndex].Type
600+
601+
childExtendedType, ok := childType.(types.ExtendedType)
602+
// if even one of the types is not an extended type, then we can't transform any values
603+
if !ok {
604+
return nil, nil
605+
}
606+
607+
if !childType.Equals(parentType) {
608+
parentExtendedType, ok := parentType.(types.ExtendedType)
609+
if !ok {
610+
// this should be impossible (child and parent should both be extended types), but just in case
611+
return nil, nil
612+
}
613+
614+
fromType := childExtendedType
615+
toType := parentExtendedType
616+
if direction == ParentToChild {
617+
fromType = parentExtendedType
618+
toType = childExtendedType
619+
}
620+
621+
if convFns == nil {
622+
convFns = make([]ForeignKeyTypeConversionFn, len(childSch))
623+
}
624+
convFns[childIndex] = func(ctx *sql.Context, val any) (sql.Type, any, error) {
625+
convertedVal, err := toType.ConvertToType(ctx, fromType, val)
626+
return toType, convertedVal, err
627+
}
628+
}
629+
}
630+
631+
return convFns, nil
632+
}

sql/tables.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ type ForeignKeyTable interface {
192192
CreateIndexForForeignKey(ctx *Context, indexDef IndexDef) error
193193
// GetDeclaredForeignKeys returns the foreign key constraints that are declared by this table.
194194
GetDeclaredForeignKeys(ctx *Context) ([]ForeignKeyConstraint, error)
195-
// GetReferencedForeignKeys returns the foreign key constraints that are referenced by this table.
195+
// GetReferencedForeignKeys returns the foreign key constraints that reference this table as the parent
196196
GetReferencedForeignKeys(ctx *Context) ([]ForeignKeyConstraint, error)
197197
// AddForeignKey adds the given foreign key constraint to the table. Returns an error if the foreign key name
198198
// already exists on any other table within the database.

sql/types/extended.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ type ExtendedType interface {
3636
FormatValue(val any) (string, error)
3737
// MaxSerializedWidth returns the maximum size that the serialized value may represent.
3838
MaxSerializedWidth() ExtendedTypeSerializedWidth
39+
// ConvertToType converts the given value of the given type to this type, or returns an error if
40+
// no conversion is possible.
41+
ConvertToType(ctx *sql.Context, typ ExtendedType, val any) (any, error)
3942
}
4043

4144
type ExtendedTypeSerializedWidth uint8

0 commit comments

Comments
 (0)