Skip to content

Commit 802dd50

Browse files
authored
Merge pull request #3245 from dolthub/angela/insert_returning
Allow triggers to fire with `INSERT...RETURNING` statements
2 parents 19e524c + 9014b19 commit 802dd50

File tree

6 files changed

+126
-14
lines changed

6 files changed

+126
-14
lines changed

enginetest/queries/insert_queries.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,6 +2278,77 @@ var InsertScripts = []ScriptTest{
22782278
},
22792279
},
22802280
},
2281+
{
2282+
// https://github.com/dolthub/dolt/issues/9895
2283+
Name: "insert...returning works with after triggers",
2284+
Dialect: "mysql", // actually mariadb
2285+
SetUpScript: []string{
2286+
"create table parent (id int not null auto_increment, primary key (id))",
2287+
"create table child(parent_id int not null, version_id int not null auto_increment, primary key (version_id))",
2288+
"insert into parent () values ()",
2289+
"create trigger trg_child_after_insert after insert on child for each row update parent set id = (id + 1) where id = NEW.parent_id",
2290+
},
2291+
Assertions: []ScriptTestAssertion{
2292+
{
2293+
Query: "insert into child (parent_id) values (1) returning version_id",
2294+
Expected: []sql.Row{{1}},
2295+
},
2296+
{
2297+
Query: "select * from parent",
2298+
Expected: []sql.Row{{2}},
2299+
},
2300+
{
2301+
Query: "insert into child (parent_id) values (2) returning version_id",
2302+
Expected: []sql.Row{{2}},
2303+
},
2304+
{
2305+
Query: "select * from parent",
2306+
Expected: []sql.Row{{3}},
2307+
},
2308+
{
2309+
// https://github.com/dolthub/dolt/issues/9907
2310+
Skip: true,
2311+
Query: "insert into child (parent_id) values ((select id from parent limit 1)) returning parent_id, version_id",
2312+
// TODO: update to actual error
2313+
ExpectedErr: nil,
2314+
},
2315+
},
2316+
},
2317+
{
2318+
Name: "insert...returning works with before triggers",
2319+
Dialect: "mysql", // actually mariadb
2320+
SetUpScript: []string{
2321+
"create table parent (id int not null auto_increment, primary key (id))",
2322+
"create table child(parent_id int not null, version_id int not null auto_increment, primary key (version_id))",
2323+
"insert into parent () values ()",
2324+
"create trigger trg_child_before_insert before insert on child for each row update parent set id = (id + 1) where id = NEW.parent_id",
2325+
},
2326+
Assertions: []ScriptTestAssertion{
2327+
{
2328+
Query: "insert into child (parent_id) values (1) returning version_id",
2329+
Expected: []sql.Row{{1}},
2330+
},
2331+
{
2332+
Query: "select * from parent",
2333+
Expected: []sql.Row{{2}},
2334+
},
2335+
{
2336+
Query: "insert into child (parent_id) values (2) returning version_id",
2337+
Expected: []sql.Row{{2}},
2338+
},
2339+
{
2340+
Query: "select * from parent",
2341+
Expected: []sql.Row{{3}},
2342+
},
2343+
{
2344+
//https://github.com/dolthub/dolt/issues/9907
2345+
Skip: true,
2346+
Query: "insert into child (parent_id) values ((select id from parent limit 1)) returning version_id",
2347+
// TODO: update to actual error
2348+
ExpectedErr: nil,
2349+
},
2350+
},
2351+
},
22812352
}
22822353

