Skip to content

Commit 7c05ba9

Browse files
authored
Merge pull request #2932 from dolthub/nicktobey/wrapper-bytes1
Optimize ConvertToBytes by avoiding unnecessary string <-> bytes conversions and copies.
2 parents 5632d67 + 172d577 commit 7c05ba9

File tree

6 files changed

+108
-21
lines changed

6 files changed

+108
-21
lines changed

internal/strings/unquote.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,64 @@ func Unquote(s string) (string, error) {
6666
return str, nil
6767
}
6868

69+
// UnquoteBytes is the same as Unquote, except it modifies a byte slice in-place
70+
// Be careful: by reusing the slice, this destroys the original input value.
71+
func UnquoteBytes(b []byte) ([]byte, error) {
72+
outIdx := 0
73+
for i := 0; i < len(b); i++ {
74+
if b[i] == '\\' {
75+
i++
76+
if i == len(b) {
77+
b[outIdx] = '\\'
78+
}
79+
switch b[i] {
80+
case '"':
81+
b[outIdx] = '"'
82+
case 'b':
83+
b[outIdx] = '\b'
84+
case 'f':
85+
b[outIdx] = '\f'
86+
case 'n':
87+
b[outIdx] = '\n'
88+
case 'r':
89+
b[outIdx] = '\r'
90+
case 't':
91+
b[outIdx] = '\t'
92+
case '\\':
93+
b[outIdx] = '\\'
94+
case 'u':
95+
if i+4 > len(b) {
96+
return nil, fmt.Errorf("Invalid unicode: %s", b[i+1:])
97+
}
98+
char, size, err := decodeEscapedUnicode(b[i+1 : i+5])
99+
if err != nil {
100+
return nil, err
101+
}
102+
for j, c := range char[:size] {
103+
b[outIdx+j] = c
104+
}
105+
i += 4
106+
default:
107+
// For all other escape sequences, backslash is ignored.
108+
b[outIdx] = b[i]
109+
}
110+
} else {
111+
b[outIdx] = b[i]
112+
}
113+
outIdx++
114+
}
115+
b = b[:outIdx]
116+
117+
// Remove prefix and suffix '"'.
118+
if outIdx > 1 {
119+
head, tail := b[0], b[outIdx-1]
120+
if head == '"' && tail == '"' {
121+
return b[1 : outIdx-1], nil
122+
}
123+
}
124+
return b, nil
125+
}
126+
69127
// decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629.
70128
// According RFC 3629, the max length of utf8 characters is 4 bytes.
71129
// And MySQL use 4 bytes to represent the unicode which must be in [0, 65536).

sql/fulltext/fulltext.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ func writeHashedValue(ctx context.Context, h hash.Hash, val interface{}) (valIsN
173173
return false, err
174174
}
175175
case sql.JSONWrapper:
176-
str, err := types.StringifyJSON(val)
176+
str, err := types.JsonToMySqlString(val)
177177
if err != nil {
178178
return false, err
179179
}

sql/types/json.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func (t JsonType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
150150
}
151151
js := jsVal.(sql.JSONWrapper)
152152

153-
str, err := StringifyJSON(js)
153+
str, err := JsonToMySqlString(js)
154154
if err != nil {
155155
return sqltypes.NULL, err
156156
}

sql/types/json_encode.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ func marshalToMySqlString(val interface{}) (string, error) {
142142
return b.String(), nil
143143
}
144144

