Skip to content

Commit 6ad8521

Browse files
authored
Merge pull request #2896 from dolthub/zachmu/insert-returning
support for `INSERT .. RETURNING`, a postgres extension
2 parents fa97159 + 2a94ec7 commit 6ad8521

File tree

6 files changed

+97
-33
lines changed

6 files changed

+97
-33
lines changed

sql/analyzer/fix_exec_indexes.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,10 @@ func (s *idxScope) visitSelf(n sql.Node) error {
494494
newCheck.Expr = newE
495495
s.checks = append(s.checks, &newCheck)
496496
}
497+
for _, r := range n.Returning {
498+
newE := fixExprToScope(r, dstScope)
499+
s.expressions = append(s.expressions, newE)
500+
}
497501
case *plan.Update:
498502
newScope := s.copy()
499503
srcScope := s.childScopes[0]
@@ -543,7 +547,8 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) {
543547
nn := *n
544548
nn.Source = s.children[0]
545549
nn.Destination = s.children[1]
546-
nn.OnDupExprs = s.expressions
550+
nn.OnDupExprs = s.expressions[:len(n.OnDupExprs)]
551+
nn.Returning = s.expressions[len(n.OnDupExprs):]
547552
return nn.WithChecks(s.checks), nil
548553
default:
549554
s.ids = columnIdsForNode(n)

sql/plan/insert.go

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"gopkg.in/src-d/go-errors.v1"
2121

2222
"github.com/dolthub/go-mysql-server/sql"
23+
"github.com/dolthub/go-mysql-server/sql/expression"
2324
"github.com/dolthub/go-mysql-server/sql/transform"
2425
)
2526

@@ -70,6 +71,10 @@ type InsertInto struct {
7071
// a |Values| node with only literal expressions.
7172
LiteralValueSource bool
7273

74+
// Returning is a list of expressions to return after the insert operation. This feature is not supported
75+
// in MySQL's syntax, but is exposed through PostgreSQL's syntax.
76+
Returning []sql.Expression
77+
7378
// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
7479
FirstGeneratedAutoIncRowIdx int
7580

@@ -122,6 +127,19 @@ func (ii *InsertInto) Schema() sql.Schema {
122127
if ii.IsReplace {
123128
return append(ii.Destination.Schema(), ii.Destination.Schema()...)
124129
}
130+
131+
// Postgres allows the returned values of the insert statement to be controlled, so if returning expressions
132+
// were specified, then we return a different schema.
133+
if ii.Returning != nil {
134+
// We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true.
135+
returningSchema := sql.Schema{}
136+
for _, expr := range ii.Returning {
137+
returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, ""))
138+
}
139+
140+
return returningSchema
141+
}
142+
125143
return ii.Destination.Schema()
126144
}
127145

@@ -238,24 +256,30 @@ func (ii *InsertInto) DebugString() string {
238256

239257
// Expressions implements the sql.Expressioner interface.
240258
func (ii *InsertInto) Expressions() []sql.Expression {
241-
return append(ii.OnDupExprs, ii.checks.ToExpressions()...)
259+
exprs := append(ii.OnDupExprs, ii.checks.ToExpressions()...)
260+
exprs = append(exprs, ii.Returning...)
261+
return exprs
242262
}
243263

244264
// WithExpressions implements the sql.Expressioner interface.
245265
func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) {
246-
if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks) {
247-
return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks))
266+
if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning) {
267+
return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning))
248268
}
249269

250270
nii := *ii
251271
nii.OnDupExprs = newExprs[:len(nii.OnDupExprs)]
272+
newExprs = newExprs[len(nii.OnDupExprs):]
252273

253274
var err error
254-
nii.checks, err = nii.checks.FromExpressions(newExprs[len(nii.OnDupExprs):])
275+
nii.checks, err = nii.checks.FromExpressions(newExprs[:len(nii.checks)])
255276
if err != nil {
256277
return nil, err
257278
}
258279

280+
newExprs = newExprs[len(nii.checks):]
281+
nii.Returning = newExprs
282+
259283
return &nii, nil
260284
}
261285

@@ -264,17 +288,15 @@ func (ii *InsertInto) Resolved() bool {
264288
if !ii.Destination.Resolved() || !ii.Source.Resolved() {
265289
return false
266290
}
267-
for _, updateExpr := range ii.OnDupExprs {
268-
if !updateExpr.Resolved() {
269-
return false
270-
}
271-
}
291+
272292
for _, checkExpr := range ii.checks {
273293
if !checkExpr.Expr.Resolved() {
274294
return false
275295
}
276296
}
277-
return true
297+
298+
return expression.ExpressionsResolved(ii.OnDupExprs...) &&
299+
expression.ExpressionsResolved(ii.Returning...)
278300
}
279301

280302
// InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be

sql/planbuilder/dml.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
150150
ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore)
151151
ins.LiteralValueSource = srcLiteralOnly
152152

153+
if i.Returning != nil {
154+
returningExprs := make([]sql.Expression, len(i.Returning))
155+
for i, selectExpr := range i.Returning {
156+
returningExprs[i] = b.selectExprToExpression(destScope, selectExpr)
157+
}
158+
ins.Returning = returningExprs
159+
}
160+
153161
b.validateInsert(ins)
154162

155163
outScope = destScope

sql/rowexec/dml.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
9090
ctx: ctx,
9191
ignore: ii.Ignore,
9292
firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx,
93+
returnExprs: ii.Returning,
94+
returnSchema: ii.Schema(),
9395
deferredDefaults: ii.DeferredDefaults,
9496
}
9597

sql/rowexec/dml_iters.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -590,17 +590,30 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
590590
childIter := i.GetChildIter()
591591
childIter, sch := AddAccumulatorIter(ctx, childIter)
592592
return i.WithChildIter(childIter), sch
593-
default:
594-
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
595-
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
596-
if rowHandler == nil {
597-
return iter, nil
593+
case *plan.TableEditorIter:
594+
// If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
595+
innerIter := i.InnerIter()
596+
if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 {
597+
return insertIter, insertIter.returnSchema
598598
}
599-
return &accumulatorIter{
600-
iter: iter,
601-
updateRowHandler: rowHandler,
602-
}, types.OkResultSchema
599+
600+
return defaultAccumulatorIter(ctx, iter)
601+
default:
602+
return defaultAccumulatorIter(ctx, iter)
603+
}
604+
}
605+
606+
// defaultAccumulatorIter returns the default accumulator iter for a DML node
607+
func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) {
608+
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
609+
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
610+
if rowHandler == nil {
611+
return iter, nil
603612
}
613+
return &accumulatorIter{
614+
iter: iter,
615+
updateRowHandler: rowHandler,
616+
}, types.OkResultSchema
604617
}
605618

606619
func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {

sql/rowexec/insert.go

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,21 @@ import (
3030
)
3131

3232
type insertIter struct {
33-
schema sql.Schema
34-
inserter sql.RowInserter
35-
replacer sql.RowReplacer
36-
updater sql.RowUpdater
37-
rowSource sql.RowIter
38-
unlocker func()
39-
ctx *sql.Context
40-
insertExprs []sql.Expression
41-
updateExprs []sql.Expression
42-
checks sql.CheckConstraints
43-
tableNode sql.Node
44-
closed bool
45-
ignore bool
33+
schema sql.Schema
34+
inserter sql.RowInserter
35+
replacer sql.RowReplacer
36+
updater sql.RowUpdater
37+
rowSource sql.RowIter
38+
unlocker func()
39+
ctx *sql.Context
40+
insertExprs []sql.Expression
41+
updateExprs []sql.Expression
42+
checks sql.CheckConstraints
43+
tableNode sql.Node
44+
closed bool
45+
ignore bool
46+
returnExprs []sql.Expression
47+
returnSchema sql.Schema
4648

4749
firstGeneratedAutoIncRowIdx int
4850

@@ -175,6 +177,18 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)
175177

176178
i.updateLastInsertId(ctx, row)
177179

180+
if len(i.returnExprs) > 0 {
181+
var retExprRow sql.Row
182+
for _, returnExpr := range i.returnExprs {
183+
result, err := returnExpr.Eval(ctx, row)
184+
if err != nil {
185+
return nil, err
186+
}
187+
retExprRow = append(retExprRow, result)
188+
}
189+
return retExprRow, nil
190+
}
191+
178192
return row, nil
179193
}
180194

0 commit comments

Comments
 (0)