diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index cb5455eed8..ae8994fb7a 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -226,7 +226,7 @@ func TestSingleScript(t *testing.T) { for _, test := range scripts { harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil) - harness.UseServer() + //harness.UseServer() engine, err := harness.NewEngine(t) if err != nil { panic(err) diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 1123a5b546..3f06760c67 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -8684,6 +8684,34 @@ where }, }, }, + { + Name: "substring function tests with wrappers", + Dialect: "mysql", + SetUpScript: []string{ + "create table tbl (t text);", + "insert into tbl values ('abcdef');", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select left(t, 3) from tbl;", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "select right(t, 3) from tbl;", + Expected: []sql.Row{ + {"def"}, + }, + }, + { + Query: "select instr(t, 'bcd') from tbl;", + Expected: []sql.Row{ + {2}, + }, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/expression/function/substring.go b/sql/expression/function/substring.go index 36189e10c8..19a51a46f0 100644 --- a/sql/expression/function/substring.go +++ b/sql/expression/function/substring.go @@ -349,8 +349,20 @@ func (l Left) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch str := str.(type) { case string: text = []rune(str) + case sql.StringWrapper: + s, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(s) case []byte: text = []rune(string(str)) + case sql.BytesWrapper: + b, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(string(b)) case nil: return nil, nil default: @@ -583,8 +595,20 @@ func (i Instr) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch str := str.(type) { case string: text = []rune(str) + case sql.StringWrapper: + s, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(s) case []byte: text = []rune(string(str)) + case sql.BytesWrapper: + s, err := str.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(string(s)) case nil: return nil, nil default: @@ -600,8 +624,20 @@ func (i Instr) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { switch substr := substr.(type) { case string: subtext = []rune(substr) + case sql.StringWrapper: + s, err := substr.Unwrap(ctx) + if err != nil { + return nil, err + } + text = []rune(s) case []byte: - subtext = []rune(string(subtext)) + subtext = []rune(string(substr)) + case sql.BytesWrapper: + s, err := substr.Unwrap(ctx) + if err != nil { + return nil, err + } + subtext = []rune(string(s)) case nil: return nil, nil default: