Skip to content

Commit 260b407

Browse files
authored
Properly cast ENUMs to TEXT for CASE and CONVERT statements (#2791)
1 parent 09a7e80 commit 260b407

File tree

6 files changed

+177
-27
lines changed

6 files changed

+177
-27
lines changed

enginetest/queries/script_queries.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7641,14 +7641,14 @@ where
76417641
},
76427642
{
76437643
// https://github.com/dolthub/dolt/issues/8598
7644-
Name: "enum cast to int and string",
7644+
Name: "enum cast to int and string",
7645+
Dialect: "mysql",
76457646
SetUpScript: []string{
76467647
"create table t (i int primary key, e enum('abc', 'def', 'ghi'));",
76477648
"insert into t values (1, 'abc'), (2, 'def'), (3, 'ghi');",
76487649
},
76497650
Assertions: []ScriptTestAssertion{
76507651
{
7651-
Skip: true,
76527652
Query: "select i, cast(e as signed) from t;",
76537653
Expected: []sql.Row{
76547654
{1, 1},
@@ -7657,14 +7657,46 @@ where
76577657
},
76587658
},
76597659
{
7660-
Skip: true,
76617660
Query: "select i, cast(e as char) from t;",
76627661
Expected: []sql.Row{
76637662
{1, "abc"},
76647663
{2, "def"},
76657664
{3, "ghi"},
76667665
},
76677666
},
7667+
{
7668+
Query: "select i, cast(e as binary) from t;",
7669+
Expected: []sql.Row{
7670+
{1, []uint8("abc")},
7671+
{2, []uint8("def")},
7672+
{3, []uint8("ghi")},
7673+
},
7674+
},
7675+
{
7676+
Query: "select case when e = 'abc' then 'abc' when e = 'def' then 123 else e end from t",
7677+
Expected: []sql.Row{
7678+
{"abc"},
7679+
{"123"},
7680+
{"ghi"},
7681+
},
7682+
},
7683+
},
7684+
},
7685+
{
7686+
Name: "enum errors",
7687+
Dialect: "mysql",
7688+
SetUpScript: []string{
7689+
"create table t (i int primary key, e enum('abc', 'def', 'ghi'));",
7690+
},
7691+
Assertions: []ScriptTestAssertion{
7692+
{
7693+
Query: "insert into t values (1, 500)",
7694+
ExpectedErrStr: "value 500 is not valid for this Enum",
7695+
},
7696+
{
7697+
Query: "insert into t values (1, -1)",
7698+
ExpectedErrStr: "value -1 is not valid for this Enum",
7699+
},
76687700
},
76697701
},
76707702
}

sql/expression/enum.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
package expression
15+
16+
import (
17+
"github.com/dolthub/go-mysql-server/sql"
18+
"github.com/dolthub/go-mysql-server/sql/types"
19+
)
20+
21+
// EnumToString is an expression that converts an enum value to a string.
22+
type EnumToString struct {
23+
Enum sql.Expression
24+
}
25+
26+
var _ sql.Expression = (*EnumToString)(nil)
27+
var _ sql.CollationCoercible = (*EnumToString)(nil)
28+
29+
func NewEnumToString(enum sql.Expression) *EnumToString {
30+
return &EnumToString{Enum: enum}
31+
}
32+
33+
// Type implements the sql.Expression interface.
34+
func (e *EnumToString) Type() sql.Type {
35+
return types.Text
36+
}
37+
38+
// CollationCoercibility implements the interface sql.CollationCoercible.
39+
func (e *EnumToString) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
40+
return e.Type().CollationCoercibility(ctx)
41+
}
42+
43+
// IsNullable implements the sql.Expression interface.
44+
func (e *EnumToString) IsNullable() bool {
45+
return e.Enum.IsNullable()
46+
}
47+
48+
// Resolved implements the sql.Expression interface.
49+
func (e *EnumToString) Resolved() bool {
50+
return e.Enum.Resolved()
51+
}
52+
53+
// Children implements the sql.Expression interface.
54+
func (e *EnumToString) Children() []sql.Expression {
55+
return []sql.Expression{e.Enum}
56+
}
57+
58+
// Eval implements the sql.Expression interface.
59+
func (e *EnumToString) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
60+
span, ctx := ctx.Span("expression.EnumToString")
61+
defer span.End()
62+
63+
val, err := e.Enum.Eval(ctx, row)
64+
if err != nil {
65+
return nil, err
66+
}
67+
if val == nil {
68+
return nil, nil
69+
}
70+
71+
enumType := e.Enum.Type().(types.EnumType)
72+
str, _ := enumType.At(int(val.(uint16)))
73+
return str, nil
74+
}
75+
76+
// WithChildren implements the Expression interface.
77+
func (e *EnumToString) WithChildren(children ...sql.Expression) (sql.Expression, error) {
78+
if len(children) != 1 {
79+
return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1)
80+
}
81+
82+
return NewEnumToString(children[0]), nil
83+
}
84+
85+
// String implements the sql.Expression interface.
86+
func (e *EnumToString) String() string {
87+
return e.Enum.String()
88+
}
89+
90+
// DebugString implements the sql.Expression interface.
91+
func (e *EnumToString) DebugString() string {
92+
return sql.DebugString(e.Enum)
93+
}

