Skip to content

Commit 81264bf

Browse files
committed
Add support for UPDATE ... RETURNING
1 parent 2d6449f commit 81264bf

File tree

6 files changed

+98
-24
lines changed

6 files changed

+98
-24
lines changed

sql/analyzer/fix_exec_indexes.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,10 @@ func (s *idxScope) visitSelf(n sql.Node) error {
514514
newCheck.Expr = newE
515515
s.checks = append(s.checks, &newCheck)
516516
}
517+
for _, r := range n.Returning {
518+
newE := fixExprToScope(r, srcScope)
519+
s.expressions = append(s.expressions, newE)
520+
}
517521
case *plan.LoadData:
518522
scope := &idxScope{}
519523
scope.addSchema(n.DestSch)
@@ -556,6 +560,10 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
556560
nn.Returning = s.expressions[len(n.OnDupExprs):]
557561
return nn.WithChecks(s.checks), nil
558562
default:
563+
if nn, ok := n.(*plan.Update); ok {
564+
nn.Returning = s.expressions
565+
}
566+
559567
s.ids = columnIdsForNode(n)
560568

561569
s.addSchema(n.Schema())

sql/plan/update.go

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/dolthub/go-mysql-server/sql"
2323
"github.com/dolthub/go-mysql-server/sql/expression"
24+
"github.com/dolthub/go-mysql-server/sql/transform"
2425
)
2526

2627
var ErrUpdateNotSupported = errors.NewKind("table doesn't support UPDATE")
@@ -35,6 +36,10 @@ type Update struct {
3536
IsJoin bool
3637
HasSingleRel bool
3738
IsProcNested bool
39+
40+
// Returning is a list of expressions to return after the update operation. This feature is not
41+
// supported in MySQL's syntax, but is exposed through PostgreSQL's syntax.
42+
Returning []sql.Expression
3843
}
3944

4045
var _ sql.Node = (*Update)(nil)
@@ -112,6 +117,24 @@ func GetDatabase(node sql.Node) sql.Database {
112117
return nil
113118
}
114119

120+
// Schema implements the sql.Node interface.
121+
func (u *Update) Schema() sql.Schema {
122+
// Postgres allows the returned values of the update statement to be controlled, so if returning
123+
// expressions were specified, then we return a different schema.
124+
if u.Returning != nil {
125+
// We know that returning exprs are resolved here, because you can't call Schema()
126+
// safely until Resolved() is true.
127+
returningSchema := sql.Schema{}
128+
for _, expr := range u.Returning {
129+
returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, ""))
130+
}
131+
132+
return returningSchema
133+
}
134+
135+
return u.Child.Schema()
136+
}
137+
115138
func (u *Update) Checks() sql.CheckConstraints {
116139
return u.checks
117140
}
@@ -140,24 +163,32 @@ func (u *Update) Database() string {
140163
}
141164

142165
func (u *Update) Expressions() []sql.Expression {
143-
return u.checks.ToExpressions()
166+
exprs := append([]sql.Expression{}, u.checks.ToExpressions()...)
167+
exprs = append(exprs, u.Returning...)
168+
return exprs
144169
}
145170

146171
func (u *Update) Resolved() bool {
147-
return u.Child.Resolved() && expression.ExpressionsResolved(u.checks.ToExpressions()...)
172+
return u.Child.Resolved() &&
173+
expression.ExpressionsResolved(u.checks.ToExpressions()...) &&
174+
expression.ExpressionsResolved(u.Returning...)
175+
148176
}
149177

150178
func (u Update) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
151-
if len(newExprs) != len(u.checks) {
152-
return nil, sql.ErrInvalidChildrenNumber.New(u, len(newExprs), len(u.checks))
179+
expectedLength := len(u.checks) + len(u.Returning)
180+
if len(newExprs) != expectedLength {
181+
return nil, sql.ErrInvalidChildrenNumber.New(u, len(newExprs), expectedLength)
153182
}
154183

155184
var err error
156-
u.checks, err = u.checks.FromExpressions(newExprs)
185+
u.checks, err = u.checks.FromExpressions(newExprs[:len(u.checks)])
157186
if err != nil {
158187
return nil, err
159188
}
160189

190+
u.Returning = newExprs[len(u.checks):]
191+
161192
return &u, nil
162193
}
163194

