Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,14 @@ func TestCollationCoercion(t *testing.T) {

func TestRegex(t *testing.T) {
harness := enginetest.NewDefaultMemoryHarness()
harness.Setup(setup.SimpleSetup...)
regexSetup := []setup.SetupScript{
{
"CREATE TABLE tests(pk int primary key, str text, pattern text, flags text);",
"INSERT INTO tests VALUES (1, 'testing', 'TESTING', 'ci');",
},
}
setupsScripts := append(setup.SimpleSetup, regexSetup)
harness.Setup(setupsScripts...)
engine, err := harness.NewEngine(t)
require.NoError(t, err)
defer engine.Close()
Expand Down
4 changes: 4 additions & 0 deletions enginetest/queries/regex_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ var RegexTests = []RegexTest{
Query: "SELECT REGEXP_LIKE('testing', 'TESTING', 'ic');",
Expected: []sql.Row{{0}},
},
{
Query: "SELECT REGEXP_LIKE(str, pattern, flags) from tests;",
Expected: []sql.Row{{1}},
},
{
Query: "SELECT REGEXP_LIKE('testing', 'TESTING' COLLATE utf8mb4_0900_ai_ci);",
Expected: []sql.Row{{1}},
Expand Down
19 changes: 15 additions & 4 deletions sql/expression/function/regexp_like.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,12 @@ func (r *RegexpLike) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if err != nil {
return nil, err
}
textStr, _, err := sql.Unwrap[string](ctx, text)
if err != nil {
return nil, err
}

err = r.re.SetMatchString(ctx, text.(string))
err = r.re.SetMatchString(ctx, textStr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -220,9 +224,13 @@ func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcNam
if err != nil {
return nil, err
}
patternValStr, _, err := sql.Unwrap[string](ctx, patternVal)
if err != nil {
return nil, err
}

// Empty regex, throw illegal argument
if len(patternVal.(string)) == 0 {
if len(patternValStr) == 0 {
return nil, errors.NewKind("Illegal argument to regular expression.").New()
}

Expand Down Expand Up @@ -250,7 +258,10 @@ func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcNam
return nil, err
}

flagsStr = f.(string)
flagsStr, _, err = sql.Unwrap[string](ctx, f)
if err != nil {
return nil, err
}
flagsStr, err = consolidateRegexpFlags(flagsStr, funcName)
if err != nil {
return nil, err
Expand Down Expand Up @@ -279,7 +290,7 @@ func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcNam
ctx.Warn(1193, `System variable for regular expressions "regexp_buffer_size" is missing`)
}
re := regex.CreateRegex(bufferSize)
if err = re.SetRegexString(ctx, patternVal.(string), regexFlags); err != nil {
if err = re.SetRegexString(ctx, patternValStr, regexFlags); err != nil {
_ = re.Close()
return nil, err
}
Expand Down
25 changes: 21 additions & 4 deletions sql/expression/function/reverse_repeat_replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,26 @@ func (r *Replace) Eval(
return nil, err
}

if fromStr.(string) == "" {
return str, nil
{
str, _, err := sql.Unwrap[string](ctx, str)
if err != nil {
return nil, err
}

fromStr, _, err := sql.Unwrap[string](ctx, fromStr)
if err != nil {
return nil, err
}

toStr, _, err := sql.Unwrap[string](ctx, toStr)
if err != nil {
return nil, err
}

if fromStr == "" {
return str, nil
}

return strings.Replace(str, fromStr, toStr, -1), nil
}

return strings.Replace(str.(string), fromStr.(string), toStr.(string), -1), nil
}
14 changes: 13 additions & 1 deletion sql/expression/function/rpad_lpad.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,19 @@ func (p *Pad) Eval(
return nil, err
}

return padString(str.(string), length.(int64), padStr.(string), p.padType)
{
str, _, err := sql.Unwrap[string](ctx, str)
if err != nil {
return nil, err
}

padStr, _, err := sql.Unwrap[string](ctx, padStr)
if err != nil {
return nil, err
}

return padString(str, length.(int64), padStr, p.padType)
}
}

func padString(str string, length int64, padStr string, padType padType) (string, error) {
Expand Down
Loading