Skip to content

Commit 3e99e22

Browse files
authored
Merge pull request #3012 from dolthub/angela/returning
Handle `insert...returning...` queries
2 parents fd7fc07 + c9e1686 commit 3e99e22

File tree

4 files changed

+81
-26
lines changed

4 files changed

+81
-26
lines changed

enginetest/queries/insert_queries.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,6 +2276,36 @@ var InsertScripts = []ScriptTest{
22762276
},
22772277
},
22782278
},
2279+
{
2280+
Name: "insert...returning... statements",
2281+
Dialect: "mysql", // actually mariadb
2282+
SetUpScript: []string{
2283+
"CREATE TABLE animals (id int, name varchar(20))",
2284+
"CREATE TABLE auto_pk (`pk` int NOT NULL AUTO_INCREMENT, `name` varchar(20), PRIMARY KEY (`pk`))",
2285+
},
2286+
Assertions: []ScriptTestAssertion{
2287+
{
2288+
Query: "insert into animals (id) values (2) returning id",
2289+
Expected: []sql.Row{{2}},
2290+
},
2291+
{
2292+
Query: "insert into animals(id,name) values (1, 'Dog'),(2,'Lion'),(3,'Tiger'),(4,'Leopard') returning id, id+id",
2293+
Expected: []sql.Row{{1, 2}, {2, 4}, {3, 6}, {4, 8}},
2294+
},
2295+
{
2296+
Query: "insert into animals set id=1,name='Bear' returning id,name",
2297+
Expected: []sql.Row{{1, "Bear"}},
2298+
},
2299+
{
2300+
Query: "insert into auto_pk (name) values ('Cat') returning pk,name",
2301+
Expected: []sql.Row{{1, "Cat"}},
2302+
},
2303+
{
2304+
Query: "insert into auto_pk values (NULL, 'Dog'),(5, 'Fish'),(NULL, 'Horse') returning pk,name",
2305+
Expected: []sql.Row{{2, "Dog"}, {5, "Fish"}, {6, "Horse"}},
2306+
},
2307+
},
2308+
},
22792309
}
22802310

22812311
var InsertDuplicateKeyKeyless = []ScriptTest{

enginetest/server_engine.go

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
206206
return s.queryOrExec(ctx, stmt, parsed, query, args)
207207
}
208208

209+
func (s *ServerQueryEngine) query(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
210+
var rows *gosql.Rows
211+
var err error
212+
if stmt != nil {
213+
rows, err = stmt.Query(args...)
214+
} else {
215+
rows, err = s.conn.Query(query, args...)
216+
}
217+
if err != nil {
218+
return nil, nil, nil, trimMySQLErrCodePrefix(err)
219+
}
220+
return convertRowsResult(ctx, rows)
221+
}
222+
223+
func (s *ServerQueryEngine) exec(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
224+
var res gosql.Result
225+
var err error
226+
if stmt != nil {
227+
res, err = stmt.Exec(args...)
228+
} else {
229+
res, err = s.conn.Exec(query, args...)
230+
}
231+
if err != nil {
232+
return nil, nil, nil, trimMySQLErrCodePrefix(err)
233+
}
234+
return convertExecResult(res)
235+
}
236+
209237
// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
210238
// If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
211239
// be run as prepared.
212240
// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
213241
//
214-
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result.
242+
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult.
215243
func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
216-
var err error
217-
switch parsed.(type) {
218244
// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
219-
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
220-
var rows *gosql.Rows
221-
if stmt != nil {
222-
rows, err = stmt.Query(args...)
223-
} else {
224-
rows, err = s.conn.Query(query, args...)
225-
}
226-
if err != nil {
227-
return nil, nil, nil, trimMySQLErrCodePrefix(err)
228-
}
229-
return convertRowsResult(ctx, rows)
245+
var shouldQuery bool
246+
switch p := parsed.(type) {
247+
// Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec
248+
case *sqlparser.Insert:
249+
if p.Returning != nil {
250+
shouldQuery = true
251+
}
252+
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show,
253+
*sqlparser.Set, *sqlparser.Call, *sqlparser.Begin,
254+
*sqlparser.Use, *sqlparser.Load, *sqlparser.Execute,
255+
*sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
256+
shouldQuery = true
230257
default:
231-
var res gosql.Result
232-
if stmt != nil {
233-
res, err = stmt.Exec(args...)
234-
} else {
235-
res, err = s.conn.Exec(query, args...)
236-
}
237-
if err != nil {
238-
return nil, nil, nil, trimMySQLErrCodePrefix(err)
239-
}
240-
return convertExecResult(res)
241258
}
259+
260+
if shouldQuery {
261+
return s.query(ctx, stmt, query, args)
262+
}
263+
return s.exec(ctx, stmt, query, args)
242264
}
243265

244266
// trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.

server/handler.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
157157
// than they will at execution time.
158158
func nodeReturnsOkResultSchema(node sql.Node) bool {
159159
switch node.(type) {
160-
case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom:
160+
case *plan.InsertInto:
161+
insertNode, _ := node.(*plan.InsertInto)
162+
return insertNode.Returning == nil
163+
case *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom:
161164
return true
162165
}
163166
return false

sql/plan/insert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ type InsertInto struct {
7272
LiteralValueSource bool
7373

7474
// Returning is a list of expressions to return after the insert operation. This feature is not supported
75-
// in MySQL's syntax, but is exposed through PostgreSQL's syntax.
75+
// in MySQL's syntax, but is exposed through PostgreSQL's and MariaDB's syntax.
7676
Returning []sql.Expression
7777

7878
// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.

0 commit comments

Comments
 (0)