Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions enginetest/queries/insert_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,77 @@ var InsertScripts = []ScriptTest{
},
},
},
{
// https://github.com/dolthub/dolt/issues/9895
Name: "insert...returning works with after triggers",
Dialect: "mysql", // actually mariadb
SetUpScript: []string{
"create table parent (id int not null auto_increment, primary key (id))",
"create table child(parent_id int not null, version_id int not null auto_increment, primary key (version_id))",
"insert into parent () values ()",
"create trigger trg_child_after_insert after insert on child for each row update parent set id = (id + 1) where id = NEW.parent_id",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into child (parent_id) values (1) returning version_id",
Expected: []sql.Row{{1}},
},
{
Query: "select * from parent",
Expected: []sql.Row{{2}},
},
{
Query: "insert into child (parent_id) values (2) returning version_id",
Expected: []sql.Row{{2}},
},
{
Query: "select * from parent",
Expected: []sql.Row{{3}},
},
{
// https://github.com/dolthub/dolt/issues/9907
Skip: true,
Query: "insert into child (parent_id) values ((select id from parent limit 1)) returning parent_id, version_id",
// TODO: update to actual error
ExpectedErr: nil,
},
},
},
{
Name: "insert...returning works with before triggers",
Dialect: "mysql", // actually mariadb
SetUpScript: []string{
"create table parent (id int not null auto_increment, primary key (id))",
"create table child(parent_id int not null, version_id int not null auto_increment, primary key (version_id))",
"insert into parent () values ()",
"create trigger trg_child_before_insert before insert on child for each row update parent set id = (id + 1) where id = NEW.parent_id",
},
Assertions: []ScriptTestAssertion{
{
Query: "insert into child (parent_id) values (1) returning version_id",
Expected: []sql.Row{{1}},
},
{
Query: "select * from parent",
Expected: []sql.Row{{2}},
},
{
Query: "insert into child (parent_id) values (2) returning version_id",
Expected: []sql.Row{{2}},
},
{
Query: "select * from parent",
Expected: []sql.Row{{3}},
},
{
//https://github.com/dolthub/dolt/issues/9907
Skip: true,
Query: "insert into child (parent_id) values ((select id from parent limit 1)) returning version_id",
// TODO: update to actual error
ExpectedErr: nil,
},
},
},
}

var InsertDuplicateKeyKeyless = []ScriptTest{
Expand Down
2 changes: 2 additions & 0 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
})
return n.WithSource(triggerExecutor), transform.NewTree, nil
} else {
n.HasAfterTrigger = true
return plan.NewTriggerExecutor(n, triggerLogic, plan.InsertTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
CreateStatement: trigger.CreateTriggerString,
Expand Down Expand Up @@ -438,6 +439,7 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
node, err := n.WithChildren(triggerExecutor)
return node, transform.NewTree, err
} else {
// TODO: add HasAfterTrigger flag for DeleteFrom node once DELETE...RETURNING has been implemented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have DELETE ... RETURNING support in GMS, but it may only be tested currently through Doltgres' test suite.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not yet fully implemented for GMS and Dolt (we don't currently support the grammar in vitess) so I'll leave it as a TODO for now

return plan.NewTriggerExecutor(n, triggerLogic, plan.DeleteTrigger, plan.TriggerTime(trigger.TriggerTime), sql.TriggerDefinition{
Name: trigger.TriggerName,
CreateStatement: trigger.CreateTriggerString,
Expand Down
1 change: 1 addition & 0 deletions sql/plan/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ type InsertInto struct {
// LiteralValueSource is set to |true| when |Source| is
// a |Values| node with only literal expressions.
LiteralValueSource bool
HasAfterTrigger bool
}

var _ sql.Databaser = (*InsertInto)(nil)
Expand Down
1 change: 1 addition & 0 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
returnExprs: ii.Returning,
returnSchema: ii.Schema(),
deferredDefaults: ii.DeferredDefaults,
hasAfterTrigger: ii.HasAfterTrigger,
}

var ed sql.EditOpenerCloser
Expand Down
40 changes: 36 additions & 4 deletions sql/rowexec/dml_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
logicRow = row
}

if returnRow, err, ok := t.getReturningRow(ctx, childRow); ok {
return returnRow, err
}

// For some logic statements, we want to return the result of the logic operation as our row, e.g. a Set that alters
// the fields of the new row
if ok, returnRow := shouldUseLogicResult(logic, logicRow); ok {
Expand All @@ -293,6 +297,20 @@ func (t *triggerIter) Next(ctx *sql.Context) (row sql.Row, returnErr error) {
return childRow, nil
}

func (t *triggerIter) getReturningRow(ctx *sql.Context, row sql.Row) (sql.Row, error, bool) {
if tableEditor, isTableEditor := t.child.(*plan.TableEditorIter); isTableEditor {
// TODO: get returning rows for REPLACE and DELETE once REPLACE...RETURNING and DELETE...RETURNING have been
// implemented
if insert, isInsert := tableEditor.InnerIter().(*insertIter); isInsert {
if len(insert.returnExprs) > 0 {
retRow, err := insert.getReturningRow(ctx, row)
return retRow, err, true
}
}
}
return nil, nil, false
}

func shouldUseLogicResult(logic sql.Node, row sql.Row) (bool, sql.Row) {
switch logic := logic.(type) {
// TODO: are there other statement types that we should use here?
Expand Down Expand Up @@ -624,11 +642,25 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
return innerIter, innerIter.returnSchema
}
}

return defaultAccumulatorIter(ctx, iter)
default:
return defaultAccumulatorIter(ctx, iter)
case *triggerIter:
if tableEditor, ok := i.child.(*plan.TableEditorIter); ok {
switch innerIter := tableEditor.InnerIter().(type) {
case *insertIter:
if len(innerIter.returnExprs) > 0 {
return i, innerIter.returnSchema
}
case *updateIter:
if len(innerIter.returnExprs) > 0 {
return i, innerIter.returnSchema
}
case *deleteIter:
if len(innerIter.returnExprs) > 0 {
return i, innerIter.returnSchema
}
}
}
}
return defaultAccumulatorIter(ctx, iter)
}

// defaultAccumulatorIter returns the default accumulator iter for a DML node
Expand Down
25 changes: 15 additions & 10 deletions sql/rowexec/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type insertIter struct {
rowNumber int64
closed bool
ignore bool
hasAfterTrigger bool
}

func getInsertExpressions(values sql.Node) []sql.Expression {
Expand Down Expand Up @@ -217,21 +218,25 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)

i.updateLastInsertId(ctx, row)

if len(i.returnExprs) > 0 {
var retExprRow sql.Row
for _, returnExpr := range i.returnExprs {
result, err := returnExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
retExprRow = append(retExprRow, result)
}
return retExprRow, nil
if len(i.returnExprs) > 0 && !i.hasAfterTrigger {
return i.getReturningRow(ctx, row)
}

return row, nil
}

func (i *insertIter) getReturningRow(ctx *sql.Context, row sql.Row) (sql.Row, error) {
var retExprRow sql.Row
for _, returnExpr := range i.returnExprs {
result, err := returnExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
retExprRow = append(retExprRow, result)
}
return retExprRow, nil
}

func (i *insertIter) handleOnDuplicateKeyUpdate(ctx *sql.Context, oldRow, newRow sql.Row) (returnRow sql.Row, returnErr error) {
var err error
updateAcc := append(oldRow, newRow...)
Expand Down