Skip to content

Commit af9f681

Browse files
committed
modified GeneralizeType to match rules for Case statement, need to test
1 parent 6ecae23 commit af9f681

File tree

6 files changed

+171
-77
lines changed

6 files changed

+171
-77
lines changed

enginetest/queries/script_queries.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8727,7 +8727,8 @@ where
87278727
},
87288728
},
87298729
{
8730-
Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0",
8730+
Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0",
8731+
Expected: []sql.Row{{128}},
87318732
},
87328733
},
87338734
},

sql/expression/case.go

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -43,71 +43,14 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression
4343
return &Case{expr, branches, elseExpr}
4444
}
4545

46-
// From the description of operator typing here:
47-
// https://dev.mysql.com/doc/refman/8.0/en/flow-control-functions.html#operator_case
48-
func combinedCaseBranchType(left, right sql.Type) sql.Type {
49-
if left == types.Null {
50-
return right
51-
}
52-
if right == types.Null {
53-
return left
54-
}
55-
56-
// Our current implementation of StringType.Convert(enum), does not match MySQL's behavior.
57-
// So, we make sure to return Enums in this particular case.
58-
// More details: https://github.com/dolthub/dolt/issues/8598
59-
if types.IsEnum(left) && types.IsEnum(right) {
60-
return right
61-
}
62-
if types.IsSet(left) && types.IsSet(right) {
63-
return right
64-
}
65-
if types.IsTextOnly(left) && types.IsTextOnly(right) {
66-
return types.LongText
67-
}
68-
if types.IsTextBlob(left) && types.IsTextBlob(right) {
69-
return types.LongBlob
70-
}
71-
if types.IsTime(left) && types.IsTime(right) {
72-
if left == right {
73-
return left
74-
}
75-
return types.DatetimeMaxPrecision
76-
}
77-
if types.IsNumber(left) && types.IsNumber(right) {
78-
if left == types.Float64 || right == types.Float64 {
79-
return types.Float64
80-
}
81-
if left == types.Float32 || right == types.Float32 {
82-
return types.Float32
83-
}
84-
if types.IsDecimal(left) || types.IsDecimal(right) {
85-
return types.MustCreateDecimalType(65, 10)
86-
}
87-
if left == types.Uint64 && types.IsSigned(right) ||
88-
right == types.Uint64 && types.IsSigned(left) {
89-
return types.MustCreateDecimalType(65, 10)
90-
}
91-
if !types.IsSigned(left) && !types.IsSigned(right) {
92-
return types.Uint64
93-
} else {
94-
return types.Int64
95-
}
96-
}
97-
if types.IsJSON(left) && types.IsJSON(right) {
98-
return types.JSON
99-
}
100-
return types.LongText
101-
}
102-
10346
// Type implements the sql.Expression interface.
10447
func (c *Case) Type() sql.Type {
10548
curr := types.Null
10649
for _, b := range c.Branches {
107-
curr = combinedCaseBranchType(curr, b.Value.Type())
50+
curr = types.GeneralizeTypes(curr, b.Value.Type())
10851
}
10952
if c.Else != nil {
110-
curr = combinedCaseBranchType(curr, c.Else.Type())
53+
curr = types.GeneralizeTypes(curr, c.Else.Type())
11154
}
11255
return curr
11356
}

sql/expression/function/if.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
8989
return nil, err
9090
}
9191
}
92-
eval, _, err = f.Type().Convert(ctx, eval)
92+
if ret, _, err := f.Type().Convert(ctx, eval); err == nil {
93+
return ret, nil
94+
}
9395
return eval, err
9496
}
9597

sql/expression/function/ifnull.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,26 @@ func (f *IfNull) Description() string {
5252

5353
// Eval implements the Expression interface.
5454
func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
55+
t := f.Type()
56+
5557
left, err := f.LeftChild.Eval(ctx, row)
5658
if err != nil {
5759
return nil, err
5860
}
5961
if left != nil {
60-
left, _, err = f.Type().Convert(ctx, left)
62+
if ret, _, err := t.Convert(ctx, left); err == nil {
63+
return ret, nil
64+
}
6165
return left, err
6266
}
6367

6468
right, err := f.RightChild.Eval(ctx, row)
6569
if err != nil {
6670
return nil, err
6771
}
68-
right, _, err = f.Type().Convert(ctx, right)
72+
if ret, _, err := t.Convert(ctx, right); err == nil {
73+
return ret, nil
74+
}
6975
return right, err
7076
}
7177

sql/types/conversion.go

Lines changed: 154 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -555,24 +555,166 @@ func TypesEqual(a, b sql.Type) bool {
555555
}
556556
}
557557

