Skip to content

Commit 6a5e2ee

Browse files
committed
Bug fix: need to swap the direction of transform in some cases. Also some naming refactors
1 parent f645d60 commit 6a5e2ee

File tree

3 files changed

+69
-56
lines changed

3 files changed

+69
-56
lines changed

sql/analyzer/apply_foreign_keys.go

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
274274
return nil, err
275275
}
276276

277-
typeTransforms, err := plan.GetChildParentTypeTransforms(parentTbl.Schema(), tblSch, fk)
277+
typeConversions, err := plan.GetForeignKeyTypeConversions(parentTbl.Schema(), tblSch, fk, plan.ChildToParent)
278278
if err != nil {
279279
return nil, err
280280
}
@@ -290,12 +290,12 @@ func getForeignKeyReferences(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
290290
ForeignKey: fk,
291291
SelfCols: selfCols,
292292
RowMapper: plan.ForeignKeyRowMapper{
293-
Index: parentIndex,
294-
Updater: parentUpdater,
295-
SourceSch: tblSch,
296-
TargetTypeTransforms: typeTransforms,
297-
IndexPositions: indexPositions,
298-
AppendTypes: appendTypes,
293+
Index: parentIndex,
294+
Updater: parentUpdater,
295+
SourceSch: tblSch,
296+
TargetTypeConversions: typeConversions,
297+
IndexPositions: indexPositions,
298+
AppendTypes: appendTypes,
299299
},
300300
}
301301
}
@@ -386,7 +386,7 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
386386
return nil, err
387387
}
388388

389-
typeTransforms, err := plan.GetChildParentTypeTransforms(tblSch, childTblSch, fk)
389+
typeConversions, err := plan.GetForeignKeyTypeConversions(tblSch, childTblSch, fk, plan.ParentToChild)
390390
if err != nil {
391391
return nil, err
392392
}
@@ -414,12 +414,12 @@ func getForeignKeyRefActions(ctx *sql.Context, a *Analyzer, tbl sql.ForeignKeyTa
414414
}
415415
fkEditor.RefActions[i] = plan.ForeignKeyRefActionData{
416416
RowMapper: &plan.ForeignKeyRowMapper{
417-
Index: childIndex,
418-
Updater: childUpdater,
419-
SourceSch: tblSch,
420-
TargetTypeTransforms: typeTransforms,
421-
IndexPositions: indexPositions,
422-
AppendTypes: appendTypes,
417+
Index: childIndex,
418+
Updater: childUpdater,
419+
SourceSch: tblSch,
420+
TargetTypeConversions: typeConversions,
421+
IndexPositions: indexPositions,
422+
AppendTypes: appendTypes,
423423
},
424424
Editor: childEditor,
425425
ForeignKey: fk,

sql/plan/alter_foreign_key.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ func ResolveForeignKey(ctx *sql.Context, tbl sql.ForeignKeyTable, refTbl sql.For
200200
}
201201
}
202202

203-
typeTransforms, err := GetChildParentTypeTransforms(refTbl.Schema(), tbl.Schema(), fkDef)
203+
typeConversions, err := GetForeignKeyTypeConversions(refTbl.Schema(), tbl.Schema(), fkDef, ChildToParent)
204204
if err != nil {
205205
return err
206206
}
@@ -209,12 +209,12 @@ func ResolveForeignKey(ctx *sql.Context, tbl sql.ForeignKeyTable, refTbl sql.For
209209
ForeignKey: fkDef,
210210
SelfCols: selfCols,
211211
RowMapper: ForeignKeyRowMapper{
212-
Index: refTblIndex,
213-
Updater: refTbl.GetForeignKeyEditor(ctx),
214-
SourceSch: tbl.Schema(),
215-
TargetTypeTransforms: typeTransforms,
216-
IndexPositions: indexPositions,
217-
AppendTypes: appendTypes,
212+
Index: refTblIndex,
213+
Updater: refTbl.GetForeignKeyEditor(ctx),
214+
SourceSch: tbl.Schema(),
215+
TargetTypeConversions: typeConversions,
216+
IndexPositions: indexPositions,
217+
AppendTypes: appendTypes,
218218
},
219219
}
220220