sql/planbuilder/factory.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ package planbuilder
1717
import (
1818
"strings"
1919

20+
"github.com/dolthub/go-mysql-server/sql/types"
21+
2022
"github.com/dolthub/go-mysql-server/sql"
2123
"github.com/dolthub/go-mysql-server/sql/expression"
2224
"github.com/dolthub/go-mysql-server/sql/plan"
@@ -132,6 +134,13 @@ func (f *factory) buildConvert(expr sql.Expression, castToType string, typeLengt
132134
return expr, nil
133135
}
134136
}
137+
if types.IsText(n.Type()) && types.IsEnum(expr.Type()) {
138+
newNode, err := n.WithChildren(expression.NewEnumToString(expr))
139+
if err != nil {
140+
return nil, err
141+
}
142+
return newNode, nil
143+
}
135144
return n, nil
136145
}
137146

sql/planbuilder/scalar.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,20 @@ func (b *Builder) caseExprToExpression(inScope *scope, e *ast.CaseExpr) (sql.Exp
794794
elseExpr = b.buildScalar(inScope, e.Else)
795795
}
796796

797-
return expression.NewCase(expr, branches, elseExpr), nil
797+
newCase := expression.NewCase(expr, branches, elseExpr)
798+
if types.IsText(newCase.Type()) {
799+
for _, branch := range branches {
800+
if types.IsEnum(branch.Value.Type()) {
801+
branch.Value = expression.NewEnumToString(branch.Value)
802+
}
803+
}
804+
if elseExpr != nil && types.IsEnum(elseExpr.Type()) {
805+
elseExpr = expression.NewEnumToString(elseExpr)
806+
}
807+
newCase = expression.NewCase(expr, branches, elseExpr)
808+
}
809+
810+
return newCase, nil
798811
}
799812

