Skip to content

Commit 561fd82

Browse files
committed
add unwrapany calls for dolt textstorage
1 parent f65884e commit 561fd82

File tree

3 files changed

+68
-11
lines changed

3 files changed

+68
-11
lines changed

sql/expression/function/reverse_repeat_replace.go

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package function
1616

1717
import (
1818
"fmt"
19+
"reflect"
1920
"strings"
2021

2122
"gopkg.in/src-d/go-errors.v1"
@@ -64,7 +65,18 @@ func (r *Reverse) Eval(
6465
return nil, err
6566
}
6667

67-
return reverseString(v.(string)), nil
68+
// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
69+
v, err = sql.UnwrapAny(ctx, v)
70+
if err != nil {
71+
return nil, err
72+
}
73+
74+
s, ok := v.(string)
75+
if !ok {
76+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(v).String())
77+
}
78+
79+
return reverseString(s), nil
6880
}
6981

7082
func reverseString(s string) string {
@@ -162,6 +174,17 @@ func (r *Repeat) Eval(
162174
return nil, err
163175
}
164176

177+
// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
178+
str, err = sql.UnwrapAny(ctx, str)
179+
if err != nil {
180+
return nil, err
181+
}
182+
183+
strVal, ok := str.(string)
184+
if !ok {
185+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
186+
}
187+
165188
count, err := r.RightChild.Eval(ctx, row)
166189
if count == nil || err != nil {
167190
return nil, err
@@ -171,10 +194,14 @@ func (r *Repeat) Eval(
171194
if err != nil {
172195
return nil, err
173196
}
174-
if count.(int32) < 0 {
197+
countVal, ok := count.(int32)
198+
if !ok {
199+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(count).String())
200+
}
201+
if countVal < 0 {
175202
return nil, ErrNegativeRepeatCount.New(count)
176203
}
177-
return strings.Repeat(str.(string), int(count.(int32))), nil
204+
return strings.Repeat(strVal, int(countVal)), nil
178205
}
179206

180207
// Replace is a function that returns a string with all occurrences of fromStr replaced by the

sql/expression/function/substring.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,22 @@ func (s *Substring) Eval(
8484
return nil, err
8585
}
8686

87+
if str == nil {
88+
return nil, nil
89+
}
90+
91+
// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
92+
str, err = sql.UnwrapAny(ctx, str)
93+
if err != nil {
94+
return nil, err
95+
}
96+
8797
var text []rune
8898
switch str := str.(type) {
8999
case string:
90100
text = []rune(str)
91101
case []byte:
92102
text = []rune(string(str))
93-
case nil:
94-
return nil, nil
95103
default:
96104
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
97105
}

sql/expression/function/trim_ltrim_rtrim.go

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@ func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
6868
return nil, sql.ErrInvalidType.New(reflect.TypeOf(pat).String())
6969
}
7070

71+
// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
72+
pat, err = sql.UnwrapAny(ctx, pat)
73+
if err != nil {
74+
return nil, err
75+
}
76+
7177
// Evaluate string value
7278
str, err := t.str.Eval(ctx, row)
7379
if err != nil {
@@ -79,15 +85,31 @@ func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
7985
return nil, nil
8086
}
8187

82-
// Convert pat into string
88+
// Convert str to text type (may still be wrapped)
8389
str, _, err = types.LongText.Convert(ctx, str)
8490
if err != nil {
8591
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
8692
}
8793

94+
// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
95+
str, err = sql.UnwrapAny(ctx, str)
96+
if err != nil {
97+
return nil, err
98+
}
99+
100+
strVal, ok := str.(string)
101+
if !ok {
102+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
103+
}
104+
105+
patVal, ok := pat.(string)
106+
if !ok {
107+
return nil, sql.ErrInvalidType.New(reflect.TypeOf(pat).String())
108+
}
109+
88110
start := 0
89-
end := len(str.(string))
90-
n := len(pat.(string))
111+
end := len(strVal)
112+
n := len(patVal)
91113

92114
// Empty pattern, do nothing
93115
if n == 0 {
@@ -96,19 +118,19 @@ func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
96118

97119
// Trim Leading
98120
if t.dir == sqlparser.Leading || t.dir == sqlparser.Both {
99-
for start+n <= end && str.(string)[start:start+n] == pat {
121+
for start+n <= end && strVal[start:start+n] == patVal {
100122
start += n
101123
}
102124
}
103125

104126
// Trim Trailing
105127
if t.dir == sqlparser.Trailing || t.dir == sqlparser.Both {
106-
for start+n <= end && str.(string)[end-n:end] == pat {
128+
for start+n <= end && strVal[end-n:end] == patVal {
107129
end -= n
108130
}
109131
}
110132

111-
return str.(string)[start:end], nil
133+
return strVal[start:end], nil
112134
}
113135

114136
// IsNullable implements the Expression interface.

0 commit comments

Comments
 (0)