@@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
206206 return s .queryOrExec (ctx , stmt , parsed , query , args )
207207}
208208
209+ func (s * ServerQueryEngine ) query (ctx * sql.Context , stmt * gosql.Stmt , query string , args []any ) (sql.Schema , sql.RowIter , * sql.QueryFlags , error ) {
210+ var rows * gosql.Rows
211+ var err error
212+ if stmt != nil {
213+ rows , err = stmt .Query (args ... )
214+ } else {
215+ rows , err = s .conn .Query (query , args ... )
216+ }
217+ if err != nil {
218+ return nil , nil , nil , trimMySQLErrCodePrefix (err )
219+ }
220+ return convertRowsResult (ctx , rows )
221+ }
222+
223+ func (s * ServerQueryEngine ) exec (ctx * sql.Context , stmt * gosql.Stmt , query string , args []any ) (sql.Schema , sql.RowIter , * sql.QueryFlags , error ) {
224+ var res gosql.Result
225+ var err error
226+ if stmt != nil {
227+ res , err = stmt .Exec (args ... )
228+ } else {
229+ res , err = s .conn .Exec (query , args ... )
230+ }
231+ if err != nil {
232+ return nil , nil , nil , trimMySQLErrCodePrefix (err )
233+ }
234+ return convertExecResult (res )
235+ }
236+
209237// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
210238// If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
211239// be run as prepared.
212240// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
213241//
214- // for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result .
242+ // for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult .
215243func (s * ServerQueryEngine ) queryOrExec (ctx * sql.Context , stmt * gosql.Stmt , parsed sqlparser.Statement , query string , args []any ) (sql.Schema , sql.RowIter , * sql.QueryFlags , error ) {
216- var err error
217- switch parsed .(type ) {
218244 // TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
219- case * sqlparser.Select , * sqlparser.SetOp , * sqlparser.Show , * sqlparser.Set , * sqlparser.Call , * sqlparser.Begin , * sqlparser.Use , * sqlparser.Load , * sqlparser.Execute , * sqlparser.Analyze , * sqlparser.Flush , * sqlparser.Explain :
220- var rows * gosql.Rows
221- if stmt != nil {
222- rows , err = stmt .Query (args ... )
223- } else {
224- rows , err = s .conn .Query (query , args ... )
225- }
226- if err != nil {
227- return nil , nil , nil , trimMySQLErrCodePrefix (err )
228- }
229- return convertRowsResult (ctx , rows )
245+ var shouldQuery bool
246+ switch p := parsed .(type ) {
247+ // Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec
248+ case * sqlparser.Insert :
249+ if p .Returning != nil {
250+ shouldQuery = true
251+ }
252+ case * sqlparser.Select , * sqlparser.SetOp , * sqlparser.Show ,
253+ * sqlparser.Set , * sqlparser.Call , * sqlparser.Begin ,
254+ * sqlparser.Use , * sqlparser.Load , * sqlparser.Execute ,
255+ * sqlparser.Analyze , * sqlparser.Flush , * sqlparser.Explain :
256+ shouldQuery = true
230257 default :
231- var res gosql.Result
232- if stmt != nil {
233- res , err = stmt .Exec (args ... )
234- } else {
235- res , err = s .conn .Exec (query , args ... )
236- }
237- if err != nil {
238- return nil , nil , nil , trimMySQLErrCodePrefix (err )
239- }
240- return convertExecResult (res )
241258 }
259+
260+ if shouldQuery {
261+ return s .query (ctx , stmt , query , args )
262+ }
263+ return s .exec (ctx , stmt , query , args )
242264}
243265
244266// trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.
0 commit comments