Skip to content

Commit c9e1686

Browse files
authored
Merge pull request #3013 from dolthub/james/returning_server
handle insert returning for server context
2 parents 6a1307e + bab43af commit c9e1686

File tree

1 file changed

+46
-24
lines changed

1 file changed

+46
-24
lines changed

enginetest/server_engine.go

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
206206
return s.queryOrExec(ctx, stmt, parsed, query, args)
207207
}
208208

209+
func (s *ServerQueryEngine) query(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
210+
var rows *gosql.Rows
211+
var err error
212+
if stmt != nil {
213+
rows, err = stmt.Query(args...)
214+
} else {
215+
rows, err = s.conn.Query(query, args...)
216+
}
217+
if err != nil {
218+
return nil, nil, nil, trimMySQLErrCodePrefix(err)
219+
}
220+
return convertRowsResult(ctx, rows)
221+
}
222+
223+
func (s *ServerQueryEngine) exec(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
224+
var res gosql.Result
225+
var err error
226+
if stmt != nil {
227+
res, err = stmt.Exec(args...)
228+
} else {
229+
res, err = s.conn.Exec(query, args...)
230+
}
231+
if err != nil {
232+
return nil, nil, nil, trimMySQLErrCodePrefix(err)
233+
}
234+
return convertExecResult(res)
235+
}
236+
209237
// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
210238
// If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
211239
// be run as prepared.
212240
// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
213241
//
214-
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result.
242+
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult.
215243
func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
216-
var err error
217-
switch parsed.(type) {
218244
// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
219-
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
220-
var rows *gosql.Rows
221-
if stmt != nil {
222-
rows, err = stmt.Query(args...)
223-
} else {
224-
rows, err = s.conn.Query(query, args...)
225-
}
226-
if err != nil {
227-
return nil, nil, nil, trimMySQLErrCodePrefix(err)
228-
}
229-
return convertRowsResult(ctx, rows)
245+
var shouldQuery bool
246+
switch p := parsed.(type) {
247+
// Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec
248+
case *sqlparser.Insert:
249+
if p.Returning != nil {
250+
shouldQuery = true
251+
}
252+
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show,
253+
*sqlparser.Set, *sqlparser.Call, *sqlparser.Begin,
254+
*sqlparser.Use, *sqlparser.Load, *sqlparser.Execute,
255+
*sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
256+
shouldQuery = true
230257
default:
231-
var res gosql.Result
232-
if stmt != nil {
233-
res, err = stmt.Exec(args...)
234-
} else {
235-
res, err = s.conn.Exec(query, args...)
236-
}
237-
if err != nil {
238-
return nil, nil, nil, trimMySQLErrCodePrefix(err)
239-
}
240-
return convertExecResult(res)
241258
}
259+
260+
if shouldQuery {
261+
return s.query(ctx, stmt, query, args)
262+
}
263+
return s.exec(ctx, stmt, query, args)
242264
}
243265

244266
// trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.

0 commit comments

Comments
 (0)