Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions enginetest/server_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,39 +206,61 @@ func (s *ServerQueryEngine) QueryWithBindings(ctx *sql.Context, query string, pa
return s.queryOrExec(ctx, stmt, parsed, query, args)
}

func (s *ServerQueryEngine) query(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
var rows *gosql.Rows
var err error
if stmt != nil {
rows, err = stmt.Query(args...)
} else {
rows, err = s.conn.Query(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertRowsResult(ctx, rows)
}

func (s *ServerQueryEngine) exec(ctx *sql.Context, stmt *gosql.Stmt, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
var res gosql.Result
var err error
if stmt != nil {
res, err = stmt.Exec(args...)
} else {
res, err = s.conn.Exec(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertExecResult(res)
}

// queryOrExec function use `query()` or `exec()` method of go-sql-driver depending on the sql parser plan.
// If |stmt| is nil, then we use the connection db to query/exec the given query statement because some queries cannot
// be run as prepared.
// TODO: for `EXECUTE` and `CALL` statements, it can be either query or exec depending on the statement that prepared or stored procedure holds.
//
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, there will be no result.
// for now, we use `query` to get the row results for these statements. For statements that needs `exec`, the result is OkResult.
func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, parsed sqlparser.Statement, query string, args []any) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) {
var err error
switch parsed.(type) {
// TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned.
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
var rows *gosql.Rows
if stmt != nil {
rows, err = stmt.Query(args...)
} else {
rows, err = s.conn.Query(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertRowsResult(ctx, rows)
var shouldQuery bool
switch p := parsed.(type) {
// Insert statements with a returning clause return rows, not OkResult, so we need to call stmt.Query instead of stmt.Exec
case *sqlparser.Insert:
if p.Returning != nil {
shouldQuery = true
}
case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show,
*sqlparser.Set, *sqlparser.Call, *sqlparser.Begin,
*sqlparser.Use, *sqlparser.Load, *sqlparser.Execute,
*sqlparser.Analyze, *sqlparser.Flush, *sqlparser.Explain:
shouldQuery = true
default:
var res gosql.Result
if stmt != nil {
res, err = stmt.Exec(args...)
} else {
res, err = s.conn.Exec(query, args...)
}
if err != nil {
return nil, nil, nil, trimMySQLErrCodePrefix(err)
}
return convertExecResult(res)
}

if shouldQuery {
return s.query(ctx, stmt, query, args)
}
return s.exec(ctx, stmt, query, args)
}

// trimMySQLErrCodePrefix temporarily removes the error code part of the error message returned from the server.
Expand Down