sql/planbuilder/dml.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,15 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
581581
return true
582582
})
583583
}
584+
585+
if len(u.Returning) > 0 {
586+
returningExprs := make([]sql.Expression, len(u.Returning))
587+
for i, selectExpr := range u.Returning {
588+
returningExprs[i] = b.selectExprToExpression(outScope, selectExpr)
589+
}
590+
update.Returning = returningExprs
591+
}
592+
584593
outScope.node = update.WithChecks(checks)
585594
return
586595
}

sql/rowexec/dml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func (b *BaseBuilder) buildUpdate(ctx *sql.Context, n *plan.Update, row sql.Row)
168168
return nil, err
169169
}
170170

171-
return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore), nil
171+
return newUpdateIter(iter, updatable.Schema(), updater, n.Checks(), n.Ignore, n.Returning, n.Schema()), nil
172172
}
173173

174174
func (b *BaseBuilder) buildDropForeignKey(ctx *sql.Context, n *plan.DropForeignKey, row sql.Row) (sql.RowIter, error) {

sql/rowexec/dml_iters.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,15 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
610610
return i.WithChildIter(childIter), sch
611611
case *plan.TableEditorIter:
612612
// If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
613-
innerIter := i.InnerIter()
614-
if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 {
615-
return insertIter, insertIter.returnSchema
613+
switch innerIter := i.InnerIter().(type) {
614+
case *insertIter:
615+
if len(innerIter.returnExprs) > 0 {
616+
return innerIter, innerIter.returnSchema
617+
}
618+
case *updateIter:
619+
if len(innerIter.returnExprs) > 0 {
620+
return innerIter, innerIter.returnSchema
621+
}
616622
}
617623

618624
return defaultAccumulatorIter(ctx, iter)

sql/rowexec/update.go

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ import (
2323
)
2424

2525
type updateIter struct {
26-
childIter sql.RowIter
27-
schema sql.Schema
28-
updater sql.RowUpdater
29-
checks sql.CheckConstraints
30-
closed bool
31-
ignore bool
26+
childIter sql.RowIter
27+
schema sql.Schema
28+
updater sql.RowUpdater
29+
checks sql.CheckConstraints
30+
closed bool
31+
ignore bool
32+
returnExprs []sql.Expression
33+
returnSchema sql.Schema
3234
}
3335

3436
func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) {
@@ -66,6 +68,18 @@ func (u *updateIter) Next(ctx *sql.Context) (sql.Row, error) {
6668
return nil, u.ignoreOrError(ctx, newRow, err)
6769
}
6870
}
71+
72+
if len(u.returnExprs) > 0 {
73+
var retExprRow sql.Row
74+
for _, returnExpr := range u.returnExprs {
75+
result, err := returnExpr.Eval(ctx, newRow)
76+
if err != nil {
77+
return nil, err
78+
}
79+
retExprRow = append(retExprRow, result)
80+
}
81+
return retExprRow, nil
82+
}
6983
} else {
7084
return nil, err
7185
}
@@ -164,21 +178,27 @@ func newUpdateIter(
164178
updater sql.RowUpdater,
165179
checks sql.CheckConstraints,
166180
ignore bool,
181+
returnExprs []sql.Expression,
182+
returnSchema sql.Schema,
167183
) sql.RowIter {
168184
if ignore {
169185
return plan.NewCheckpointingTableEditorIter(&updateIter{
170-
childIter: childIter,
171-
updater: updater,
172-
schema: schema,
173-
checks: checks,
174-
ignore: true,
186+
childIter: childIter,
187+
updater: updater,
188+
schema: schema,
189+
checks: checks,
190+
ignore: true,
191+
returnExprs: returnExprs,
192+
returnSchema: returnSchema,
175193
}, updater)
176194
} else {
177195
return plan.NewTableEditorIter(&updateIter{
178-
childIter: childIter,
179-
updater: updater,
180-
schema: schema,
181-
checks: checks,
196+
childIter: childIter,
197+
updater: updater,
198+
schema: schema,
199+
checks: checks,
200+
returnExprs: returnExprs,
201+
returnSchema: returnSchema,
182202
}, updater)
183203
}
184204
}

0 commit comments

Comments
 (0)