Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sql/analyzer/fix_exec_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,10 @@ func (s *idxScope) visitSelf(n sql.Node) error {
newCheck.Expr = newE
s.checks = append(s.checks, &newCheck)
}
for _, r := range n.Returning {
newE := fixExprToScope(r, srcScope)
s.expressions = append(s.expressions, newE)
}
case *plan.LoadData:
scope := &idxScope{}
scope.addSchema(n.DestSch)
Expand Down Expand Up @@ -556,6 +560,10 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
nn.Returning = s.expressions[len(n.OnDupExprs):]
return nn.WithChecks(s.checks), nil
default:
if nn, ok := n.(*plan.Update); ok {
nn.Returning = s.expressions
}

s.ids = columnIdsForNode(n)

s.addSchema(n.Schema())
Expand Down
41 changes: 36 additions & 5 deletions sql/plan/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/transform"
)

var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE")
Expand All @@ -35,6 +36,10 @@ type Update struct {
IsJoin bool
HasSingleRel bool
IsProcNested bool

// Returning is a list of expressions to return after the update operation. This feature is not
// supported in MySQL's syntax, but is exposed through PostgreSQL's syntax.
Returning []sql.Expression
}

var _ sql.Node = (*Update)(nil)
Expand Down Expand Up @@ -112,6 +117,24 @@ func GetDatabase(node sql.Node) sql.Database {
return nil
}

// Schema implements the sql.Node interface.
func (u *Update) Schema() sql.Schema {
// Postgres allows the returned values of the update statement to be controlled, so if returning
// expressions were specified, then we return a different schema.
if u.Returning != nil {
// We know that returning exprs are resolved here, because you can't call Schema()
// safely until Resolved() is true.
returningSchema := sql.Schema{}
for _, expr := range u.Returning {
returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, ""))
}

return returningSchema
}

return u.Child.Schema()
}

func (u *Update) Checks() sql.CheckConstraints {
return u.checks
}
Expand Down Expand Up @@ -140,24 +163,32 @@ func (u *Update) Database() string {
}

func (u *Update) Expressions() []sql.Expression {
return u.checks.ToExpressions()
exprs := append([]sql.Expression{}, u.checks.ToExpressions()...)
exprs = append(exprs, u.Returning...)
return exprs
}

func (u *Update) Resolved() bool {
return u.Child.Resolved() && expression.ExpressionsResolved(u.checks.ToExpressions()...)
return u.Child.Resolved() &&
expression.ExpressionsResolved(u.checks.ToExpressions()...) &&
expression.ExpressionsResolved(u.Returning...)

}

func (u Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
if len(newExprs) != len(u.checks) {
return nil, sql.ErrInvalidChildrenNumber.New(u, len(newExprs), len(u.checks))
expectedLength := len(u.checks) + len(u.Returning)
if len(newExprs) != expectedLength {
return nil, sql.ErrInvalidChildrenNumber.New(u, len(newExprs), expectedLength)
}

var err error
u.checks, err = u.checks.FromExpressions(newExprs)
u.checks, err = u.checks.FromExpressions(newExprs[:len(u.checks)])
if err != nil {
return nil, err
}

u.Returning = newExprs[len(u.checks):]

return &u, nil
}

Expand Down
9 changes: 9 additions & 0 deletions sql/planbuilder/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,15 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
return true
})
}

if len(u.Returning) > 0 {
returningExprs := make([]sql.Expression, len(u.Returning))
for i, selectExpr := range u.Returning {
returningExprs[i] = b.selectExprToExpression(outScope, selectExpr)
}
update.Returning = returningExprs
}

outScope.node = update.WithChecks(checks)
return
}
Expand Down
2 changes: 1 addition & 1 deletion sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (b *BaseBuilder) buildUpdate(ctx *sql.Context, n *plan.Update, row sql.Row)
return nil, err
}

return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore), nil
return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore, n.Returning, n.Schema()), nil
}

func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignKey, row sql.Row) (sql.RowIter, error) {
Expand Down
12 changes: 9 additions & 3 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,15 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
return i.WithChildIter(childIter), sch
case *plan.TableEditorIter:
// If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
innerIter := i.InnerIter()
if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 {
return insertIter, insertIter.returnSchema
switch innerIter := i.InnerIter().(type) {
case *insertIter:
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
}
case *updateIter:
if len(innerIter.returnExprs) > 0 {
return innerIter, innerIter.returnSchema
}
}

return defaultAccumulatorIter(ctx, iter)
Expand Down
50 changes: 35 additions & 15 deletions sql/rowexec/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ import (
)

type updateIter struct {
childIter sql.RowIter
schema sql.Schema
updater sql.RowUpdater
checks sql.CheckConstraints
closed bool
ignore bool
childIter sql.RowIter
schema sql.Schema
updater sql.RowUpdater
checks sql.CheckConstraints
closed bool
ignore bool
returnExprs []sql.Expression
returnSchema sql.Schema
}

func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) {
Expand Down Expand Up @@ -66,6 +68,18 @@ func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, u.ignoreOrError(ctx, newRow, err)
}
}

if len(u.returnExprs) > 0 {
var retExprRow sql.Row
for _, returnExpr := range u.returnExprs {
result, err := returnExpr.Eval(ctx, newRow)
if err != nil {
return nil, err
}
retExprRow = append(retExprRow, result)
}
return retExprRow, nil
}
} else {
return nil, err
}
Expand Down Expand Up @@ -164,21 +178,27 @@ func newUpdateIter(
updater sql.RowUpdater,
checks sql.CheckConstraints,
ignore bool,
returnExprs []sql.Expression,
returnSchema sql.Schema,
) sql.RowIter {
if ignore {
return plan.NewCheckpointingTableEditorIter(&updateIter{
childIter: childIter,
updater: updater,
schema: schema,
checks: checks,
ignore: true,
childIter: childIter,
updater: updater,
schema: schema,
checks: checks,
ignore: true,
returnExprs: returnExprs,
returnSchema: returnSchema,
}, updater)
} else {
return plan.NewTableEditorIter(&updateIter{
childIter: childIter,
updater: updater,
schema: schema,
checks: checks,
childIter: childIter,
updater: updater,
schema: schema,
checks: checks,
returnExprs: returnExprs,
returnSchema: returnSchema,
}, updater)
}
}
Expand Down