22832354
var InsertDuplicateKeyKeyless = []ScriptTest{

sql/analyzer/triggers.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
394394
})
395395
return n.WithSource(triggerExecutor), transform.NewTree, nil
396396
} else {
397+
n.HasAfterTrigger = true
397398
return plan.NewTriggerExecutor(n, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
398399
Name: trigger.TriggerName,
399400
CreateStatement: trigger.CreateTriggerString,
@@ -438,6 +439,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
438439
node, err := n.WithChildren(triggerExecutor)
439440
return node, transform.NewTree, err
440441
} else {
442+
// TODO: add HasAfterTrigger flag for DeleteFrom node once DELETE...RETURNING has been implemented
441443
return plan.NewTriggerExecutor(n, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
442444
Name: trigger.TriggerName,
443445
CreateStatement: trigger.CreateTriggerString,

sql/plan/insert.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ type InsertInto struct {
8181
// LiteralValueSource is set to |true| when |Source| is
8282
// a |Values| node with only literal expressions.
8383
LiteralValueSource bool
84+
HasAfterTrigger bool
8485
}
8586

8687
var _ sql.Databaser = (*InsertInto)(nil)

sql/rowexec/dml.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
9393
returnExprs: ii.Returning,
9494
returnSchema: ii.Schema(),
9595
deferredDefaults: ii.DeferredDefaults,
96+
hasAfterTrigger: ii.HasAfterTrigger,
9697
}
9798

9899
var ed sql.EditOpenerCloser

sql/rowexec/dml_iters.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
284284
logicRow = row
285285
}
286286

287+
if returnRow, err, ok := t.getReturningRow(ctx, childRow); ok {
288+
return returnRow, err
289+
}
290+
287291
// For some logic statements, we want to return the result of the logic operation as our row, e.g. a Set that alters
288292
// the fields of the new row
289293
if ok, returnRow := shouldUseLogicResult(logic, logicRow); ok {
@@ -293,6 +297,20 @@ func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
293297
return childRow, nil
294298
}
295299

300+
func (t *triggerIter) getReturningRow(ctx *sql.Context, row sql.Row) (sql.Row, error, bool) {
301+
if tableEditor, isTableEditor := t.child.(*plan.TableEditorIter); isTableEditor {
302+
// TODO: get returning rows for REPLACE and DELETE once REPLACE...RETURNING and DELETE...RETURNING have been
303+
// implemented
304+
if insert, isInsert := tableEditor.InnerIter().(*insertIter); isInsert {
305+
if len(insert.returnExprs) > 0 {
306+
retRow, err := insert.getReturningRow(ctx, row)
307+
return retRow, err, true
308+
}
309+
}
310+
}
311+
return nil, nil, false
312+
}
313+
296314
func shouldUseLogicResult(logic sql.Node, row sql.Row) (bool, sql.Row) {
297315
switch logic := logic.(type) {
298316
// TODO: are there other statement types that we should use here?
@@ -624,11 +642,25 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
624642
return innerIter, innerIter.returnSchema
625643
}
626644
}
627-
628-
return defaultAccumulatorIter(ctx, iter)
629-
default:
630-
return defaultAccumulatorIter(ctx, iter)
645+
case *triggerIter:
646+
if tableEditor, ok := i.child.(*plan.TableEditorIter); ok {
647+
switch innerIter := tableEditor.InnerIter().(type) {
648+
case *insertIter:
649+
if len(innerIter.returnExprs) > 0 {
650+
return i, innerIter.returnSchema
651+
}
652+
case *updateIter:
653+
if len(innerIter.returnExprs) > 0 {
654+
return i, innerIter.returnSchema
655+
}
656+
case *deleteIter:
657+
if len(innerIter.returnExprs) > 0 {
658+
return i, innerIter.returnSchema
659+
}
660+
}
661+
}
631662
}
663+
return defaultAccumulatorIter(ctx, iter)
632664
}
633665

634666
// defaultAccumulatorIter returns the default accumulator iter for a DML node

sql/rowexec/insert.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ type insertIter struct {
5252
rowNumber int64
5353
closed bool
5454
ignore bool
55+
hasAfterTrigger bool
5556
}
5657

5758
func getInsertExpressions(values sql.Node) []sql.Expression {
@@ -217,21 +218,25 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)
217218

218219
i.updateLastInsertId(ctx, row)
219220

220-
if len(i.returnExprs) > 0 {
221-
var retExprRow sql.Row
222-
for _, returnExpr := range i.returnExprs {
223-
result, err := returnExpr.Eval(ctx, row)
224-
if err != nil {
225-
return nil, err
226-
}
227-
retExprRow = append(retExprRow, result)
228-
}
229-
return retExprRow, nil
221+
if len(i.returnExprs) > 0 && !i.hasAfterTrigger {
222+
return i.getReturningRow(ctx, row)
230223
}
231224

232225
return row, nil
233226
}
234227

228+
func (i *insertIter) getReturningRow(ctx *sql.Context, row sql.Row) (sql.Row, error) {
229+
var retExprRow sql.Row
230+
for _, returnExpr := range i.returnExprs {
231+
result, err := returnExpr.Eval(ctx, row)
232+
if err != nil {
233+
return nil, err
234+
}
235+
retExprRow = append(retExprRow, result)
236+
}
237+
return retExprRow, nil
238+
}
239+
235240
func (i *insertIter) handleOnDuplicateKeyUpdate(ctx *sql.Context, oldRow, newRow sql.Row) (returnRow sql.Row, returnErr error) {
236241
var err error
237242
updateAcc := append(oldRow, newRow...)

0 commit comments

Comments
 (0)