Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sql/analyzer/fix_exec_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ func (s *idxScope) visitSelf(n sql.Node) error {
newCheck.Expr = newE
s.checks = append(s.checks, &newCheck)
}
for _, r := range n.Returning {
newE := fixExprToScope(r, dstScope)
s.expressions = append(s.expressions, newE)
}
case *plan.Update:
newScope := s.copy()
srcScope := s.childScopes[0]
Expand Down Expand Up @@ -543,7 +547,8 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
nn := *n
nn.Source = s.children[0]
nn.Destination = s.children[1]
nn.OnDupExprs = s.expressions
nn.OnDupExprs = s.expressions[:len(n.OnDupExprs)]
nn.Returning = s.expressions[len(n.OnDupExprs):]
return nn.WithChecks(s.checks), nil
default:
s.ids = columnIdsForNode(n)
Expand Down
42 changes: 32 additions & 10 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/transform"
)

Expand Down Expand Up @@ -70,6 +71,10 @@ type InsertInto struct {
// a |Values| node with only literal expressions.
LiteralValueSource bool

// Returning is a list of expressions to return after the insert operation. This feature is not supported
// in MySQL's syntax, but is exposed through PostgreSQL's syntax.
Returning []sql.Expression

// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
FirstGeneratedAutoIncRowIdx int

Expand Down Expand Up @@ -122,6 +127,19 @@ func (ii *InsertInto) Schema() sql.Schema {
if ii.IsReplace {
return append(ii.Destination.Schema(), ii.Destination.Schema()...)
}

// Postgres allows the returned values of the insert statement to be controlled, so if returning expressions
// were specified, then we return a different schema.
if ii.Returning != nil {
// We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true.
returningSchema := sql.Schema{}
for _, expr := range ii.Returning {
returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, ""))
}

return returningSchema
}

return ii.Destination.Schema()
}

Expand Down Expand Up @@ -238,24 +256,30 @@ func (ii *InsertInto) DebugString() string {

// Expressions implements the sql.Expressioner interface.
func (ii *InsertInto) Expressions() []sql.Expression {
return append(ii.OnDupExprs, ii.checks.ToExpressions()...)
exprs := append(ii.OnDupExprs, ii.checks.ToExpressions()...)
exprs = append(exprs, ii.Returning...)
return exprs
}

// WithExpressions implements the sql.Expressioner interface.
func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks) {
return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks))
if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning) {
return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning))
}

nii := *ii
nii.OnDupExprs = newExprs[:len(nii.OnDupExprs)]
newExprs = newExprs[len(nii.OnDupExprs):]

var err error
nii.checks, err = nii.checks.FromExpressions(newExprs[len(nii.OnDupExprs):])
nii.checks, err = nii.checks.FromExpressions(newExprs[:len(nii.checks)])
if err != nil {
return nil, err
}

newExprs = newExprs[len(nii.checks):]
nii.Returning = newExprs

return &nii, nil
}

Expand All @@ -264,17 +288,15 @@ func (ii *InsertInto) Resolved() bool {
if !ii.Destination.Resolved() || !ii.Source.Resolved() {
return false
}
for _, updateExpr := range ii.OnDupExprs {
if !updateExpr.Resolved() {
return false
}
}

for _, checkExpr := range ii.checks {
if !checkExpr.Expr.Resolved() {
return false
}
}
return true

return expression.ExpressionsResolved(ii.OnDupExprs...) &&
expression.ExpressionsResolved(ii.Returning...)
}

// InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be
Expand Down
8 changes: 8 additions & 0 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore)
ins.LiteralValueSource = srcLiteralOnly

if i.Returning != nil {
returningExprs := make([]sql.Expression, len(i.Returning))
for i, selectExpr := range i.Returning {
returningExprs[i] = b.selectExprToExpression(destScope, selectExpr)
}
ins.Returning = returningExprs
}

b.validateInsert(ins)

outScope = destScope
Expand Down
2 changes: 2 additions & 0 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
ctx: ctx,
ignore: ii.Ignore,
firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx,
returnExprs: ii.Returning,
returnSchema: ii.Schema(),
deferredDefaults: ii.DeferredDefaults,
}

Expand Down
31 changes: 22 additions & 9 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,17 +590,30 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
childIter := i.GetChildIter()
childIter, sch := AddAccumulatorIter(ctx, childIter)
return i.WithChildIter(childIter), sch
default:
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
if rowHandler == nil {
return iter, nil
case *plan.TableEditorIter:
// If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
innerIter := i.InnerIter()
if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 {
return insertIter, insertIter.returnSchema
}
return &accumulatorIter{
iter: iter,
updateRowHandler: rowHandler,
}, types.OkResultSchema

return defaultAccumulatorIter(ctx, iter)
default:
return defaultAccumulatorIter(ctx, iter)
}
}

// defaultAccumulatorIter returns the default accumulator iter for a DML node
func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) {
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
if rowHandler == nil {
return iter, nil
}
return &accumulatorIter{
iter: iter,
updateRowHandler: rowHandler,
}, types.OkResultSchema
}

func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {
Expand Down
40 changes: 27 additions & 13 deletions sql/rowexec/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,21 @@ import (
)

type insertIter struct {
schema sql.Schema
inserter sql.RowInserter
replacer sql.RowReplacer
updater sql.RowUpdater
rowSource sql.RowIter
unlocker func()
ctx *sql.Context
insertExprs []sql.Expression
updateExprs []sql.Expression
checks sql.CheckConstraints
tableNode sql.Node
closed bool
ignore bool
schema sql.Schema
inserter sql.RowInserter
replacer sql.RowReplacer
updater sql.RowUpdater
rowSource sql.RowIter
unlocker func()
ctx *sql.Context
insertExprs []sql.Expression
updateExprs []sql.Expression
checks sql.CheckConstraints
tableNode sql.Node
closed bool
ignore bool
returnExprs []sql.Expression
returnSchema sql.Schema

firstGeneratedAutoIncRowIdx int

Expand Down Expand Up @@ -175,6 +177,18 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)

i.updateLastInsertId(ctx, row)

if len(i.returnExprs) > 0 {
var retExprRow sql.Row
for _, returnExpr := range i.returnExprs {
result, err := returnExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
retExprRow = append(retExprRow, result)
}
return retExprRow, nil
}

return row, nil
}

Expand Down