Skip to content

Commit 7f38dfa

Browse files
committed
Convert enums and sets to string if comparing to a string
1 parent 9ae1a2f commit 7f38dfa

File tree

2 files changed

+70
-44
lines changed

2 files changed

+70
-44
lines changed

enginetest/queries/script_queries.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9320,6 +9320,12 @@ where
93209320
{"hijkl", []uint8("hijkl")},
93219321
},
93229322
},
9323+
{
9324+
Query: "select e from t where e like 'a%'",
9325+
Expected: []sql.Row{
9326+
{"abc"},
9327+
},
9328+
},
93239329
},
93249330
},
93259331
{
@@ -9764,9 +9770,6 @@ where
97649770
},
97659771
},
97669772
{
9767-
// this is failing due to a type coercion bug in comparison.Compare
9768-
// https://github.com/dolthub/dolt/issues/9510
9769-
Skip: true,
97709773
Query: "select i, s + 0, s from t where s = '';",
97719774
Expected: []sql.Row{
97729775
{0, float64(0), ""},
@@ -9820,8 +9823,6 @@ where
98209823
},
98219824
},
98229825
{
9823-
// https://github.com/dolthub/dolt/issues/9510
9824-
Skip: true,
98259826
Query: "select s from t where s like 'a%' order by s;",
98269827
Expected: []sql.Row{
98279828
{"abc"},

sql/expression/comparison.go

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -141,46 +141,51 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
141141
return c.Left().Type().Compare(ctx, left, right)
142142
}
143143

144-
// ENUM, SET, and TIME must be excluded when doing comparisons, as they're too restrictive to use as a comparison
145-
// base.
146-
//
147-
// The best overall method would be to assign type priority. For example, INT would have a higher priority than
148-
// TINYINT. This could then be combined with the origin of the value (table column, procedure param, etc.) to
149-
// determine the best type for any comparison (tie-breakers can be simple rules such as the current left preference).
150-
var compareType sql.Type
151-
collationPreference := sql.Collation_Default
152-
switch c.Left().(type) {
153-
case *GetField, *UserVar, *SystemVar, *ProcedureParam:
154-
compareType = c.Left().Type()
155-
if twc, ok := compareType.(sql.TypeWithCollation); ok {
156-
collationPreference = twc.Collation()
157-
}
158-
default:
159-
switch c.Right().(type) {
160-
case *GetField, *UserVar, *SystemVar, *ProcedureParam:
161-
compareType = c.Right().Type()
162-
if twc, ok := compareType.(sql.TypeWithCollation); ok {
163-
collationPreference = twc.Collation()
164-
}
165-
}
166-
}
167-
if compareType != nil {
168-
_, isEnum := compareType.(sql.EnumType)
169-
_, isSet := compareType.(sql.SetType)
170-
_, isTime := compareType.(types.TimeType)
171-
if !isEnum && !isSet && !isTime {
172-
compareType = nil
173-
}
174-
}
175-
if compareType == nil {
176-
left, right, compareType, err = c.castLeftAndRight(ctx, left, right)
177-
if err != nil {
178-
return 0, err
179-
}
144+
left, right, compareType, err := c.castLeftAndRight(ctx, left, right)
145+
if err != nil {
146+
return 0, err
180147
}
181-
if _, isSet := compareType.(sql.SetType); !isSet && types.IsTextOnly(compareType) {
182-
collationPreference, _ = c.CollationCoercibility(ctx)
183-
stringCompareType := compareType.(sql.StringType)
148+
149+
//var compareType sql.Type
150+
//// ENUM, SET, and TIME must be excluded when doing comparisons, as they're too restrictive to use as a comparison
151+
//// base.
152+
////
153+
//// The best overall method would be to assign type priority. For example, INT would have a higher priority than
154+
//// TINYINT. This could then be combined with the origin of the value (table column, procedure param, etc.) to
155+
//// determine the best type for any comparison (tie-breakers can be simple rules such as the current left preference).
156+
//
157+
//collationPreference := sql.Collation_Default
158+
//switch c.Left().(type) {
159+
//case *GetField, *UserVar, *SystemVar, *ProcedureParam:
160+
// compareType = c.Left().Type()
161+
// if twc, ok := compareType.(sql.TypeWithCollation); ok {
162+
// collationPreference = twc.Collation()
163+
// }
164+
//default:
165+
// switch c.Right().(type) {
166+
// case *GetField, *UserVar, *SystemVar, *ProcedureParam:
167+
// compareType = c.Right().Type()
168+
// if twc, ok := compareType.(sql.TypeWithCollation); ok {
169+
// collationPreference = twc.Collation()
170+
// }
171+
// }
172+
//}
173+
//if compareType != nil {
174+
// _, isEnum := compareType.(sql.EnumType)
175+
// _, isSet := compareType.(sql.SetType)
176+
// _, isTime := compareType.(types.TimeType)
177+
// if !isEnum && !isSet && !isTime {
178+
// compareType = nil
179+
// }
180+
//}
181+
//if compareType == nil {
182+
// left, right, compareType, err = c.castLeftAndRight(ctx, left, right)
183+
// if err != nil {
184+
// return 0, err
185+
// }
186+
//}
187+
collationPreference, _ := c.CollationCoercibility(ctx)
188+
if stringCompareType, ok := compareType.(sql.StringType); ok {
184189
compareType = types.MustCreateString(stringCompareType.Type(), stringCompareType.Length(), collationPreference)
185190
}
186191

@@ -204,6 +209,26 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{
204209
func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
205210
leftType := c.Left().Type()
206211
rightType := c.Right().Type()
212+
213+
leftIsEnumOrSet := types.IsEnum(leftType) || types.IsSet(leftType)
214+
rightIsEnumOrSet := types.IsEnum(rightType) || types.IsSet(rightType)
215+
leftIsText := types.IsTextOnly(leftType)
216+
rightIsText := types.IsTextOnly(rightType)
217+
if (leftIsEnumOrSet && rightIsText) || (rightIsEnumOrSet && !leftIsText) {
218+
l, err := types.TypeAwareConversion(ctx, left, leftType, rightType)
219+
if err != nil {
220+
return nil, nil, nil, err
221+
}
222+
return l, right, rightType, nil
223+
}
224+
if (rightIsEnumOrSet && leftIsText) || (leftIsEnumOrSet && !rightIsText) {
225+
r, err := types.TypeAwareConversion(ctx, right, rightType, leftType)
226+
if err != nil {
227+
return nil, nil, nil, err
228+
}
229+
return left, r, leftType, nil
230+
}
231+
207232
if types.IsTuple(leftType) && types.IsTuple(rightType) {
208233
return left, right, c.Left().Type(), nil
209234
}

0 commit comments

Comments
 (0)