Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions internal/strings/unquote.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,64 @@ func Unquote(s string) (string, error) {
return str, nil
}

// UnquoteBytes is the same as Unquote, except it modifies a byte slice in-place
// Be careful: by reusing the slice, this destroys the original input value.
func UnquoteBytes(b []byte) ([]byte, error) {
outIdx := 0
for i := 0; i < len(b); i++ {
if b[i] == '\\' {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, didn't realize \\ was a single char

Copy link
Contributor Author

@nicktobey nicktobey Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You got to escape the backslash character. '\' would be escaping the closing quote.

i++
if i == len(b) {
b[outIdx] = '\\'
}
switch b[i] {
case '"':
b[outIdx] = '"'
case 'b':
b[outIdx] = '\b'
case 'f':
b[outIdx] = '\f'
case 'n':
b[outIdx] = '\n'
case 'r':
b[outIdx] = '\r'
case 't':
b[outIdx] = '\t'
case '\\':
b[outIdx] = '\\'
case 'u':
if i+4 > len(b) {
return nil, fmt.Errorf("Invalid unicode: %s", b[i+1:])
}
char, size, err := decodeEscapedUnicode(b[i+1 : i+5])
if err != nil {
return nil, err
}
for j, c := range char[:size] {
b[outIdx+j] = c
}
i += 4
default:
// For all other escape sequences, backslash is ignored.
b[outIdx] = b[i]
}
} else {
b[outIdx] = b[i]
}
outIdx++
}
b = b[:outIdx]

// Remove prefix and suffix '"'.
if outIdx > 1 {
head, tail := b[0], b[outIdx-1]
if head == '"' && tail == '"' {
return b[1 : outIdx-1], nil
}
}
return b, nil
}

// decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629.
// According RFC 3629, the max length of utf8 characters is 4 bytes.
// And MySQL use 4 bytes to represent the unicode which must be in [0, 65536).
Expand Down
2 changes: 1 addition & 1 deletion sql/fulltext/fulltext.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func writeHashedValue(ctx context.Context, h hash.Hash, val interface{}) (valIsN
return false, err
}
case sql.JSONWrapper:
str, err := types.StringifyJSON(val)
str, err := types.JsonToMySqlString(val)
if err != nil {
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion sql/types/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
}
js := jsVal.(sql.JSONWrapper)

str, err := StringifyJSON(js)
str, err := JsonToMySqlString(js)
if err != nil {
return sqltypes.NULL, err
}
Expand Down
12 changes: 12 additions & 0 deletions sql/types/json_encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,18 @@ func marshalToMySqlString(val interface{}) (string, error) {
return b.String(), nil
}

// marshalToMySqlBytes is a helper function to marshal a JSONDocument to a byte slice that is
// compatible with MySQL's JSON output, including spaces.
func marshalToMySqlBytes(val interface{}) ([]byte, error) {
b := NewNoCopyBuilder(1024)
err := writeMarshalledValue(b, val)
if err != nil {
return nil, err
}

return b.Bytes(), nil
}

func sortKeys[T any](m map[string]T) []string {
var keys []string
for k := range m {
Expand Down
25 changes: 13 additions & 12 deletions sql/types/json_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,24 @@ import (
"github.com/dolthub/go-mysql-server/sql"
)

// JSONStringer can be converted to a string representation that is compatible with MySQL's JSON output, including spaces.
type JSONStringer interface {
JSONString() (string, error)
}

// StringifyJSON generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces.
func StringifyJSON(jsonWrapper sql.JSONWrapper) (string, error) {
if stringer, ok := jsonWrapper.(JSONStringer); ok {
return stringer.JSONString()
}
// JsonToMySqlString generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces.
func JsonToMySqlString(jsonWrapper sql.JSONWrapper) (string, error) {
val, err := jsonWrapper.ToInterface()
if err != nil {
return "", err
}
return marshalToMySqlString(val)
}

// JsonToMySqlString generates a byte slice representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces.
func JsonToMySqlBytes(jsonWrapper sql.JSONWrapper) ([]byte, error) {
val, err := jsonWrapper.ToInterface()
if err != nil {
return nil, err
}
return marshalToMySqlBytes(val)
}

// JSONBytes are values which can be represented as JSON.
type JSONBytes interface {
sql.JSONWrapper
Expand Down Expand Up @@ -202,12 +203,12 @@ func (j *LazyJSONDocument) GetBytes() ([]byte, error) {

// Value implements driver.Valuer for interoperability with other go libraries
func (j *LazyJSONDocument) Value() (driver.Value, error) {
return StringifyJSON(j)
return JsonToMySqlString(j)
}

// LazyJSONDocument implements the fmt.Stringer interface.
func (j *LazyJSONDocument) String() string {
s, err := StringifyJSON(j)
s, err := JsonToMySqlString(j)
if err != nil {
return fmt.Sprintf("error while stringifying JSON: %s", err.Error())
}
Expand Down
30 changes: 23 additions & 7 deletions sql/types/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ func ConvertToString(ctx context.Context, v interface{}, t sql.StringType, dest
func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest []byte) ([]byte, error) {
var val []byte
start := len(dest)
// Based on the type of the input, convert it into a byte array, writing it into |dest| to avoid an allocation.
// If the current implementation must make a separate allocation anyway, avoid copying it into dest by replacing
// |val| entirely (and setting |start| to 0).
switch s := v.(type) {
case bool:
if s {
Expand Down Expand Up @@ -394,7 +397,10 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
case string:
val = append(dest, s...)
case []byte:
val = append(dest, s...)
// We can avoid copying the slice if this isn't a conversion to BINARY
// We'll check for that below, immediately before extending the slice.
val = s
start = 0
case time.Time:
val = s.AppendFormat(dest, sql.TimestampDatetimeLayout)
case decimal.Decimal:
Expand All @@ -405,24 +411,29 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
}
val = append(dest, s.Decimal.String()...)
case sql.JSONWrapper:
jsonString, err := StringifyJSON(s)
var err error
val, err = JsonToMySqlBytes(s)
if err != nil {
return nil, err
}
st, err := strings.Unquote(jsonString)
val, err = strings.UnquoteBytes(val)
if err != nil {
return nil, err
}
val = append(dest, st...)
case sql.AnyWrapper:
unwrapped, err := s.UnwrapAny(ctx)
start = 0
case sql.Wrapper[string]:
unwrapped, err := s.Unwrap(ctx)
if err != nil {
return nil, err
}
val, err = ConvertToBytes(ctx, unwrapped, t, dest)
val = append(val, unwrapped...)
case sql.Wrapper[[]byte]:
var err error
val, err = s.Unwrap(ctx)
if err != nil {
return nil, err
}
start = 0
case GeometryValue:
return s.Serialize(), nil
default:
Expand Down Expand Up @@ -455,6 +466,11 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
}

if st.baseType == sqltypes.Binary {
if b, ok := v.([]byte); ok {
// Make a copy now to avoid overwriting the original allocation.
val = append(dest, b...)
start = len(dest)
}
val = append(val, bytes.Repeat([]byte{0}, int(st.maxCharLength)-len(val))...)
}
}
Expand Down
Loading