Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions enginetest/queries/insert_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,36 @@ var InsertScripts = []ScriptTest{
},
},
},
{
Name: "insert...returning... statements",
Dialect: "mysql", // actually mariadb
SetUpScript: []string{
"CREATE TABLE animals (id int, name varchar(20))",
"CREATE TABLE auto_pk (`pk` int NOT NULL AUTO_INCREMENT, `name` varchar(20), PRIMARY KEY (`pk`))",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into animals (id) values (2) returning id",
Expected: []sql.Row{{2}},
},
{
Query: "insert into animals(id,name) values (1, 'Dog'),(2,'Lion'),(3,'Tiger'),(4,'Leopard') returning id, id+id",
Expected: []sql.Row{{1, 2}, {2, 4}, {3, 6}, {4, 8}},
},
{
Query: "insert into animals set id=1,name='Bear' returning id,name",
Expected: []sql.Row{{1, "Bear"}},
},
{
Query: "insert into auto_pk (name) values ('Cat') returning pk,name",
Expected: []sql.Row{{1, "Cat"}},
},
{
Query: "insert into auto_pk values (NULL, 'Dog'),(5, 'Fish'),(NULL, 'Horse') returning pk,name",
Expected: []sql.Row{{2, "Dog"}, {5, "Fish"}, {6, "Horse"}},
},
},
},
}

var InsertDuplicateKeyKeyless = []ScriptTest{
Expand Down
70 changes: 46 additions & 24 deletions enginetest/server_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
return s.queryOrExec(ctx, stmt, parsed, query, args)
}

func (s *ServerQueryEngine) query(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
var rows *gosql.Rows
var err error
if stmt != nil {
rows, err = stmt.Query(args...)
} else {
rows, err = s.conn.Query(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertRowsResult(ctx, rows)
}

func (s *ServerQueryEngine) exec(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
var res gosql.Result
var err error
if stmt != nil {
res, err = stmt.Exec(args...)
} else {
res, err = s.conn.Exec(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertExecResult(res)
}

// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
// If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
// be run as prepared.
// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
//
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result.
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult.
func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
var err error
switch parsed.(type) {
// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
var rows *gosql.Rows
if stmt != nil {
rows, err = stmt.Query(args...)
} else {
rows, err = s.conn.Query(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertRowsResult(ctx, rows)
var shouldQuery bool
switch p := parsed.(type) {
// Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec
case *sqlparser.Insert:
if p.Returning != nil {
shouldQuery = true
}
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show,
*sqlparser.Set, *sqlparser.Call, *sqlparser.Begin,
*sqlparser.Use, *sqlparser.Load, *sqlparser.Execute,
*sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
shouldQuery = true
default:
var res gosql.Result
if stmt != nil {
res, err = stmt.Exec(args...)
} else {
res, err = s.conn.Exec(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertExecResult(res)
}

if shouldQuery {
return s.query(ctx, stmt, query, args)
}
return s.exec(ctx, stmt, query, args)
}

// trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.
Expand Down
5 changes: 4 additions & 1 deletion server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
// than they will at execution time.
func nodeReturnsOkResultSchema(node sql.Node) bool {
switch node.(type) {
case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom:
case *plan.InsertInto:
insertNode, _ := node.(*plan.InsertInto)
return insertNode.Returning == nil
case *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom:
return true
}
return false
Expand Down
2 changes: 1 addition & 1 deletion sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type InsertInto struct {
LiteralValueSource bool

// Returning is a list of expressions to return after the insert operation. This feature is not supported
// in MySQL's syntax, but is exposed through PostgreSQL's syntax.
// in MySQL's syntax, but is exposed through PostgreSQL's and MariaDB's syntax.
Returning []sql.Expression

// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
Expand Down