@@ -166,73 +166,107 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
166
166
}
167
167
168
168
// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/libmysql/libmysql.c#L1150-L1156
169
- func (mc * mysqlConn ) escapeBytes (v []byte ) string {
170
- var escape func ([]byte ) []byte
169
+ func (mc * mysqlConn ) escapeBytes (buf , v []byte ) [] byte {
170
+ var escape func ([]byte , [] byte ) []byte
171
171
if mc .status & statusNoBackslashEscapes == 0 {
172
- escape = EscapeString
172
+ escape = escapeString
173
173
} else {
174
- escape = EscapeQuotes
174
+ escape = escapeQuotes
175
175
}
176
- return "'" + string (escape (v )) + "'"
176
+ buf = append (buf , '\'' )
177
+ buf = escape (buf , v )
178
+ buf = append (buf , '\'' )
179
+ return buf
180
+ }
181
+
182
+ func estimateParamLength (args []driver.Value ) (int , bool ) {
183
+ l := 0
184
+ for _ , a := range args {
185
+ switch v := a .(type ) {
186
+ case int64 , float64 :
187
+ l += 20
188
+ case bool :
189
+ l += 5
190
+ case time.Time :
191
+ l += 30
192
+ case string :
193
+ l += len (v )* 2 + 2
194
+ case []byte :
195
+ l += len (v )* 2 + 2
196
+ default :
197
+ return 0 , false
198
+ }
199
+ }
200
+ return l , true
177
201
}
178
202
179
203
func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
180
- chunks := strings . Split ( query , "?" )
181
- if len ( chunks ) != len ( args ) + 1 {
204
+ estimated , ok := estimateParamLength ( args )
205
+ if ! ok {
182
206
return "" , driver .ErrSkip
183
207
}
208
+ estimated += len (query )
184
209
185
- parts := make ([]string , len (chunks )+ len (args ))
186
- parts [0 ] = chunks [0 ]
210
+ buf := make ([]byte , 0 , estimated )
211
+ argPos := 0
212
+
213
+ // Go 1.5 will optimize range([]byte(string)) to skip allocation.
214
+ for _ , c := range []byte (query ) {
215
+ if c != '?' {
216
+ buf = append (buf , c )
217
+ continue
218
+ }
219
+
220
+ arg := args [argPos ]
221
+ argPos ++
187
222
188
- for i , arg := range args {
189
- pos := i * 2 + 1
190
- parts [pos + 1 ] = chunks [i + 1 ]
191
223
if arg == nil {
192
- parts [ pos ] = "NULL"
224
+ buf = append ( buf , [] byte ( "NULL" ) ... )
193
225
continue
194
226
}
227
+
195
228
switch v := arg .(type ) {
196
229
case int64 :
197
- parts [ pos ] = strconv .FormatInt ( v , 10 )
230
+ buf = strconv .AppendInt ( buf , v , 10 )
198
231
case float64 :
199
- parts [ pos ] = strconv .FormatFloat ( v , 'g' , - 1 , 64 )
232
+ buf = strconv .AppendFloat ( buf , v , 'g' , - 1 , 64 )
200
233
case bool :
201
234
if v {
202
- parts [ pos ] = "1"
235
+ buf = append ( buf , '1' )
203
236
} else {
204
- parts [ pos ] = "0"
237
+ buf = append ( buf , '0' )
205
238
}
206
239
case time.Time :
207
240
if v .IsZero () {
208
- parts [ pos ] = "'0000-00-00'"
241
+ buf = append ( buf , [] byte ( "'0000-00-00'" ) ... )
209
242
} else {
210
243
fmt := "'2006-01-02 15:04:05.999999'"
211
244
if v .Nanosecond () == 0 {
212
245
fmt = "'2006-01-02 15:04:05'"
213
246
}
214
- parts [pos ] = v .In (mc .cfg .loc ).Format (fmt )
247
+ s := v .In (mc .cfg .loc ).Format (fmt )
248
+ buf = append (buf , []byte (s )... )
215
249
}
216
250
case []byte :
217
251
if v == nil {
218
- parts [ pos ] = "NULL"
252
+ buf = append ( buf , [] byte ( "NULL" ) ... )
219
253
} else {
220
- parts [ pos ] = mc .escapeBytes (v )
254
+ buf = mc .escapeBytes (buf , v )
221
255
}
222
256
case string :
223
- parts [ pos ] = mc .escapeBytes ([]byte (v ))
257
+ buf = mc .escapeBytes (buf , []byte (v ))
224
258
default :
225
259
return "" , driver .ErrSkip
226
260
}
261
+
262
+ if len (buf )+ 4 > mc .maxPacketAllowed {
263
+ return "" , driver .ErrSkip
264
+ }
227
265
}
228
- pktSize := len (query ) + 4 // 4 bytes for header.
229
- for _ , p := range parts {
230
- pktSize += len (p )
231
- }
232
- if pktSize > mc .maxPacketAllowed {
266
+ if argPos != len (args ) {
233
267
return "" , driver .ErrSkip
234
268
}
235
- return strings . Join ( parts , "" ), nil
269
+ return string ( buf ), nil
236
270
}
237
271
238
272
func (mc * mysqlConn ) Exec (query string , args []driver.Value ) (driver.Result , error ) {
0 commit comments