Skip to content

Commit 328e24e

Browse files
authored
Merge pull request #3111 from dolthub/angela/compare_type
Enum and set conversions in comparison.Compare
2 parents e4d9d0a + c0ce6d6 commit 328e24e

File tree

4 files changed

+93
-48
lines changed

4 files changed

+93
-48
lines changed

enginetest/queries/script_queries.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9320,6 +9320,12 @@ where
93209320
{"hijkl", []uint8("hijkl")},
93219321
},
93229322
},
9323+
{
9324+
Query: "select e from t where e like 'a%'",
9325+
Expected: []sql.Row{
9326+
{"abc"},
9327+
},
9328+
},
93239329
},
93249330
},
93259331
{
@@ -9764,9 +9770,6 @@ where
97649770
},
97659771
},
97669772
{
9767-
// this is failing due to a type coercion bug in comparison.Compare
9768-
// https://github.com/dolthub/dolt/issues/9510
9769-
Skip: true,
97709773
Query: "select i, s + 0, s from t where s = '';",
97719774
Expected: []sql.Row{
97729775
{0, float64(0), ""},
@@ -9820,8 +9823,6 @@ where
98209823
},
98219824
},
98229825
{
9823-
// https://github.com/dolthub/dolt/issues/9510
9824-
Skip: true,
98259826
Query: "select s from t where s like 'a%' order by s;",
98269827
Expected: []sql.Row{
98279828
{"abc"},

sql/expression/comparison.go

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -141,50 +141,20 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
141141
return c.Left().Type().Compare(ctx, left, right)
142142
}
143143

144-
// ENUM, SET, and TIME must be excluded when doing comparisons, as they're too restrictive to use as a comparison
145-
// base.
146-
//
147-
// The best overall method would be to assign type priority. For example, INT would have a higher priority than
148-
// TINYINT. This could then be combined with the origin of the value (table column, procedure param, etc.) to
149-
// determine the best type for any comparison (tie-breakers can be simple rules such as the current left preference).
150-
var compareType sql.Type
151-
collationPreference := sql.Collation_Default
152-
switch c.Left().(type) {
153-
case *GetField, *UserVar, *SystemVar, *ProcedureParam:
154-
compareType = c.Left().Type()
155-
if twc, ok := compareType.(sql.TypeWithCollation); ok {
156-
collationPreference = twc.Collation()
157-
}
158-
default:
159-
switch c.Right().(type) {
160-
case *GetField, *UserVar, *SystemVar, *ProcedureParam:
161-
compareType = c.Right().Type()
162-
if twc, ok := compareType.(sql.TypeWithCollation); ok {
163-
collationPreference = twc.Collation()
164-
}
165-
}
166-
}
167-
if compareType != nil {
168-
_, isEnum := compareType.(sql.EnumType)
169-
_, isSet := compareType.(sql.SetType)
170-
_, isTime := compareType.(types.TimeType)
171-
if !isEnum && !isSet && !isTime {
172-
compareType = nil
173-
}
144+
l, r, compareType, err := c.castLeftAndRight(ctx, left, right)
145+
if err != nil {
146+
return 0, err
174147
}
175-
if compareType == nil {
176-
left, right, compareType, err = c.castLeftAndRight(ctx, left, right)
177-
if err != nil {
178-
return 0, err
179-
}
148+
149+
// Set comparison relies on empty strings not being converted yet
150+
if types.IsSet(compareType) {
151+
return compareType.Compare(ctx, left, right)
180152
}
181-
if _, isSet := compareType.(sql.SetType); !isSet && types.IsTextOnly(compareType) {
182-
collationPreference, _ = c.CollationCoercibility(ctx)
183-
stringCompareType := compareType.(sql.StringType)
153+
collationPreference, _ := c.CollationCoercibility(ctx)
154+
if stringCompareType, ok := compareType.(sql.StringType); ok && types.IsTextOnly(stringCompareType) {
184155
compareType = types.MustCreateString(stringCompareType.Type(), stringCompareType.Length(), collationPreference)
185156
}
186-
187-
return compareType.Compare(ctx, left, right)
157+
return compareType.Compare(ctx, l, r)
188158
}
189159

190160
func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {
@@ -204,6 +174,49 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{
204174
func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
205175
leftType := c.Left().Type()
206176
rightType := c.Right().Type()
177+
178+
leftIsEnumOrSet := types.IsEnum(leftType) || types.IsSet(leftType)
179+
rightIsEnumOrSet := types.IsEnum(rightType) || types.IsSet(rightType)
180+
// Only convert if same Enum or Set
181+
if leftIsEnumOrSet && rightIsEnumOrSet {
182+
if types.TypesEqual(leftType, rightType) {
183+
return left, right, leftType, nil
184+
}
185+
} else {
186+
// If right side is convertible to enum/set, convert. Otherwise, convert left side
187+
if leftIsEnumOrSet && (types.IsText(rightType) || types.IsNumber(rightType)) {
188+
if r, inRange, err := leftType.Convert(ctx, right); inRange && err == nil {
189+
return left, r, leftType, nil
190+
} else {
191+
l, err := types.TypeAwareConversion(ctx, left, leftType, rightType)
192+
if err != nil {
193+
return nil, nil, nil, err
194+
}
195+
return l, right, rightType, nil
196+
}
197+
}
198+
// If left side is convertible to enum/set, convert. Otherwise, convert right side
199+
if rightIsEnumOrSet && (types.IsText(leftType) || types.IsNumber(leftType)) {
200+
if l, inRange, err := rightType.Convert(ctx, left); inRange && err == nil {
201+
return l, right, rightType, nil
202+
} else {
203+
r, err := types.TypeAwareConversion(ctx, right, rightType, leftType)
204+
if err != nil {
205+
return nil, nil, nil, err
206+
}
207+
return left, r, leftType, nil
208+
}
209+
}
210+
}
211+
212+
if types.IsTimespan(leftType) || types.IsTimespan(rightType) {
213+
if l, err := types.Time.ConvertToTimespan(left); err == nil {
214+
if r, err := types.Time.ConvertToTimespan(right); err == nil {
215+
return l, r, types.Time, nil
216+
}
217+
}
218+
}
219+
207220
if types.IsTuple(leftType) && types.IsTuple(rightType) {
208221
return left, right, c.Left().Type(), nil
209222
}

sql/types/enum.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.
218218
return t.Convert(ctx, string(value))
219219
}
220220

221-
return nil, sql.InRange, ErrConvertingToEnum.New(v)
221+
return nil, sql.OutOfRange, ErrConvertingToEnum.New(v)
222222
}
223223

224224
// Equals implements the Type interface.

sql/types/set.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,24 @@ func (t SetType) Compare(ctx context.Context, a interface{}, b interface{}) (int
129129
au := ai.(uint64)
130130
bu := bi.(uint64)
131131

132+
// If there's an empty string in the set, empty strings should match both 0 and an empty string bit field
133+
if emptyStringBitField, ok := t.emptyStringBitField(); ok {
134+
aIsEmptyString := isEmptyString(a)
135+
bIsEmptyString := isEmptyString(b)
136+
if aIsEmptyString {
137+
if bu == 0 || bu == emptyStringBitField {
138+
return 0, nil
139+
}
140+
return -1, nil
141+
}
142+
if bIsEmptyString {
143+
if au == 0 || au == emptyStringBitField {
144+
return 0, nil
145+
}
146+
return 1, nil
147+
}
148+
}
149+
132150
if au < bu {
133151
return -1, nil
134152
} else if au > bu {
@@ -180,7 +198,7 @@ func (t SetType) Convert(ctx context.Context, v interface{}) (interface{}, sql.C
180198
return t.Convert(ctx, value.Decimal.BigInt().Uint64())
181199
case string:
182200
ret, err := t.convertStringToBitField(value)
183-
return ret, sql.InRange, err
201+
return ret, err == nil, err
184202
case []byte:
185203
return t.Convert(ctx, string(value))
186204
}
@@ -364,7 +382,7 @@ func (t SetType) convertStringToBitField(str string) (uint64, error) {
364382
return 0, nil
365383
}
366384
var bitField uint64
367-
_, allowEmptyString := t.valToBit[""]
385+
_, allowEmptyString := t.emptyStringBitField()
368386
lastI := 0
369387
var val string
370388
for i := 0; i < len(str)+1; i++ {
@@ -410,3 +428,16 @@ func (t SetType) convertStringToBitField(str string) (uint64, error) {
410428
}
411429
return bitField, nil
412430
}
431+
432+
func (t SetType) emptyStringBitField() (bitField uint64, ok bool) {
433+
bitField, ok = t.valToBit[""]
434+
return bitField, ok
435+
}
436+
437+
func isEmptyString(val interface{}) bool {
438+
switch v := val.(type) {
439+
case string:
440+
return v == ""
441+
}
442+
return false
443+
}

0 commit comments

Comments
 (0)