Skip to content

Commit 5be1404

Browse files
committed
handle empty strings in set comparisons
1 parent 4fa5f89 commit 5be1404

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

sql/expression/comparison.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,20 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
141141
return c.Left().Type().Compare(ctx, left, right)
142142
}
143143

144-
left, right, compareType, err := c.castLeftAndRight(ctx, left, right)
144+
l, r, compareType, err := c.castLeftAndRight(ctx, left, right)
145145
if err != nil {
146146
return 0, err
147147
}
148148

149+
// Set comparison relies on empty strings not being converted yet
150+
if types.IsSet(compareType) {
151+
return compareType.Compare(ctx, left, right)
152+
}
149153
collationPreference, _ := c.CollationCoercibility(ctx)
150154
if stringCompareType, ok := compareType.(sql.StringType); ok {
151155
compareType = types.MustCreateString(stringCompareType.Type(), stringCompareType.Length(), collationPreference)
152156
}
153-
154-
return compareType.Compare(ctx, left, right)
157+
return compareType.Compare(ctx, l, r)
155158
}
156159

157160
func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) {

sql/types/set.go

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,24 @@ func (t SetType) Compare(ctx context.Context, a interface{}, b interface{}) (int
129129
au := ai.(uint64)
130130
bu := bi.(uint64)
131131

132+
// If there's an empty string in the set, empty strings should match both 0 and an empty string bit field
133+
if emptyStringBitField, ok := t.emptyStringBitField(); ok {
134+
aIsEmptyString := isEmptyString(a)
135+
bIsEmptyString := isEmptyString(b)
136+
if aIsEmptyString {
137+
if bu == 0 || bu == emptyStringBitField {
138+
return 0, nil
139+
}
140+
return -1, nil
141+
}
142+
if bIsEmptyString {
143+
if au == 0 || au == emptyStringBitField {
144+
return 0, nil
145+
}
146+
return 1, nil
147+
}
148+
}
149+
132150
if au < bu {
133151
return -1, nil
134152
} else if au > bu {
@@ -180,7 +198,7 @@ func (t SetType) Convert(ctx context.Context, v interface{}) (interface{}, sql.C
180198
return t.Convert(ctx, value.Decimal.BigInt().Uint64())
181199
case string:
182200
ret, err := t.convertStringToBitField(value)
183-
return ret, sql.InRange, err
201+
return ret, err == nil, err
184202
case []byte:
185203
return t.Convert(ctx, string(value))
186204
}
@@ -364,7 +382,7 @@ func (t SetType) convertStringToBitField(str string) (uint64, error) {
364382
return 0, nil
365383
}
366384
var bitField uint64
367-
_, allowEmptyString := t.valToBit[""]
385+
_, allowEmptyString := t.emptyStringBitField()
368386
lastI := 0
369387
var val string
370388
for i := 0; i < len(str)+1; i++ {
@@ -410,3 +428,16 @@ func (t SetType) convertStringToBitField(str string) (uint64, error) {
410428
}
411429
return bitField, nil
412430
}
431+
432+
func (t SetType) emptyStringBitField() (bitField uint64, ok bool) {
433+
bitField, ok = t.valToBit[""]
434+
return bitField, ok
435+
}
436+
437+
func isEmptyString(val interface{}) bool {
438+
switch v := val.(type) {
439+
case string:
440+
return v == ""
441+
}
442+
return false
443+
}

0 commit comments

Comments
 (0)