Skip to content

Commit 48e833f

Browse files
committed
Unwrap values before casting them to strings or bytes.
1 parent b49ea60 commit 48e833f

File tree

13 files changed

+88
-22
lines changed

13 files changed

+88
-22
lines changed

enginetest/enginetests.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,9 @@ func TestOrderByGroupBy(t *testing.T, harness Harness) {
848848
panic(fmt.Sprintf("unexpected type %T", v))
849849
}
850850

851-
team := row[1].(string)
851+
team, ok, err := sql.Unwrap[string](ctx, row[1])
852+
require.True(t, ok)
853+
require.NoError(t, err)
852854
switch team {
853855
case "red":
854856
require.True(t, val == 3 || val == 4)

sql/expression/function/aggregation/group_concat.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error
222222
if err != nil {
223223
return err
224224
}
225-
vs = string(v.([]byte))
225+
vb, _, err := sql.Unwrap[[]byte](ctx, v)
226+
if err != nil {
227+
return err
228+
}
229+
vs = string(vb)
226230
if len(vs) == 0 {
227231
return nil
228232
}
@@ -234,7 +238,10 @@ func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error
234238
if v == nil {
235239
return nil
236240
}
237-
vs = v.(string)
241+
vs, _, err = sql.Unwrap[string](ctx, v)
242+
if err != nil {
243+
return err
244+
}
238245
}
239246

240247
// Get the current array of rows and the map

sql/expression/function/concat.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ func (c *Concat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
128128
return nil, err
129129
}
130130

131+
val, _, err = sql.Unwrap[string](ctx, val)
132+
if err != nil {
133+
return nil, err
134+
}
135+
131136
parts = append(parts, val.(string))
132137
}
133138

sql/expression/function/convert_tz.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ func (c *ConvertTz) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
9393
}
9494

9595
// If either the date, or the timezones/offsets are not correct types we return NULL.
96-
datetime, err := types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(dt)
96+
datetime, err := types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(ctx, dt)
9797
if err != nil {
9898
return nil, nil
9999
}
@@ -121,7 +121,7 @@ func (c *ConvertTz) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
121121
return nil, nil
122122
}
123123

124-
return types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(converted)
124+
return types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(ctx, converted)
125125
}
126126

127127
// Children implements the sql.Expression interface.

sql/expression/function/extract.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (td *Extract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
100100
return nil, nil
101101
}
102102

103-
right, err = types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(right)
103+
right, err = types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(ctx, right)
104104
if err != nil {
105105
ctx.Warn(1292, err.Error())
106106
return nil, nil

sql/expression/function/hash.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ func (f *SHA1) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
126126
if err != nil {
127127
return nil, err
128128
}
129+
val, err = sql.UnwrapAny(ctx, val)
130+
if err != nil {
131+
return nil, err
132+
}
129133

130134
h := sha1.New()
131135
_, err = io.WriteString(h, string(val.([]byte)))

sql/expression/function/spatial/geojson.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,11 @@ func (g *GeomFromGeoJSON) Eval(ctx *sql.Context, row sql.Row) (interface{}, erro
672672
return nil, err
673673
}
674674

675+
val, err = sql.UnwrapAny(ctx, val)
676+
if err != nil {
677+
return nil, err
678+
}
679+
675680
switch s := val.(type) {
676681
case string:
677682
val = []byte(s)

sql/expression/function/string.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,22 @@ func (h *Hex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
185185
}
186186

187187
switch val := arg.(type) {
188-
case string:
188+
case string, sql.StringWrapper:
189+
s, _, err := sql.Unwrap[string](ctx, val)
190+
if err != nil {
191+
return nil, err
192+
}
189193
childType := h.Child.Type()
190194
if types.IsTextOnly(childType) {
191195
// For string types we need to re-encode the internal string so that we get the correct hex output
192196
encoder := childType.(sql.StringType).Collation().CharacterSet().Encoder()
193-
encodedBytes, ok := encoder.Encode(encodings.StringToBytes(val))
197+
encodedBytes, ok := encoder.Encode(encodings.StringToBytes(s))
194198
if !ok {
195199
return nil, fmt.Errorf("unable to re-encode string for HEX function")
196200
}
197201
return hexForString(encodings.BytesToString(encodedBytes)), nil
198202
} else {
199-
return hexForString(val), nil
203+
return hexForString(s), nil
200204
}
201205

202206
case uint8, uint16, uint32, uint, int, int8, int16, int32, int64:
@@ -244,8 +248,12 @@ func (h *Hex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
244248

245249
return hexForString(s), nil
246250

247-
case []byte:
248-
return hexForString(string(val)), nil
251+
case []byte, sql.BytesWrapper:
252+
b, _, err := sql.Unwrap[[]byte](ctx, val)
253+
if err != nil {
254+
return nil, err
255+
}
256+
return hexForString(string(b)), nil
249257

250258
case types.GeometryValue:
251259
return hexForString(string(val.Serialize())), nil

sql/expression/function/substring.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,20 @@ func (r Right) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
455455
switch str := str.(type) {
456456
case string:
457457
text = []rune(str)
458+
case sql.StringWrapper:
459+
s, err := str.Unwrap(ctx)
460+
if err != nil {
461+
return nil, err
462+
}
463+
text = []rune(s)
458464
case []byte:
459465
text = []rune(string(str))
466+
case sql.BytesWrapper:
467+
b, err := str.Unwrap(ctx)
468+
if err != nil {
469+
return nil, err
470+
}
471+
text = []rune(string(b))
460472
case nil:
461473
return nil, nil
462474
default:

sql/expression/function/time.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func getDate(ctx *sql.Context,
4848
return nil, nil
4949
}
5050

51-
date, err := types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(val)
51+
date, err := types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(ctx, val)
5252
if err != nil {
5353
ctx.Warn(1292, "Incorrect datetime value: '%s'", val)
5454
return nil, nil
@@ -1676,7 +1676,7 @@ func (t *Time) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
16761676
}
16771677

16781678
// convert to date
1679-
date, err := types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(v)
1679+
date, err := types.DatetimeMaxPrecision.ConvertWithoutRangeCheck(ctx, v)
16801680
if err == nil {
16811681
h, m, s := date.Clock()
16821682
us := date.Nanosecond() / 1000

0 commit comments

Comments
 (0)