@@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
206
206
return s .queryOrExec (ctx , stmt , parsed , query , args )
207
207
}
208
208
209
+ func (s * ServerQueryEngine ) query (ctx * sql.Context , stmt * gosql.Stmt , query string , args []any ) (sql.Schema , sql.RowIter , * sql.QueryFlags , error ) {
210
+ var rows * gosql.Rows
211
+ var err error
212
+ if stmt != nil {
213
+ rows , err = stmt .Query (args ... )
214
+ } else {
215
+ rows , err = s .conn .Query (query , args ... )
216
+ }
217
+ if err != nil {
218
+ return nil , nil , nil , trimMySQLErrCodePrefix (err )
219
+ }
220
+ return convertRowsResult (ctx , rows )
221
+ }
222
+
223
+ func (s * ServerQueryEngine ) exec (ctx * sql.Context , stmt * gosql.Stmt , query string , args []any ) (sql.Schema , sql.RowIter , * sql.QueryFlags , error ) {
224
+ var res gosql.Result
225
+ var err error
226
+ if stmt != nil {
227
+ res , err = stmt .Exec (args ... )
228
+ } else {
229
+ res , err = s .conn .Exec (query , args ... )
230
+ }
231
+ if err != nil {
232
+ return nil , nil , nil , trimMySQLErrCodePrefix (err )
233
+ }
234
+ return convertExecResult (res )
235
+ }
236
+
209
237
// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
210
238
// If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
211
239
// be run as prepared.
212
240
// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
213
241
//
214
- // for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result .
242
+ // for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult .
215
243
func (s * ServerQueryEngine ) queryOrExec (ctx * sql.Context , stmt * gosql.Stmt , parsed sqlparser.Statement , query string , args []any ) (sql.Schema , sql.RowIter , * sql.QueryFlags , error ) {
216
- var err error
217
- switch parsed .(type ) {
218
244
// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
219
- case * sqlparser.Select , * sqlparser.SetOp , * sqlparser.Show , * sqlparser.Set , * sqlparser.Call , * sqlparser.Begin , * sqlparser.Use , * sqlparser.Load , * sqlparser.Execute , * sqlparser.Analyze , * sqlparser.Flush , * sqlparser.Explain :
220
- var rows * gosql.Rows
221
- if stmt != nil {
222
- rows , err = stmt .Query (args ... )
223
- } else {
224
- rows , err = s .conn .Query (query , args ... )
225
- }
226
- if err != nil {
227
- return nil , nil , nil , trimMySQLErrCodePrefix (err )
228
- }
229
- return convertRowsResult (ctx , rows )
245
+ var shouldQuery bool
246
+ switch p := parsed .(type ) {
247
+ // Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec
248
+ case * sqlparser.Insert :
249
+ if p .Returning != nil {
250
+ shouldQuery = true
251
+ }
252
+ case * sqlparser.Select , * sqlparser.SetOp , * sqlparser.Show ,
253
+ * sqlparser.Set , * sqlparser.Call , * sqlparser.Begin ,
254
+ * sqlparser.Use , * sqlparser.Load , * sqlparser.Execute ,
255
+ * sqlparser.Analyze , * sqlparser.Flush , * sqlparser.Explain :
256
+ shouldQuery = true
230
257
default :
231
- var res gosql.Result
232
- if stmt != nil {
233
- res , err = stmt .Exec (args ... )
234
- } else {
235
- res , err = s .conn .Exec (query , args ... )
236
- }
237
- if err != nil {
238
- return nil , nil , nil , trimMySQLErrCodePrefix (err )
239
- }
240
- return convertExecResult (res )
241
258
}
259
+
260
+ if shouldQuery {
261
+ return s .query (ctx , stmt , query , args )
262
+ }
263
+ return s .exec (ctx , stmt , query , args )
242
264
}
243
265
244
266
// trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.
0 commit comments