Skip to content

Commit 22622a0

Browse files
authored
Merge pull request #3042 from dolthub/angela/triggers
Apply multiple triggers to `UPDATE JOIN`
2 parents 7615f0d + e1abac3 commit 22622a0

File tree

4 files changed

+121
-41
lines changed

4 files changed

+121
-41
lines changed

enginetest/queries/script_queries.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,15 +517,15 @@ SET entity_test.value = joined.value;`,
517517
Expected: []sql.Row{{1, "john", "doe", 0, 42}},
518518
},
519519
{
520-
Query: "UPDATE test_users JOIN (SELECT id, 1 FROM test_users) AS tu SET test_users.favorite_number = 420;",
520+
Query: "UPDATE test_users JOIN (SELECT 1 FROM test_users) AS tu SET test_users.favorite_number = 420;",
521521
Expected: []sql.Row{{NewUpdateResult(1, 1)}},
522522
},
523523
{
524524
Query: "select * from test_users;",
525525
Expected: []sql.Row{{1, "john", "doe", 0, 420}},
526526
},
527527
{
528-
Query: "UPDATE test_users JOIN (SELECT id, 1 FROM test_users) AS tu SET test_users.deleted = 1;",
528+
Query: "UPDATE test_users JOIN (SELECT 1 FROM test_users) AS tu SET test_users.deleted = 1;",
529529
Expected: []sql.Row{{NewUpdateResult(1, 1)}},
530530
},
531531
{

enginetest/queries/update_queries.go

Lines changed: 78 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ var UpdateScriptTests = []ScriptTest{
534534
Name: "UPDATE join – multiple tables, with trigger",
535535
SetUpScript: []string{
536536
"CREATE TABLE a (id INT PRIMARY KEY, x INT);",
537-
"CREATE TABLE b (id INT PRIMARY KEY, y INT);",
537+
"CREATE TABLE b (pk INT PRIMARY KEY, y INT);",
538538
"CREATE TABLE logbook (entry TEXT);",
539539
`CREATE TRIGGER trig_a AFTER UPDATE ON a FOR EACH ROW
540540
BEGIN
@@ -550,13 +550,10 @@ var UpdateScriptTests = []ScriptTest{
550550
Assertions: []ScriptTestAssertion{
551551
{
552552
Query: `UPDATE a
553-
JOIN b ON a.id = 5 AND b.id = 6
553+
JOIN b ON a.id = 5 AND b.pk = 6
554554
SET a.x = 101, b.y = 201;`,
555555
},
556556
{
557-
// TODO: UPDATE ... JOIN does not properly apply triggers when multiple tables are being updated,
558-
// and will currently only apply triggers from one of the tables.
559-
Skip: true,
560557
Query: "SELECT * FROM logbook ORDER BY entry;",
561558
Expected: []sql.Row{
562559
{"a updated"},
@@ -565,6 +562,82 @@ var UpdateScriptTests = []ScriptTest{
565562
},
566563
},
567564
},
565+
{
566+
Dialect: "mysql",
567+
Name: "UPDATE join – multiple tables with triggers that reference row values",
568+
SetUpScript: []string{
569+
"create table customers (id int primary key, name text, tier text)",
570+
"create table orders (order_id int primary key, customer_id int, status text)",
571+
"create table trigger_log (msg text)",
572+
`CREATE TRIGGER after_orders_update after update on orders for each row
573+
begin
574+
insert into trigger_log (msg) values(
575+
concat('Order ', OLD.order_id, ' status changed from ', OLD.status, ' to ', NEW.status));
576+
end;`,
577+
`Create trigger after_customers_update after update on customers for each row
578+
begin
579+
insert into trigger_log (msg) values(
580+
concat('Customer ', OLD.id, ' tier changed from ', OLD.tier, ' to ', NEW.tier));
581+
end;`,
582+
"insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');",
583+
"insert into orders values (101, 1, 'pending'), (102, 2, 'pending');",
584+
"update customers c join orders o on c.id = o.customer_id " +
585+
"set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
586+
},
587+
Assertions: []ScriptTestAssertion{
588+
{
589+
Query: "SELECT * FROM trigger_log order by msg;",
590+
Expected: []sql.Row{
591+
{"Customer 1 tier changed from silver to platinum"},
592+
{"Customer 2 tier changed from gold to platinum"},
593+
{"Order 101 status changed from pending to shipped"},
594+
{"Order 102 status changed from pending to shipped"},
595+
},
596+
},
597+
},
598+
},
599+
{
600+
Dialect: "mysql",
601+
Name: "UPDATE join – multiple tables with same column names with triggers",
602+
SetUpScript: []string{
603+
"create table customers (id int primary key, name text, tier text)",
604+
"create table orders (id int primary key, customer_id int, status text)",
605+
"create table trigger_log (msg text)",
606+
`CREATE TRIGGER after_orders_update after update on orders for each row
607+
begin
608+
insert into trigger_log (msg) values(
609+
concat('Order ', OLD.id, ' status changed from ', OLD.status, ' to ', NEW.status));
610+
end;`,
611+
`Create trigger after_customers_update after update on customers for each row
612+
begin
613+
insert into trigger_log (msg) values(
614+
concat('Customer ', OLD.id, ' tier changed from ', OLD.tier, ' to ', NEW.tier));
615+
end;`,
616+
"insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');",
617+
"insert into orders values (101, 1, 'pending'), (102, 2, 'pending');",
618+
},
619+
Assertions: []ScriptTestAssertion{
620+
{
621+
Query: "update customers c join orders o on c.id = o.customer_id " +
622+
"set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
623+
// TODO: we shouldn't expect an error once we're able to handle conflicting column names
624+
// https://github.com/dolthub/dolt/issues/9403
625+
ExpectedErrStr: "Unable to apply triggers when joined tables have columns with the same name",
626+
},
627+
{
628+
// TODO: unskip once we're able to handle conflicting column names
629+
// https://github.com/dolthub/dolt/issues/9403
630+
Skip: true,
631+
Query: "SELECT * FROM trigger_log order by msg;",
632+
Expected: []sql.Row{
633+
{"Customer 1 tier changed from silver to platinum"},
634+
{"Customer 2 tier changed from gold to platinum"},
635+
{"Order 101 status changed from pending to shipped"},
636+
{"Order 102 status changed from pending to shipped"},
637+
},
638+
},
639+
},
640+
},
568641
}
569642

570643
var SpatialUpdateTests = []WriteQueryTest{

sql/analyzer/tables.go

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"github.com/dolthub/go-mysql-server/sql/transform"
2323
)
2424

25-
// Returns the underlying table name for the node given
25+
// Returns the underlying table name, unaliased, for the node given
2626
func getTableName(node sql.Node) string {
2727
var tableName string
2828
transform.Inspect(node, func(node sql.Node) bool {
@@ -43,27 +43,6 @@ func getTableName(node sql.Node) string {
4343
return tableName
4444
}
4545

46-
// Returns the underlying table name for the node given, ignoring table aliases
47-
func getUnaliasedTableName(node sql.Node) string {
48-
var tableName string
49-
transform.Inspect(node, func(node sql.Node) bool {
50-
switch node := node.(type) {
51-
case *plan.ResolvedTable:
52-
tableName = node.Name()
53-
return false
54-
case *plan.UnresolvedTable:
55-
tableName = node.Name()
56-
return false
57-
case *plan.IndexedTableAccess:
58-
tableName = node.Name()
59-
return false
60-
}
61-
return true
62-
})
63-
64-
return tableName
65-
}
66-
6746
// Finds first table node that is a descendant of the node given
6847
func getTable(node sql.Node) sql.Table {
6948
var table sql.Table

sql/analyzer/triggers.go

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package analyzer
1616

1717
import (
18+
"errors"
1819
"fmt"
1920
"strings"
2021

@@ -158,7 +159,15 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
158159
db = n.Database().Name()
159160
}
160161
case *plan.Update:
161-
affectedTables = append(affectedTables, getTableName(n))
162+
if n.IsJoin {
163+
uj := n.Child.(*plan.UpdateJoin)
164+
updateTargets := uj.UpdateTargets
165+
for _, updateTarget := range updateTargets {
166+
affectedTables = append(affectedTables, getTableName(updateTarget))
167+
}
168+
} else {
169+
affectedTables = append(affectedTables, getTableName(n))
170+
}
162171
triggerEvent = plan.UpdateTrigger
163172
if n.Database() != "" {
164173
db = n.Database()
@@ -355,18 +364,18 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
355364
}
356365
}
357366

358-
return transform.NodeWithCtx(n, nil, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
367+
canApplyTriggerExecutor := func(c transform.Context) bool {
359368
// Don't double-apply trigger executors to the bodies of triggers. To avoid this, don't apply the trigger if the
360-
// parent is a trigger body.
361-
// TODO: this won't work for BEGIN END blocks, stored procedures, etc. For those, we need to examine all ancestors,
362-
// not just the immediate parent. Alternately, we could do something like not walk all children of some node types
363-
// (probably better).
369+
// parent is a trigger body. Having this as a selector function will also prevent walking the child nodes in the
370+
// trigger execution logic.
364371
if _, ok := c.Parent.(*plan.TriggerExecutor); ok {
365372
if c.ChildNum == 1 { // Right child is the trigger execution logic
366-
return c.Node, transform.SameTree, nil
373+
return false
367374
}
368375
}
369-
376+
return true
377+
}
378+
return transform.NodeWithCtx(n, canApplyTriggerExecutor, func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
370379
switch n := c.Node.(type) {
371380
case *plan.InsertInto:
372381
qFlags.Set(sql.QFlagTrigger)
@@ -404,9 +413,9 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
404413
// like we need something like a MultipleTriggerExecutor node
405414
// that could execute multiple triggers on the same row from its
406415
// wrapped iterator. There is also an issue with running triggers
407-
// because their field indexes assume the row they evalute will
416+
// because their field indexes assume the row they evaluate will
408417
// only ever contain the columns from the single table the trigger
409-
// is based on, but this isn't true with UPDATE JOIN or DELETE JOIN.
418+
// is based on.
410419
if n.HasExplicitTargets() {
411420
return nil, transform.SameTree, fmt.Errorf("delete from with explicit target tables " +
412421
"does not support triggers; retry with single table deletes")
@@ -472,6 +481,12 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
472481
),
473482
)
474483
} else {
484+
// TODO: We should be able to handle duplicate column names by masking columns that aren't part of the
485+
// triggered table https://github.com/dolthub/dolt/issues/9403
486+
err = validateNoConflictingColumnNames(updateSrc.Child.Schema())
487+
if err != nil {
488+
return nil, err
489+
}
475490
// The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old.
476491
scopeNode = plan.NewProject(
477492
[]sql.Expression{expression.NewStar()},
@@ -497,6 +512,19 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
497512
return triggerLogic, err
498513
}
499514

515+
// validateNoConflictingColumnNames checks the columns of a joined table to make sure there are no conflicting column
516+
// names
517+
func validateNoConflictingColumnNames(sch sql.Schema) error {
518+
columnNames := make(map[string]struct{})
519+
for _, col := range sch {
520+
if _, ok := columnNames[col.Name]; ok {
521+
return errors.New("Unable to apply triggers when joined tables have columns with the same name")
522+
}
523+
columnNames[col.Name] = struct{}{}
524+
}
525+
return nil
526+
}
527+
500528
// validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any
501529
// table being updated in an outer scope of this analysis)
502530
func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *plan.Scope) error {
@@ -505,8 +533,8 @@ func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *p
505533
switch node := node.(type) {
506534
case *plan.Update, *plan.InsertInto, *plan.DeleteFrom:
507535
for _, n := range append([]sql.Node{n}, scope.MemoNodes()...) {
508-
invokingTableName := getUnaliasedTableName(n)
509-
updatedTable := getUnaliasedTableName(node)
536+
invokingTableName := getTableName(n)
537+
updatedTable := getTableName(node)
510538
// TODO: need to compare DB as well
511539
if updatedTable == invokingTableName {
512540
circularRef = sql.ErrTriggerTableInUse.New(updatedTable)

0 commit comments

Comments
 (0)