diff --git a/internal/strings/unquote.go b/internal/strings/unquote.go index dc4ccd6c41..f6695ab4d4 100644 --- a/internal/strings/unquote.go +++ b/internal/strings/unquote.go @@ -66,6 +66,64 @@ func Unquote(s string) (string, error) { return str, nil } +// UnquoteBytes is the same as Unquote, except it modifies a byte slice in-place +// Be careful: by reusing the slice, this destroys the original input value. +func UnquoteBytes(b []byte) ([]byte, error) { + outIdx := 0 + for i := 0; i < len(b); i++ { + if b[i] == '\\' { + i++ + if i == len(b) { + b[outIdx] = '\\' + } + switch b[i] { + case '"': + b[outIdx] = '"' + case 'b': + b[outIdx] = '\b' + case 'f': + b[outIdx] = '\f' + case 'n': + b[outIdx] = '\n' + case 'r': + b[outIdx] = '\r' + case 't': + b[outIdx] = '\t' + case '\\': + b[outIdx] = '\\' + case 'u': + if i+4 > len(b) { + return nil, fmt.Errorf("Invalid unicode: %s", b[i+1:]) + } + char, size, err := decodeEscapedUnicode(b[i+1 : i+5]) + if err != nil { + return nil, err + } + for j, c := range char[:size] { + b[outIdx+j] = c + } + i += 4 + default: + // For all other escape sequences, backslash is ignored. + b[outIdx] = b[i] + } + } else { + b[outIdx] = b[i] + } + outIdx++ + } + b = b[:outIdx] + + // Remove prefix and suffix '"'. + if outIdx > 1 { + head, tail := b[0], b[outIdx-1] + if head == '"' && tail == '"' { + return b[1 : outIdx-1], nil + } + } + return b, nil +} + // decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629. // According RFC 3629, the max length of utf8 characters is 4 bytes. // And MySQL use 4 bytes to represent the unicode which must be in [0, 65536). diff --git a/sql/fulltext/fulltext.go b/sql/fulltext/fulltext.go index 199b7c6536..60cea459c2 100644 --- a/sql/fulltext/fulltext.go +++ b/sql/fulltext/fulltext.go @@ -173,7 +173,7 @@ func writeHashedValue(ctx context.Context, h hash.Hash, val interface{}) (valIsN return false, err } case sql.JSONWrapper: - str, err := types.StringifyJSON(val) + str, err := types.JsonToMySqlString(val) if err != nil { return false, err } diff --git a/sql/types/json.go b/sql/types/json.go index 711374bf86..e4f1757145 100644 --- a/sql/types/json.go +++ b/sql/types/json.go @@ -150,7 +150,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va } js := jsVal.(sql.JSONWrapper) - str, err := StringifyJSON(js) + str, err := JsonToMySqlString(js) if err != nil { return sqltypes.NULL, err } diff --git a/sql/types/json_encode.go b/sql/types/json_encode.go index 293150f6d8..727365b38b 100644 --- a/sql/types/json_encode.go +++ b/sql/types/json_encode.go @@ -142,6 +142,18 @@ func marshalToMySqlString(val interface{}) (string, error) { return b.String(), nil } +// marshalToMySqlBytes is a helper function to marshal a JSONDocument to a byte slice that is +// compatible with MySQL's JSON output, including spaces. +func marshalToMySqlBytes(val interface{}) ([]byte, error) { + b := NewNoCopyBuilder(1024) + err := writeMarshalledValue(b, val) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + func sortKeys[T any](m map[string]T) []string { var keys []string for k := range m { diff --git a/sql/types/json_value.go b/sql/types/json_value.go index 34681bde29..f0c7ea4229 100644 --- a/sql/types/json_value.go +++ b/sql/types/json_value.go @@ -34,16 +34,8 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) -// JSONStringer can be converted to a string representation that is compatible with MySQL's JSON output, including spaces. -type JSONStringer interface { - JSONString() (string, error) -} - -// StringifyJSON generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces. -func StringifyJSON(jsonWrapper sql.JSONWrapper) (string, error) { - if stringer, ok := jsonWrapper.(JSONStringer); ok { - return stringer.JSONString() - } +// JsonToMySqlString generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces. +func JsonToMySqlString(jsonWrapper sql.JSONWrapper) (string, error) { val, err := jsonWrapper.ToInterface() if err != nil { return "", err @@ -51,6 +43,15 @@ func StringifyJSON(jsonWrapper sql.JSONWrapper) (string, error) { return marshalToMySqlString(val) } +// JsonToMySqlString generates a byte slice representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces. +func JsonToMySqlBytes(jsonWrapper sql.JSONWrapper) ([]byte, error) { + val, err := jsonWrapper.ToInterface() + if err != nil { + return nil, err + } + return marshalToMySqlBytes(val) +} + // JSONBytes are values which can be represented as JSON. type JSONBytes interface { sql.JSONWrapper @@ -202,12 +203,12 @@ func (j *LazyJSONDocument) GetBytes() ([]byte, error) { // Value implements driver.Valuer for interoperability with other go libraries func (j *LazyJSONDocument) Value() (driver.Value, error) { - return StringifyJSON(j) + return JsonToMySqlString(j) } // LazyJSONDocument implements the fmt.Stringer interface. func (j *LazyJSONDocument) String() string { - s, err := StringifyJSON(j) + s, err := JsonToMySqlString(j) if err != nil { return fmt.Sprintf("error while stringifying JSON: %s", err.Error()) } diff --git a/sql/types/strings.go b/sql/types/strings.go index d1861c5e1d..48365cb89f 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -354,6 +354,9 @@ func ConvertToString(ctx context.Context, v interface{}, t sql.StringType, dest func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest []byte) ([]byte, error) { var val []byte start := len(dest) + // Based on the type of the input, convert it into a byte array, writing it into |dest| to avoid an allocation. + // If the current implementation must make a separate allocation anyway, avoid copying it into dest by replacing + // |val| entirely (and setting |start| to 0). switch s := v.(type) { case bool: if s { @@ -394,7 +397,10 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [ case string: val = append(dest, s...) case []byte: - val = append(dest, s...) + // We can avoid copying the slice if this isn't a conversion to BINARY + // We'll check for that below, immediately before extending the slice. + val = s + start = 0 case time.Time: val = s.AppendFormat(dest, sql.TimestampDatetimeLayout) case decimal.Decimal: @@ -405,24 +411,29 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [ } val = append(dest, s.Decimal.String()...) case sql.JSONWrapper: - jsonString, err := StringifyJSON(s) + var err error + val, err = JsonToMySqlBytes(s) if err != nil { return nil, err } - st, err := strings.Unquote(jsonString) + val, err = strings.UnquoteBytes(val) if err != nil { return nil, err } - val = append(dest, st...) - case sql.AnyWrapper: - unwrapped, err := s.UnwrapAny(ctx) + start = 0 + case sql.Wrapper[string]: + unwrapped, err := s.Unwrap(ctx) if err != nil { return nil, err } - val, err = ConvertToBytes(ctx, unwrapped, t, dest) + val = append(val, unwrapped...) + case sql.Wrapper[[]byte]: + var err error + val, err = s.Unwrap(ctx) if err != nil { return nil, err } + start = 0 case GeometryValue: return s.Serialize(), nil default: @@ -455,6 +466,11 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [ } if st.baseType == sqltypes.Binary { + if b, ok := v.([]byte); ok { + // Make a copy now to avoid overwriting the original allocation. + val = append(dest, b...) + start = len(dest) + } val = append(val, bytes.Repeat([]byte{0}, int(st.maxCharLength)-len(val))...) } }