diff --git a/enginetest/queries/insert_queries.go b/enginetest/queries/insert_queries.go index 5b3e6ce1c2..a2e56c2591 100644 --- a/enginetest/queries/insert_queries.go +++ b/enginetest/queries/insert_queries.go @@ -2276,6 +2276,36 @@ var InsertScripts = []ScriptTest{ }, }, }, + { + Name: "insert...returning... statements", + Dialect: "mysql", // actually mariadb + SetUpScript: []string{ + "CREATE TABLE animals (id int, name varchar(20))", + "CREATE TABLE auto_pk (`pk` int NOT NULL AUTO_INCREMENT, `name` varchar(20), PRIMARY KEY (`pk`))", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "insert into animals (id) values (2) returning id", + Expected: []sql.Row{{2}}, + }, + { + Query: "insert into animals(id,name) values (1, 'Dog'),(2,'Lion'),(3,'Tiger'),(4,'Leopard') returning id, id+id", + Expected: []sql.Row{{1, 2}, {2, 4}, {3, 6}, {4, 8}}, + }, + { + Query: "insert into animals set id=1,name='Bear' returning id,name", + Expected: []sql.Row{{1, "Bear"}}, + }, + { + Query: "insert into auto_pk (name) values ('Cat') returning pk,name", + Expected: []sql.Row{{1, "Cat"}}, + }, + { + Query: "insert into auto_pk values (NULL, 'Dog'),(5, 'Fish'),(NULL, 'Horse') returning pk,name", + Expected: []sql.Row{{2, "Dog"}, {5, "Fish"}, {6, "Horse"}}, + }, + }, + }, } var InsertDuplicateKeyKeyless = []ScriptTest{ diff --git a/enginetest/server_engine.go b/enginetest/server_engine.go index 0a222ef534..e2b1bd8f71 100644 --- a/enginetest/server_engine.go +++ b/enginetest/server_engine.go @@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa return s.queryOrExec(ctx, stmt, parsed, query, args) } +func (s *ServerQueryEngine) query(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + var rows *gosql.Rows + var err error + if stmt != nil { + rows, err = stmt.Query(args...) + } else { + rows, err = s.conn.Query(query, args...) + } + if err != nil { + return nil, nil, nil, trimMySQLErrCodePrefix(err) + } + return convertRowsResult(ctx, rows) +} + +func (s *ServerQueryEngine) exec(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + var res gosql.Result + var err error + if stmt != nil { + res, err = stmt.Exec(args...) + } else { + res, err = s.conn.Exec(query, args...) + } + if err != nil { + return nil, nil, nil, trimMySQLErrCodePrefix(err) + } + return convertExecResult(res) +} + // queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan. // If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot // be run as prepared. // TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds. // -// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result. +// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult. func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { - var err error - switch parsed.(type) { // TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned. - case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain: - var rows *gosql.Rows - if stmt != nil { - rows, err = stmt.Query(args...) - } else { - rows, err = s.conn.Query(query, args...) - } - if err != nil { - return nil, nil, nil, trimMySQLErrCodePrefix(err) - } - return convertRowsResult(ctx, rows) + var shouldQuery bool + switch p := parsed.(type) { + // Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec + case *sqlparser.Insert: + if p.Returning != nil { + shouldQuery = true + } + case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, + *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, + *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, + *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain: + shouldQuery = true default: - var res gosql.Result - if stmt != nil { - res, err = stmt.Exec(args...) - } else { - res, err = s.conn.Exec(query, args...) - } - if err != nil { - return nil, nil, nil, trimMySQLErrCodePrefix(err) - } - return convertExecResult(res) } + + if shouldQuery { + return s.query(ctx, stmt, query, args) + } + return s.exec(ctx, stmt, query, args) } // trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server. diff --git a/server/handler.go b/server/handler.go index e3c7d57a50..113e9cc978 100644 --- a/server/handler.go +++ b/server/handler.go @@ -157,7 +157,10 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p // than they will at execution time. func nodeReturnsOkResultSchema(node sql.Node) bool { switch node.(type) { - case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: + case *plan.InsertInto: + insertNode, _ := node.(*plan.InsertInto) + return insertNode.Returning == nil + case *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: return true } return false diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 5c7a24da12..9c5dd6272e 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -72,7 +72,7 @@ type InsertInto struct { LiteralValueSource bool // Returning is a list of expressions to return after the insert operation. This feature is not supported - // in MySQL's syntax, but is exposed through PostgreSQL's syntax. + // in MySQL's syntax, but is exposed through PostgreSQL's and MariaDB's syntax. Returning []sql.Expression // FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.