Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions sql/expression/function/reverse_repeat_replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package function

import (
"fmt"
"reflect"
"strings"

"gopkg.in/src-d/go-errors.v1"
Expand Down Expand Up @@ -64,7 +65,18 @@ func (r *Reverse) Eval(
return nil, err
}

return reverseString(v.(string)), nil
// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
v, err = sql.UnwrapAny(ctx, v)
if err != nil {
return nil, err
}

s, ok := v.(string)
if !ok {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(v).String())
}

return reverseString(s), nil
}

func reverseString(s string) string {
Expand Down Expand Up @@ -162,6 +174,17 @@ func (r *Repeat) Eval(
return nil, err
}

// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
str, err = sql.UnwrapAny(ctx, str)
if err != nil {
return nil, err
}

strVal, ok := str.(string)
if !ok {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
}

count, err := r.RightChild.Eval(ctx, row)
if count == nil || err != nil {
return nil, err
Expand All @@ -171,10 +194,14 @@ func (r *Repeat) Eval(
if err != nil {
return nil, err
}
if count.(int32) < 0 {
countVal, ok := count.(int32)
if !ok {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(count).String())
}
if countVal < 0 {
return nil, ErrNegativeRepeatCount.New(count)
}
return strings.Repeat(str.(string), int(count.(int32))), nil
return strings.Repeat(strVal, int(countVal)), nil
}

// Replace is a function that returns a string with all occurrences of fromStr replaced by the
Expand Down
12 changes: 10 additions & 2 deletions sql/expression/function/substring.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,22 @@ func (s *Substring) Eval(
return nil, err
}

if str == nil {
return nil, nil
}

// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
str, err = sql.UnwrapAny(ctx, str)
if err != nil {
return nil, err
}

var text []rune
switch str := str.(type) {
case string:
text = []rune(str)
case []byte:
text = []rune(string(str))
case nil:
return nil, nil
default:
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
}
Expand Down
34 changes: 28 additions & 6 deletions sql/expression/function/trim_ltrim_rtrim.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(pat).String())
}

// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
pat, err = sql.UnwrapAny(ctx, pat)
if err != nil {
return nil, err
}

// Evaluate string value
str, err := t.str.Eval(ctx, row)
if err != nil {
Expand All @@ -79,15 +85,31 @@ func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, nil
}

// Convert pat into string
// Convert str to text type (may still be wrapped)
str, _, err = types.LongText.Convert(ctx, str)
if err != nil {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
}

// Handle Dolt's TextStorage and other wrapper types that don't convert to plain strings
str, err = sql.UnwrapAny(ctx, str)
if err != nil {
return nil, err
}

strVal, ok := str.(string)
if !ok {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(str).String())
}

patVal, ok := pat.(string)
if !ok {
return nil, sql.ErrInvalidType.New(reflect.TypeOf(pat).String())
}

start := 0
end := len(str.(string))
n := len(pat.(string))
end := len(strVal)
n := len(patVal)

// Empty pattern, do nothing
if n == 0 {
Expand All @@ -96,19 +118,19 @@ func (t *Trim) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {

// Trim Leading
if t.dir == sqlparser.Leading || t.dir == sqlparser.Both {
for start+n <= end && str.(string)[start:start+n] == pat {
for start+n <= end && strVal[start:start+n] == patVal {
start += n
}
}

// Trim Trailing
if t.dir == sqlparser.Trailing || t.dir == sqlparser.Both {
for start+n <= end && str.(string)[end-n:end] == pat {
for start+n <= end && strVal[end-n:end] == patVal {
end -= n
}
}

return str.(string)[start:end], nil
return strVal[start:end], nil
}

// IsNullable implements the Expression interface.
Expand Down
Loading