145+
// marshalToMySqlBytes is a helper function to marshal a JSONDocument to a byte slice that is
146+
// compatible with MySQL's JSON output, including spaces.
147+
func marshalToMySqlBytes(val interface{}) ([]byte, error) {
148+
b := NewNoCopyBuilder(1024)
149+
err := writeMarshalledValue(b, val)
150+
if err != nil {
151+
return nil, err
152+
}
153+
154+
return b.Bytes(), nil
155+
}
156+
145157
func sortKeys[T any](m map[string]T) []string {
146158
var keys []string
147159
for k := range m {

sql/types/json_value.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,24 @@ import (
3434
"github.com/dolthub/go-mysql-server/sql"
3535
)
3636

37-
// JSONStringer can be converted to a string representation that is compatible with MySQL's JSON output, including spaces.
38-
type JSONStringer interface {
39-
JSONString() (string, error)
40-
}
41-
42-
// StringifyJSON generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces.
43-
func StringifyJSON(jsonWrapper sql.JSONWrapper) (string, error) {
44-
if stringer, ok := jsonWrapper.(JSONStringer); ok {
45-
return stringer.JSONString()
46-
}
37+
// JsonToMySqlString generates a string representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces.
38+
func JsonToMySqlString(jsonWrapper sql.JSONWrapper) (string, error) {
4739
val, err := jsonWrapper.ToInterface()
4840
if err != nil {
4941
return "", err
5042
}
5143
return marshalToMySqlString(val)
5244
}
5345

46+
// JsonToMySqlString generates a byte slice representation of a sql.JSONWrapper that is compatible with MySQL's JSON output, including spaces.
47+
func JsonToMySqlBytes(jsonWrapper sql.JSONWrapper) ([]byte, error) {
48+
val, err := jsonWrapper.ToInterface()
49+
if err != nil {
50+
return nil, err
51+
}
52+
return marshalToMySqlBytes(val)
53+
}
54+
5455
// JSONBytes are values which can be represented as JSON.
5556
type JSONBytes interface {
5657
sql.JSONWrapper
@@ -202,12 +203,12 @@ func (j *LazyJSONDocument) GetBytes() ([]byte, error) {
202203

203204
// Value implements driver.Valuer for interoperability with other go libraries
204205
func (j *LazyJSONDocument) Value() (driver.Value, error) {
205-
return StringifyJSON(j)
206+
return JsonToMySqlString(j)
206207
}
207208

208209
// LazyJSONDocument implements the fmt.Stringer interface.
209210
func (j *LazyJSONDocument) String() string {
210-
s, err := StringifyJSON(j)
211+
s, err := JsonToMySqlString(j)
211212
if err != nil {
212213
return fmt.Sprintf("error while stringifying JSON: %s", err.Error())
213214
}

sql/types/strings.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,9 @@ func ConvertToString(ctx context.Context, v interface{}, t sql.StringType, dest
354354
func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest []byte) ([]byte, error) {
355355
var val []byte
356356
start := len(dest)
357+
// Based on the type of the input, convert it into a byte array, writing it into |dest| to avoid an allocation.
358+
// If the current implementation must make a separate allocation anyway, avoid copying it into dest by replacing
359+
// |val| entirely (and setting |start| to 0).
357360
switch s := v.(type) {
358361
case bool:
359362
if s {
@@ -394,7 +397,10 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
394397
case string:
395398
val = append(dest, s...)
396399
case []byte:
397-
val = append(dest, s...)
400+
// We can avoid copying the slice if this isn't a conversion to BINARY
401+
// We'll check for that below, immediately before extending the slice.
402+
val = s
403+
start = 0
398404
case time.Time:
399405
val = s.AppendFormat(dest, sql.TimestampDatetimeLayout)
400406
case decimal.Decimal:
@@ -405,24 +411,29 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
405411
}
406412
val = append(dest, s.Decimal.String()...)
407413
case sql.JSONWrapper:
408-
jsonString, err := StringifyJSON(s)
414+
var err error
415+
val, err = JsonToMySqlBytes(s)
409416
if err != nil {
410417
return nil, err
411418
}
412-
st, err := strings.Unquote(jsonString)
419+
val, err = strings.UnquoteBytes(val)
413420
if err != nil {
414421
return nil, err
415422
}
416-
val = append(dest, st...)
417-
case sql.AnyWrapper:
418-
unwrapped, err := s.UnwrapAny(ctx)
423+
start = 0
424+
case sql.Wrapper[string]:
425+
unwrapped, err := s.Unwrap(ctx)
419426
if err != nil {
420427
return nil, err
421428
}
422-
val, err = ConvertToBytes(ctx, unwrapped, t, dest)
429+
val = append(val, unwrapped...)
430+
case sql.Wrapper[[]byte]:
431+
var err error
432+
val, err = s.Unwrap(ctx)
423433
if err != nil {
424434
return nil, err
425435
}
436+
start = 0
426437
case GeometryValue:
427438
return s.Serialize(), nil
428439
default:
@@ -455,6 +466,11 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
455466
}
456467

457468
if st.baseType == sqltypes.Binary {
469+
if b, ok := v.([]byte); ok {
470+
// Make a copy now to avoid overwriting the original allocation.
471+
val = append(dest, b...)
472+
start = len(dest)
473+
}
458474
val = append(val, bytes.Repeat([]byte{0}, int(st.maxCharLength)-len(val))...)
459475
}
460476
}

0 commit comments

Comments
 (0)