diff --git a/sql/analyzer/fix_exec_indexes.go b/sql/analyzer/fix_exec_indexes.go index 92382758e4..2c30e3db94 100644 --- a/sql/analyzer/fix_exec_indexes.go +++ b/sql/analyzer/fix_exec_indexes.go @@ -494,6 +494,10 @@ func (s *idxScope) visitSelf(n sql.Node) error { newCheck.Expr = newE s.checks = append(s.checks, &newCheck) } + for _, r := range n.Returning { + newE := fixExprToScope(r, dstScope) + s.expressions = append(s.expressions, newE) + } case *plan.Update: newScope := s.copy() srcScope := s.childScopes[0] @@ -543,7 +547,8 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) { nn := *n nn.Source = s.children[0] nn.Destination = s.children[1] - nn.OnDupExprs = s.expressions + nn.OnDupExprs = s.expressions[:len(n.OnDupExprs)] + nn.Returning = s.expressions[len(n.OnDupExprs):] return nn.WithChecks(s.checks), nil default: s.ids = columnIdsForNode(n) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 293d5483e2..5c7a24da12 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -20,6 +20,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/transform" ) @@ -70,6 +71,10 @@ type InsertInto struct { // a |Values| node with only literal expressions. LiteralValueSource bool + // Returning is a list of expressions to return after the insert operation. This feature is not supported + // in MySQL's syntax, but is exposed through PostgreSQL's syntax. + Returning []sql.Expression + // FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id. FirstGeneratedAutoIncRowIdx int @@ -122,6 +127,19 @@ func (ii *InsertInto) Schema() sql.Schema { if ii.IsReplace { return append(ii.Destination.Schema(), ii.Destination.Schema()...) } + + // Postgres allows the returned values of the insert statement to be controlled, so if returning expressions + // were specified, then we return a different schema. + if ii.Returning != nil { + // We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true. + returningSchema := sql.Schema{} + for _, expr := range ii.Returning { + returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, "")) + } + + return returningSchema + } + return ii.Destination.Schema() } @@ -238,24 +256,30 @@ func (ii *InsertInto) DebugString() string { // Expressions implements the sql.Expressioner interface. func (ii *InsertInto) Expressions() []sql.Expression { - return append(ii.OnDupExprs, ii.checks.ToExpressions()...) + exprs := append(ii.OnDupExprs, ii.checks.ToExpressions()...) + exprs = append(exprs, ii.Returning...) + return exprs } // WithExpressions implements the sql.Expressioner interface. func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { - if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks) { - return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)) + if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning) { + return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning)) } nii := *ii nii.OnDupExprs = newExprs[:len(nii.OnDupExprs)] + newExprs = newExprs[len(nii.OnDupExprs):] var err error - nii.checks, err = nii.checks.FromExpressions(newExprs[len(nii.OnDupExprs):]) + nii.checks, err = nii.checks.FromExpressions(newExprs[:len(nii.checks)]) if err != nil { return nil, err } + newExprs = newExprs[len(nii.checks):] + nii.Returning = newExprs + return &nii, nil } @@ -264,17 +288,15 @@ func (ii *InsertInto) Resolved() bool { if !ii.Destination.Resolved() || !ii.Source.Resolved() { return false } - for _, updateExpr := range ii.OnDupExprs { - if !updateExpr.Resolved() { - return false - } - } + for _, checkExpr := range ii.checks { if !checkExpr.Expr.Resolved() { return false } } - return true + + return expression.ExpressionsResolved(ii.OnDupExprs...) && + expression.ExpressionsResolved(ii.Returning...) } // InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 71b8f9b56d..1821989956 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -150,6 +150,14 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) { ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore) ins.LiteralValueSource = srcLiteralOnly + if i.Returning != nil { + returningExprs := make([]sql.Expression, len(i.Returning)) + for i, selectExpr := range i.Returning { + returningExprs[i] = b.selectExprToExpression(destScope, selectExpr) + } + ins.Returning = returningExprs + } + b.validateInsert(ins) outScope = destScope diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index f2d937632b..a0676f223f 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -90,6 +90,8 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row ctx: ctx, ignore: ii.Ignore, firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx, + returnExprs: ii.Returning, + returnSchema: ii.Schema(), deferredDefaults: ii.DeferredDefaults, } diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index 7f6c87bb1e..57dfd72555 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -590,17 +590,30 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc childIter := i.GetChildIter() childIter, sch := AddAccumulatorIter(ctx, childIter) return i.WithChildIter(childIter), sch - default: - clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 - rowHandler := getRowHandler(clientFoundRowsToggled, iter) - if rowHandler == nil { - return iter, nil + case *plan.TableEditorIter: + // If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter + innerIter := i.InnerIter() + if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 { + return insertIter, insertIter.returnSchema } - return &accumulatorIter{ - iter: iter, - updateRowHandler: rowHandler, - }, types.OkResultSchema + + return defaultAccumulatorIter(ctx, iter) + default: + return defaultAccumulatorIter(ctx, iter) + } +} + +// defaultAccumulatorIter returns the default accumulator iter for a DML node +func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) { + clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 + rowHandler := getRowHandler(clientFoundRowsToggled, iter) + if rowHandler == nil { + return iter, nil } + return &accumulatorIter{ + iter: iter, + updateRowHandler: rowHandler, + }, types.OkResultSchema } func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index a23f67f1f5..659c508fcc 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -30,19 +30,21 @@ import ( ) type insertIter struct { - schema sql.Schema - inserter sql.RowInserter - replacer sql.RowReplacer - updater sql.RowUpdater - rowSource sql.RowIter - unlocker func() - ctx *sql.Context - insertExprs []sql.Expression - updateExprs []sql.Expression - checks sql.CheckConstraints - tableNode sql.Node - closed bool - ignore bool + schema sql.Schema + inserter sql.RowInserter + replacer sql.RowReplacer + updater sql.RowUpdater + rowSource sql.RowIter + unlocker func() + ctx *sql.Context + insertExprs []sql.Expression + updateExprs []sql.Expression + checks sql.CheckConstraints + tableNode sql.Node + closed bool + ignore bool + returnExprs []sql.Expression + returnSchema sql.Schema firstGeneratedAutoIncRowIdx int @@ -175,6 +177,18 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) i.updateLastInsertId(ctx, row) + if len(i.returnExprs) > 0 { + var retExprRow sql.Row + for _, returnExpr := range i.returnExprs { + result, err := returnExpr.Eval(ctx, row) + if err != nil { + return nil, err + } + retExprRow = append(retExprRow, result) + } + return retExprRow, nil + } + return row, nil }