558+
// generalizeNumberTypes assumes both inputs return true for IsNumber
559+
func generalizeNumberTypes(a, b sql.Type) sql.Type {
560+
if a == Float64 || b == Float64 {
561+
return Float64
562+
}
563+
if a == Float32 || b == Float32 {
564+
return Float32
565+
}
566+
567+
if IsDecimal(a) || IsDecimal(b) {
568+
// TODO: match precision and scale to that of the decimal type, check if defines column
569+
return MustCreateDecimalType(DecimalTypeMaxPrecision, DecimalTypeMaxScale)
570+
}
571+
572+
aIsSigned := IsSigned(a)
573+
bIsSigned := IsSigned(b)
574+
575+
if a == Uint64 || b == Uint64 {
576+
if aIsSigned || bIsSigned {
577+
return MustCreateDecimalType(DecimalTypeMaxPrecision, 0)
578+
}
579+
return Uint64
580+
}
581+
582+
if a == Int64 || b == Int64 {
583+
return Int64
584+
}
585+
586+
if a == Uint32 || b == Uint32 {
587+
if aIsSigned || bIsSigned {
588+
return Int64
589+
}
590+
return Uint32
591+
}
592+
593+
if a == Int32 || b == Int32 {
594+
return Int32
595+
}
596+
597+
if a == Uint24 || b == Uint24 {
598+
if aIsSigned || bIsSigned {
599+
return Int32
600+
}
601+
}
602+
603+
if a == Int24 || b == Int24 {
604+
return Int24
605+
}
606+
607+
if a == Uint16 || b == Uint16 {
608+
if aIsSigned || bIsSigned {
609+
return Int24
610+
}
611+
return Uint16
612+
}
613+
614+
if a == Int16 || b == Int16 {
615+
return Int16
616+
}
617+
618+
if a == Uint8 || b == Uint8 {
619+
if aIsSigned || bIsSigned {
620+
return Int16
621+
}
622+
return Uint8
623+
}
624+
625+
if a == Int8 || b == Int8 {
626+
return Int8
627+
}
628+
629+
if IsBoolean(a) && IsBoolean(b) {
630+
return Boolean
631+
}
632+
633+
return Int64
634+
}
635+
558636
// GeneralizeTypes returns the more "general" of two types as defined by
559-
// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html#function_if and
560-
// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html#function_ifnull
561-
// TODO: Currently returns the most general type via Promote(). Update to match MySQL (pick the more general of the two
562-
// given types)
637+
// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html
638+
// TODO: Create and handle "Illegal mix of collations" error
563639
func GeneralizeTypes(a, b sql.Type) sql.Type {
564-
if IsText(a) || IsText(b) {
565-
// TODO: handle case-sensitive strings
566-
return LongText
640+
if a == Null {
641+
return b
642+
}
643+
if b == Null {
644+
return a
567645
}
568646

569-
if IsFloat(a) || IsFloat(b) {
570-
return Float64
647+
if IsJSON(a) && IsJSON(b) {
648+
return JSON
571649
}
572650

573-
if a == Null {
574-
return b.Promote()
651+
if IsGeometry(a) && IsGeometry(b) {
652+
return a
575653
}
576654

577-
return a.Promote()
655+
if IsEnum(a) && IsEnum(b) {
656+
return a
657+
}
658+
659+
if IsSet(a) && IsSet(b) {
660+
return a
661+
}
662+
663+
aIsTimespan := IsTimespan(a)
664+
bIsTimespan := IsTimespan(b)
665+
if aIsTimespan && bIsTimespan {
666+
return a
667+
}
668+
if (IsTime(a) || aIsTimespan) && (IsTime(b) || bIsTimespan) {
669+
if IsDateType(a) && IsDateType(b) {
670+
return Date
671+
}
672+
if IsTimestampType(a) && IsTimestampType(b) {
673+
// TODO: match precision to max precision of the two timestamps
674+
return TimestampMaxPrecision
675+
}
676+
// TODO: match precision to max precision of the two time types
677+
return DatetimeMaxPrecision
678+
}
679+
680+
if IsBlobType(a) || IsBlobType(b) {
681+
return Blob
682+
}
683+
684+
aIsBit := IsBit(a)
685+
bIsBit := IsBit(b)
686+
if aIsBit && bIsBit {
687+
// TODO: match max bits to max of max bits between a and b
688+
return a.Promote()
689+
}
690+
if aIsBit {
691+
a = Int64
692+
}
693+
if bIsBit {
694+
b = Int64
695+
}
696+
697+
aIsYear := IsYear(a)
698+
bIsYear := IsYear(b)
699+
if aIsYear && bIsYear {
700+
return a
701+
}
702+
if aIsYear {
703+
a = Int32
704+
}
705+
if bIsYear {
706+
b = Int32
707+
}
708+
709+
if IsNumber(a) && IsNumber(b) {
710+
if svt, ok := a.(sql.SystemVariableType); ok {
711+
a = svt.UnderlyingType()
712+
}
713+
if svt, ok := a.(sql.SystemVariableType); ok {
714+
b = svt.UnderlyingType()
715+
}
716+
return generalizeNumberTypes(a, b)
717+
}
718+
// TODO: decide if we want to make this VarChar to match MySQL, match VarChar length to max of two types
719+
return LongText
578720
}

sql/types/conversion_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ func TestGeneralizeTypes(t *testing.T) {
168168
{Text, Text, LongText},
169169
{Text, Float64, LongText},
170170
{Int64, Text, LongText},
171-
{Float32, Float32, Float64},
171+
{Float32, Float32, Float32},
172172
{Int64, Float64, Float64},
173-
{Int32, Int32, Int64},
173+
{Int32, Int32, Int32},
174174
{Null, Null, Null},
175175
}
176176
for _, test := range tests {

0 commit comments

Comments
 (0)