Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sql/analyzer/fix_exec_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ func (s *idxScope) visitSelf(n sql.Node) error {
newCheck.Expr = newE
s.checks = append(s.checks, &newCheck)
}
for _, r := range n.Returning {
newE := fixExprToScope(r, dstScope)
s.expressions = append(s.expressions, newE)
}
case *plan.Update:
newScope := s.copy()
srcScope := s.childScopes[0]
Expand Down Expand Up @@ -543,7 +547,8 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
nn := *n
nn.Source = s.children[0]
nn.Destination = s.children[1]
nn.OnDupExprs = s.expressions
nn.OnDupExprs = s.expressions[:len(n.OnDupExprs)]
nn.Returning = s.expressions[len(n.OnDupExprs):]
return nn.WithChecks(s.checks), nil
default:
s.ids = columnIdsForNode(n)
Expand Down
47 changes: 37 additions & 10 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/transform"
)

Expand Down Expand Up @@ -70,6 +71,10 @@ type InsertInto struct {
// a |Values| node with only literal expressions.
LiteralValueSource bool

// Returning is a list of expressions to return after the insert operation. This feature is not supported
// in MySQL's syntax, but is exposed through PostgreSQL's syntax.
Returning []sql.Expression

// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
FirstGeneratedAutoIncRowIdx int

Expand Down Expand Up @@ -122,6 +127,24 @@ func (ii *InsertInto) Schema() sql.Schema {
if ii.IsReplace {
return append(ii.Destination.Schema(), ii.Destination.Schema()...)
}

// Postgres allows the returned values of the insert statement to be controlled, so if returning expressions
// were specified, then we return a different schema.
// TODO: does anything else depend on the schema returned by insert statements? triggers?
// TODO: Do we need to check for the expressions being fully resolved? (probably!)
if ii.Returning != nil {
// TODO: If we don't return the destination schema anymore... does that mess up other things, like trigger processing?
// TODO: we need to look at the expressions in the returning clause

// We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true.
returningSchema := sql.Schema{}
for _, expr := range ii.Returning {
returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, ""))
}

return returningSchema
}

return ii.Destination.Schema()
}

Expand Down Expand Up @@ -238,24 +261,30 @@ func (ii *InsertInto) DebugString() string {

// Expressions implements the sql.Expressioner interface.
func (ii *InsertInto) Expressions() []sql.Expression {
return append(ii.OnDupExprs, ii.checks.ToExpressions()...)
exprs := append(ii.OnDupExprs, ii.checks.ToExpressions()...)
exprs = append(exprs, ii.Returning...)
return exprs
}

// WithExpressions implements the sql.Expressioner interface.
func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks) {
return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks))
if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning) {
return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning))
}

nii := *ii
nii.OnDupExprs = newExprs[:len(nii.OnDupExprs)]
newExprs = newExprs[len(nii.OnDupExprs):]

var err error
nii.checks, err = nii.checks.FromExpressions(newExprs[len(nii.OnDupExprs):])
nii.checks, err = nii.checks.FromExpressions(newExprs[:len(nii.checks)])
if err != nil {
return nil, err
}

newExprs = newExprs[len(nii.checks):]
nii.Returning = newExprs

return &nii, nil
}

Expand All @@ -264,17 +293,15 @@ func (ii *InsertInto) Resolved() bool {
if !ii.Destination.Resolved() || !ii.Source.Resolved() {
return false
}
for _, updateExpr := range ii.OnDupExprs {
if !updateExpr.Resolved() {
return false
}
}

for _, checkExpr := range ii.checks {
if !checkExpr.Expr.Resolved() {
return false
}
}
return true

return expression.ExpressionsResolved(ii.OnDupExprs...) &&
expression.ExpressionsResolved(ii.Returning...)
}

// InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be
Expand Down
8 changes: 8 additions & 0 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore)
ins.LiteralValueSource = srcLiteralOnly

if i.Returning != nil {
returningExprs := make([]sql.Expression, len(i.Returning))
for i, selectExpr := range i.Returning {
returningExprs[i] = b.selectExprToExpression(destScope, selectExpr)
}
ins.Returning = returningExprs
}

b.validateInsert(ins)

outScope = destScope
Expand Down
2 changes: 2 additions & 0 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
ctx: ctx,
ignore: ii.Ignore,
firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx,
returnExprs: ii.Returning,
returnSchema: ii.Schema(),
deferredDefaults: ii.DeferredDefaults,
}

Expand Down
31 changes: 22 additions & 9 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,17 +590,30 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
childIter := i.GetChildIter()
childIter, sch := AddAccumulatorIter(ctx, childIter)
return i.WithChildIter(childIter), sch
default:
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
if rowHandler == nil {
return iter, nil
case *plan.TableEditorIter:
// If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
innerIter := i.InnerIter()
if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 {
return insertIter, insertIter.returnSchema
}
return &accumulatorIter{
iter: iter,
updateRowHandler: rowHandler,
}, types.OkResultSchema

return defaultAccumulatorIter(ctx, iter)
default:
return defaultAccumulatorIter(ctx, iter)
}
}

// defaultAccumulatorIter returns the default accumulator iter for a DML node
func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) {
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
if rowHandler == nil {
return iter, nil
}
return &accumulatorIter{
iter: iter,
updateRowHandler: rowHandler,
}, types.OkResultSchema
}

func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {
Expand Down
40 changes: 27 additions & 13 deletions sql/rowexec/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,21 @@ import (
)

type insertIter struct {
schema sql.Schema
inserter sql.RowInserter
replacer sql.RowReplacer
updater sql.RowUpdater
rowSource sql.RowIter
unlocker func()
ctx *sql.Context
insertExprs []sql.Expression
updateExprs []sql.Expression
checks sql.CheckConstraints
tableNode sql.Node
closed bool
ignore bool
schema sql.Schema
inserter sql.RowInserter
replacer sql.RowReplacer
updater sql.RowUpdater
rowSource sql.RowIter
unlocker func()
ctx *sql.Context
insertExprs []sql.Expression
updateExprs []sql.Expression
checks sql.CheckConstraints
tableNode sql.Node
closed bool
ignore bool
returnExprs []sql.Expression
returnSchema sql.Schema

firstGeneratedAutoIncRowIdx int

Expand Down Expand Up @@ -175,6 +177,18 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)

i.updateLastInsertId(ctx, row)

if len(i.returnExprs) > 0 {
var retExprRow sql.Row
for _, returnExpr := range i.returnExprs {
result, err := returnExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
retExprRow = append(retExprRow, result)
}
return retExprRow, nil
}

return row, nil
}

Expand Down
Loading