Skip to content

Commit 0297315

Browse files
committed
Reduce allocs in interpolateParams.
benchmark old ns/op new ns/op delta BenchmarkInterpolation 4065 2533 -37.69% benchmark old allocs new allocs delta BenchmarkInterpolation 15 6 -60.00% benchmark old bytes new bytes delta BenchmarkInterpolation 1144 560 -51.05%
1 parent 468b9e5 commit 0297315

File tree

3 files changed

+84
-36
lines changed

3 files changed

+84
-36
lines changed

connection.go

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -166,73 +166,107 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
166166
}
167167

168168
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
169-
func (mc *mysqlConn) escapeBytes(v []byte) string {
170-
var escape func([]byte) []byte
169+
func (mc *mysqlConn) escapeBytes(buf, v []byte) []byte {
170+
var escape func([]byte, []byte) []byte
171171
if mc.status&statusNoBackslashEscapes == 0 {
172-
escape = EscapeString
172+
escape = escapeString
173173
} else {
174-
escape = EscapeQuotes
174+
escape = escapeQuotes
175175
}
176-
return "'" + string(escape(v)) + "'"
176+
buf = append(buf, '\'')
177+
buf = escape(buf, v)
178+
buf = append(buf, '\'')
179+
return buf
180+
}
181+
182+
func estimateParamLength(args []driver.Value) (int, bool) {
183+
l := 0
184+
for _, a := range args {
185+
switch v := a.(type) {
186+
case int64, float64:
187+
l += 20
188+
case bool:
189+
l += 5
190+
case time.Time:
191+
l += 30
192+
case string:
193+
l += len(v)*2 + 2
194+
case []byte:
195+
l += len(v)*2 + 2
196+
default:
197+
return 0, false
198+
}
199+
}
200+
return l, true
177201
}
178202

179203
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
180-
chunks := strings.Split(query, "?")
181-
if len(chunks) != len(args)+1 {
204+
estimated, ok := estimateParamLength(args)
205+
if !ok {
182206
return "", driver.ErrSkip
183207
}
208+
estimated += len(query)
184209

185-
parts := make([]string, len(chunks)+len(args))
186-
parts[0] = chunks[0]
210+
buf := make([]byte, 0, estimated)
211+
argPos := 0
212+
213+
// Go 1.5 will optimize range([]byte(string)) to skip allocation.
214+
for _, c := range []byte(query) {
215+
if c != '?' {
216+
buf = append(buf, c)
217+
continue
218+
}
219+
220+
arg := args[argPos]
221+
argPos++
187222

188-
for i, arg := range args {
189-
pos := i*2 + 1
190-
parts[pos+1] = chunks[i+1]
191223
if arg == nil {
192-
parts[pos] = "NULL"
224+
buf = append(buf, []byte("NULL")...)
193225
continue
194226
}
227+
195228
switch v := arg.(type) {
196229
case int64:
197-
parts[pos] = strconv.FormatInt(v, 10)
230+
buf = strconv.AppendInt(buf, v, 10)
198231
case float64:
199-
parts[pos] = strconv.FormatFloat(v, 'g', -1, 64)
232+
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
200233
case bool:
201234
if v {
202-
parts[pos] = "1"
235+
buf = append(buf, '1')
203236
} else {
204-
parts[pos] = "0"
237+
buf = append(buf, '0')
205238
}
206239
case time.Time:
207240
if v.IsZero() {
208-
parts[pos] = "'0000-00-00'"
241+
buf = append(buf, []byte("'0000-00-00'")...)
209242
} else {
210243
fmt := "'2006-01-02 15:04:05.999999'"
211244
if v.Nanosecond() == 0 {
212245
fmt = "'2006-01-02 15:04:05'"
213246
}
214-
parts[pos] = v.In(mc.cfg.loc).Format(fmt)
247+
s := v.In(mc.cfg.loc).Format(fmt)
248+
buf = append(buf, []byte(s)...)
215249
}
216250
case []byte:
217251
if v == nil {
218-
parts[pos] = "NULL"
252+
buf = append(buf, []byte("NULL")...)
219253
} else {
220-
parts[pos] = mc.escapeBytes(v)
254+
buf = mc.escapeBytes(buf, v)
221255
}
222256
case string:
223-
parts[pos] = mc.escapeBytes([]byte(v))
257+
buf = mc.escapeBytes(buf, []byte(v))
224258
default:
225259
return "", driver.ErrSkip
226260
}
261+
262+
if len(buf)+4 > mc.maxPacketAllowed {
263+
return "", driver.ErrSkip
264+
}
227265
}
228-
pktSize := len(query) + 4 // 4 bytes for header.
229-
for _, p := range parts {
230-
pktSize += len(p)
231-
}
232-
if pktSize > mc.maxPacketAllowed {
266+
if argPos != len(args) {
233267
return "", driver.ErrSkip
234268
}
235-
return strings.Join(parts, ""), nil
269+
return string(buf), nil
236270
}
237271

238272
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {

utils.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -812,9 +812,16 @@ func appendLengthEncodedInteger(b []byte, n uint64) []byte {
812812
// characters, and turning others into specific escape sequences, such as
813813
// turning newlines into \n and null bytes into \0.
814814
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932
815-
func EscapeString(v []byte) []byte {
816-
buf := make([]byte, len(v)*2)
817-
pos := 0
815+
func escapeString(buf, v []byte) []byte {
816+
pos := len(buf)
817+
end := pos + len(v)*2
818+
if cap(buf) < end {
819+
n := make([]byte, pos+end)
820+
copy(n, buf)
821+
buf = n
822+
}
823+
buf = buf[0:end]
824+
818825
for _, c := range v {
819826
switch c {
820827
case '\x00':
@@ -859,9 +866,16 @@ func EscapeString(v []byte) []byte {
859866
// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in
860867
// effect on the server.
861868
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038
862-
func EscapeQuotes(v []byte) []byte {
863-
buf := make([]byte, len(v)*2)
864-
pos := 0
869+
func escapeQuotes(buf, v []byte) []byte {
870+
pos := len(buf)
871+
end := pos + len(v)*2
872+
if cap(buf) < end {
873+
n := make([]byte, pos+end)
874+
copy(n, buf)
875+
buf = n
876+
}
877+
buf = buf[0:end]
878+
865879
for _, c := range v {
866880
if c == '\'' {
867881
buf[pos] = '\''

utils_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func TestFormatBinaryDateTime(t *testing.T) {
255255

256256
func TestEscapeString(t *testing.T) {
257257
expect := func(expected, value string) {
258-
actual := string(EscapeString([]byte(value)))
258+
actual := string(escapeString([]byte{}, []byte(value)))
259259
if actual != expected {
260260
t.Errorf(
261261
"expected %s, got %s",
@@ -275,7 +275,7 @@ func TestEscapeString(t *testing.T) {
275275

276276
func TestEscapeQuotes(t *testing.T) {
277277
expect := func(expected, value string) {
278-
actual := string(EscapeQuotes([]byte(value)))
278+
actual := string(escapeQuotes([]byte{}, []byte(value)))
279279
if actual != expected {
280280
t.Errorf(
281281
"expected %s, got %s",

0 commit comments

Comments
 (0)