800813
func (b *Builder) intervalExprToExpression(inScope *scope, e *ast.IntervalExpr) *expression.Interval {

sql/types/conversion.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,11 +517,11 @@ func TypesEqual(a, b sql.Type) bool {
517517
case EnumType:
518518
aEnumType := at
519519
bEnumType := b.(EnumType)
520-
if len(aEnumType.indexToVal) != len(bEnumType.indexToVal) {
520+
if len(aEnumType.idxToVal) != len(bEnumType.idxToVal) {
521521
return false
522522
}
523-
for i := 0; i < len(aEnumType.indexToVal); i++ {
524-
if aEnumType.indexToVal[i] != bEnumType.indexToVal[i] {
523+
for i := 0; i < len(aEnumType.idxToVal); i++ {
524+
if aEnumType.idxToVal[i] != bEnumType.idxToVal[i] {
525525
return false
526526
}
527527
}

sql/types/enum.go

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ var (
4848

4949
type EnumType struct {
5050
collation sql.CollationID
51-
hashedValToIndex map[uint64]int
51+
hashedValToIdx map[uint64]int
5252
valToIdx map[string]int
53-
indexToVal []string
53+
idxToVal []string
5454
maxResponseByteLength uint32
5555
}
5656

@@ -98,8 +98,8 @@ func CreateEnumType(values []string, collation sql.CollationID) (sql.EnumType, e
9898
}
9999
return EnumType{
100100
collation: collation,
101-
hashedValToIndex: hashedValToIndex,
102-
indexToVal: values,
101+
hashedValToIdx: hashedValToIndex,
102+
idxToVal: values,
103103
valToIdx: valToIdx,
104104
maxResponseByteLength: maxResponseByteLength,
105105
}, nil
@@ -217,9 +217,9 @@ func (t EnumType) MustConvert(v interface{}) interface{} {
217217

218218
// Equals implements the Type interface.
219219
func (t EnumType) Equals(otherType sql.Type) bool {
220-
if ot, ok := otherType.(EnumType); ok && t.collation.Equals(ot.collation) && len(t.indexToVal) == len(ot.indexToVal) {
221-
for i, val := range t.indexToVal {
222-
if ot.indexToVal[i] != val {
220+
if ot, ok := otherType.(EnumType); ok && t.collation.Equals(ot.collation) && len(t.idxToVal) == len(ot.idxToVal) {
221+
for i, val := range t.idxToVal {
222+
if ot.idxToVal[i] != val {
223223
return false
224224
}
225225
}
@@ -289,16 +289,19 @@ func (t EnumType) Zero() interface{} {
289289
}
290290

291291
// At implements EnumType interface.
292-
func (t EnumType) At(index int) (string, bool) {
293-
// The elements listed in the column specification are assigned index numbers, beginning with 1.
294-
index -= 1
295-
if index <= -1 {
296-
// for index zero, the value is empty. It's used for insert ignore.
292+
func (t EnumType) At(idx int) (string, bool) {
293+
// for index zero, the value is empty. It's used for insert ignore.
294+
if idx < 0 {
295+
return "", false
296+
}
297+
if idx == 0 {
297298
return "", true
298-
} else if index >= len(t.indexToVal) {
299+
}
300+
if idx > len(t.idxToVal) {
299301
return "", false
300302
}
301-
return t.indexToVal[index], true
303+
// The elements listed in the column specification are assigned index numbers, beginning with 1.
304+
return t.idxToVal[idx-1], true
302305
}
303306

304307
// CharacterSet implements EnumType interface.
@@ -318,7 +321,7 @@ func (t EnumType) IndexOf(v string) int {
318321
}
319322
hashedVal, err := t.collation.HashToUint(v)
320323
if err == nil {
321-
if index, ok := t.hashedValToIndex[hashedVal]; ok {
324+
if index, ok := t.hashedValToIdx[hashedVal]; ok {
322325
return index
323326
}
324327
}
@@ -334,24 +337,24 @@ func (t EnumType) IndexOf(v string) int {
334337

335338
// NumberOfElements implements EnumType interface.
336339
func (t EnumType) NumberOfElements() uint16 {
337-
return uint16(len(t.indexToVal))
340+
return uint16(len(t.idxToVal))
338341
}
339342

340343
// Values implements EnumType interface.
341344
func (t EnumType) Values() []string {
342-
vals := make([]string, len(t.indexToVal))
343-
copy(vals, t.indexToVal)
345+
vals := make([]string, len(t.idxToVal))
346+
copy(vals, t.idxToVal)
344347
return vals
345348
}
346349

347350
// WithNewCollation implements sql.TypeWithCollation interface.
348351
func (t EnumType) WithNewCollation(collation sql.CollationID) (sql.Type, error) {
349-
return CreateEnumType(t.indexToVal, collation)
352+
return CreateEnumType(t.idxToVal, collation)
350353
}
351354

352355
// StringWithTableCollation implements sql.TypeWithCollation interface.
353356
func (t EnumType) StringWithTableCollation(tableCollation sql.CollationID) string {
354-
s := fmt.Sprintf("enum('%v')", strings.Join(t.indexToVal, `','`))
357+
s := fmt.Sprintf("enum('%v')", strings.Join(t.idxToVal, `','`))
355358
if t.CharacterSet() != tableCollation.CharacterSet() {
356359
s += " CHARACTER SET " + t.CharacterSet().String()
357360
}

0 commit comments

Comments
 (0)