sql/plan/foreign_key_editor.go

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,9 @@ type ForeignKeyRowMapper struct {
457457
Index sql.Index
458458
Updater sql.ForeignKeyEditor
459459
SourceSch sql.Schema
460-
// TargetTypeTransforms are a set of functions to transform the value in the table to the corresponding value in the
460+
// TargetTypeConversions are a set of functions to transform the value in the table to the corresponding value in the
461461
// other table. This is required when the types of the two tables are compatible but different (e.g. INT and BIGINT).
462-
TargetTypeTransforms []func(ctx *sql.Context, val any) (sql.Type, any, error)
462+
TargetTypeConversions []ForeignKeyTypeConversionFn
463463
// IndexPositions hold the mapping between an index's column position and the source row's column position. Given
464464
// an index (x1, x2) and a source row (y1, y2, y3) and the relation (x1->y3, x2->y1), this slice would contain
465465
// [2, 0]. The first index column "x1" maps to the third source column "y3" (so position 2 since it's zero-based),
@@ -489,9 +489,9 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
489489

490490
targetType := mapper.SourceSch[rowPos].Type
491491
// Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
492-
if mapper.TargetTypeTransforms != nil && mapper.TargetTypeTransforms[rowPos] != nil {
492+
if mapper.TargetTypeConversions != nil && mapper.TargetTypeConversions[rowPos] != nil {
493493
var err error
494-
targetType, rowVal, err = mapper.TargetTypeTransforms[rowPos](ctx, rowVal)
494+
targetType, rowVal, err = mapper.TargetTypeConversions[rowPos](ctx, rowVal)
495495
// TODO: possible for this to fail without error, which means the value cannot be found in the other table
496496
if err != nil {
497497
return nil, err
@@ -537,26 +537,18 @@ func (mapper *ForeignKeyRowMapper) GetKeyString(row sql.Row) string {
537537

538538
// GetChildParentMapping returns a mapping from the foreign key columns of a child schema to the parent schema.
539539
func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.ForeignKeyConstraint) (ChildParentMapping, error) {
540-
parentMap := make(map[string]int)
541-
for i, col := range parentSch {
542-
parentMap[strings.ToLower(col.Name)] = i
543-
}
544-
childMap := make(map[string]int)
545-
for i, col := range childSch {
546-
childMap[strings.ToLower(col.Name)] = i
547-
}
548540
mapping := make(ChildParentMapping, len(childSch))
549541
for i := range mapping {
550542
mapping[i] = -1
551543
}
552544
for i := range fkDef.Columns {
553-
childIndex, ok := childMap[strings.ToLower(fkDef.Columns[i])]
554-
if !ok {
545+
childIndex := childSch.IndexOfColName(fkDef.Columns[i])
546+
if childIndex < 0 {
555547
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on table `%s` but it could not be found",
556548
fkDef.Name, fkDef.Columns[i], fkDef.Table)
557549
}
558-
parentIndex, ok := parentMap[strings.ToLower(fkDef.ParentColumns[i])]
559-
if !ok {
550+
parentIndex := parentSch.IndexOfColName(fkDef.ParentColumns[i])
551+
if parentIndex < 0 {
560552
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found",
561553
fkDef.Name, fkDef.ParentColumns[i], fkDef.ParentTable)
562554
}
@@ -565,28 +557,37 @@ func GetChildParentMapping(parentSch sql.Schema, childSch sql.Schema, fkDef sql.
565557
return mapping, nil
566558
}
567559

568-
// GetChildParentTypeTransforms returs a set of functions to transform the value in the child table to the
569-
// corresponding type in the parent table, if necessary
570-
func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkDef sql.ForeignKeyConstraint) ([]func(ctx *sql.Context, val any) (sql.Type, any, error), error) {
571-
parentMap := make(map[string]int)
572-
for i, col := range parentSch {
573-
parentMap[strings.ToLower(col.Name)] = i
574-
}
575-
childMap := make(map[string]int)
576-
for i, col := range childSch {
577-
childMap[strings.ToLower(col.Name)] = i
578-
}
560+
// ForeignKeyTypeConversionDirection specifies whether a child column type is being converted to its parent type for
561+
// constraint enforcement, or vice versa.
562+
type ForeignKeyTypeConversionDirection byte
563+
const (
564+
ChildToParent ForeignKeyTypeConversionDirection = iota
565+
ParentToChild
566+
)
579567

580-
var mapping []func(*sql.Context, any) (sql.Type, any, error)
568+
// ForeignKeyTypeConversionFn is a function that transforms a value from one type to another for foreign key constraint
569+
// enforcement. The target type is returned along with the transformed value, or an error if the transformation fails.
570+
type ForeignKeyTypeConversionFn func(ctx *sql.Context, val any) (sql.Type, any, error)
571+
572+
// GetForeignKeyTypeConversions returns a set of functions to convert a type in a one foreign key column table to the
573+
// type in the corresponding table. Specify the schema of both child and parent tables, as well as whether the
574+
// transformation is from child to parent or vice versa.
575+
func GetForeignKeyTypeConversions(
576+
parentSch sql.Schema,
577+
childSch sql.Schema,
578+
fkDef sql.ForeignKeyConstraint,
579+
direction ForeignKeyTypeConversionDirection,
580+
) ([]ForeignKeyTypeConversionFn, error) {
581+
var mapping []ForeignKeyTypeConversionFn
581582

582583
for i := range fkDef.Columns {
583-
childIndex, ok := childMap[strings.ToLower(fkDef.Columns[i])]
584-
if !ok {
584+
childIndex := childSch.IndexOfColName(fkDef.Columns[i])
585+
if childIndex < 0 {
585586
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on table `%s` but it could not be found",
586587
fkDef.Name, fkDef.Columns[i], fkDef.Table)
587588
}
588-
parentIndex, ok := parentMap[strings.ToLower(fkDef.ParentColumns[i])]
589-
if !ok {
589+
parentIndex := parentSch.IndexOfColName(fkDef.ParentColumns[i])
590+
if parentIndex < 0 {
590591
return nil, fmt.Errorf("foreign key `%s` refers to column `%s` on referenced table `%s` but it could not be found",
591592
fkDef.Name, fkDef.ParentColumns[i], fkDef.ParentTable)
592593
}
@@ -602,12 +603,24 @@ func GetChildParentTypeTransforms(parentSch sql.Schema, childSch sql.Schema, fkD
602603

603604
if !childType.Equals(parentType) {
604605
parentExtendedType, _ := parentType.(types.ExtendedType)
606+
if parentExtendedType == nil {
607+
// this should be impossible (child and parent should both be extended types), but just in case
608+
return nil, nil
609+
}
610+
611+
fromType := childExtendedType
612+
toType := parentExtendedType
613+
if direction == ParentToChild {
614+
fromType = parentExtendedType
615+
toType = childExtendedType
616+
}
617+
605618
if mapping == nil {
606-
mapping = make([]func(*sql.Context, any) (sql.Type, any, error), len(childSch))
619+
mapping = make([]ForeignKeyTypeConversionFn, len(childSch))
607620
}
608621
mapping[childIndex] = func(ctx *sql.Context, val any) (sql.Type, any, error) {
609-
convertedVal, err := parentExtendedType.ConvertToType(ctx, childExtendedType, val)
610-
return parentExtendedType, convertedVal, err
622+
convertedVal, err := toType.ConvertToType(ctx, fromType, val)
623+
return toType, convertedVal, err
611624
}
612625
}
613626
}

0 commit comments

Comments
 (0)