Skip to content

Commit 0362674

Browse files
authored
Merge pull request #2968 from dolthub/nicktobey/unwrap3
[no-release-notes] Rollforward "Unwrap inputs to REGEXP_LIKE, RPAD, LPAD, and REPLACE functions."
2 parents 3d51c5c + 87b662f commit 0362674

File tree

5 files changed

+61
-10
lines changed

5 files changed

+61
-10
lines changed

enginetest/engine_only_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,14 @@ func TestCollationCoercion(t *testing.T) {
717717

718718
func TestRegex(t *testing.T) {
719719
harness := enginetest.NewDefaultMemoryHarness()
720-
harness.Setup(setup.SimpleSetup...)
720+
regexSetup := []setup.SetupScript{
721+
{
722+
"CREATE TABLE tests(pk int primary key, str text, pattern text, flags text);",
723+
"INSERT INTO tests VALUES (1, 'testing', 'TESTING', 'ci');",
724+
},
725+
}
726+
setupsScripts := append(setup.SimpleSetup, regexSetup)
727+
harness.Setup(setupsScripts...)
721728
engine, err := harness.NewEngine(t)
722729
require.NoError(t, err)
723730
defer engine.Close()

enginetest/queries/regex_queries.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ var RegexTests = []RegexTest{
5555
Query: "SELECT REGEXP_LIKE('testing', 'TESTING', 'ic');",
5656
Expected: []sql.Row{{0}},
5757
},
58+
{
59+
Query: "SELECT REGEXP_LIKE(str, pattern, flags) from tests;",
60+
Expected: []sql.Row{{1}},
61+
},
5862
{
5963
Query: "SELECT REGEXP_LIKE('testing', 'TESTING' COLLATE utf8mb4_0900_ai_ci);",
6064
Expected: []sql.Row{{1}},

sql/expression/function/regexp_like.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,12 @@ func (r *RegexpLike) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
179179
if err != nil {
180180
return nil, err
181181
}
182+
textStr, _, err := sql.Unwrap[string](ctx, text)
183+
if err != nil {
184+
return nil, err
185+
}
182186

183-
err = r.re.SetMatchString(ctx, text.(string))
187+
err = r.re.SetMatchString(ctx, textStr)
184188
if err != nil {
185189
return nil, err
186190
}
@@ -220,9 +224,13 @@ func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcNam
220224
if err != nil {
221225
return nil, err
222226
}
227+
patternValStr, _, err := sql.Unwrap[string](ctx, patternVal)
228+
if err != nil {
229+
return nil, err
230+
}
223231

224232
// Empty regex, throw illegal argument
225-
if len(patternVal.(string)) == 0 {
233+
if len(patternValStr) == 0 {
226234
return nil, errors.NewKind("Illegal argument to regular expression.").New()
227235
}
228236

@@ -250,7 +258,10 @@ func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcNam
250258
return nil, err
251259
}
252260

253-
flagsStr = f.(string)
261+
flagsStr, _, err = sql.Unwrap[string](ctx, f)
262+
if err != nil {
263+
return nil, err
264+
}
254265
flagsStr, err = consolidateRegexpFlags(flagsStr, funcName)
255266
if err != nil {
256267
return nil, err
@@ -279,7 +290,7 @@ func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcNam
279290
ctx.Warn(1193, `System variable for regular expressions "regexp_buffer_size" is missing`)
280291
}
281292
re := regex.CreateRegex(bufferSize)
282-
if err = re.SetRegexString(ctx, patternVal.(string), regexFlags); err != nil {
293+
if err = re.SetRegexString(ctx, patternValStr, regexFlags); err != nil {
283294
_ = re.Close()
284295
return nil, err
285296
}

sql/expression/function/reverse_repeat_replace.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,26 @@ func (r *Replace) Eval(
280280
return nil, err
281281
}
282282

283-
if fromStr.(string) == "" {
284-
return str, nil
283+
{
284+
str, _, err := sql.Unwrap[string](ctx, str)
285+
if err != nil {
286+
return nil, err
287+
}
288+
289+
fromStr, _, err := sql.Unwrap[string](ctx, fromStr)
290+
if err != nil {
291+
return nil, err
292+
}
293+
294+
toStr, _, err := sql.Unwrap[string](ctx, toStr)
295+
if err != nil {
296+
return nil, err
297+
}
298+
299+
if fromStr == "" {
300+
return str, nil
301+
}
302+
303+
return strings.Replace(str, fromStr, toStr, -1), nil
285304
}
286-
287-
return strings.Replace(str.(string), fromStr.(string), toStr.(string), -1), nil
288305
}

sql/expression/function/rpad_lpad.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,19 @@ func (p *Pad) Eval(
169169
return nil, err
170170
}
171171

172-
return padString(str.(string), length.(int64), padStr.(string), p.padType)
172+
{
173+
str, _, err := sql.Unwrap[string](ctx, str)
174+
if err != nil {
175+
return nil, err
176+
}
177+
178+
padStr, _, err := sql.Unwrap[string](ctx, padStr)
179+
if err != nil {
180+
return nil, err
181+
}
182+
183+
return padString(str, length.(int64), padStr, p.padType)
184+
}
173185
}
174186

175187
func padString(str string, length int64, padStr string, padType padType) (string, error) {

0 commit comments

Comments
 (0)