Skip to content

Commit 4a6f3fb

Browse files
authored
Merge pull request #3173 from dolthub/angela/triggers
Mask column names from non-triggered tables in UpdateJoins
2 parents a369458 + 98e4cba commit 4a6f3fb

File tree

2 files changed

+28
-37
lines changed

2 files changed

+28
-37
lines changed

enginetest/queries/update_queries.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,8 @@ var UpdateScriptTests = []ScriptTest{
598598
},
599599
{
600600
Dialect: "mysql",
601-
Name: "UPDATE join – multiple tables with same column names with triggers",
601+
// https://github.com/dolthub/dolt/issues/9403
602+
Name: "UPDATE join – multiple tables with same column names with triggers",
602603
SetUpScript: []string{
603604
"create table customers (id int primary key, name text, tier text)",
604605
"create table orders (id int primary key, customer_id int, status text)",
@@ -615,19 +616,11 @@ var UpdateScriptTests = []ScriptTest{
615616
end;`,
616617
"insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');",
617618
"insert into orders values (101, 1, 'pending'), (102, 2, 'pending');",
619+
"update customers c join orders o on c.id = o.customer_id " +
620+
"set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
618621
},
619622
Assertions: []ScriptTestAssertion{
620623
{
621-
Query: "update customers c join orders o on c.id = o.customer_id " +
622-
"set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
623-
// TODO: we shouldn't expect an error once we're able to handle conflicting column names
624-
// https://github.com/dolthub/dolt/issues/9403
625-
ExpectedErrStr: "Unable to apply triggers when joined tables have columns with the same name",
626-
},
627-
{
628-
// TODO: unskip once we're able to handle conflicting column names
629-
// https://github.com/dolthub/dolt/issues/9403
630-
Skip: true,
631624
Query: "SELECT * FROM trigger_log order by msg;",
632625
Expected: []sql.Row{
633626
{"Customer 1 tier changed from silver to platinum"},

sql/analyzer/triggers.go

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package analyzer
1616

1717
import (
18-
"errors"
1918
"fmt"
2019
"strings"
2120

@@ -450,15 +449,21 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
450449
})
451450
}
452451

453-
func getUpdateJoinSource(n sql.Node) *plan.UpdateSource {
452+
// getUpdateJoinSource looks for an UpdateJoin child in an Update node and get the UpdateSource and a map of table
453+
// aliases
454+
func getUpdateJoinSource(n sql.Node) (*plan.UpdateSource, map[string]string) {
454455
if updateNode, isUpdate := n.(*plan.Update); isUpdate {
455456
if updateJoin, isUpdateJoin := updateNode.Child.(*plan.UpdateJoin); isUpdateJoin {
456457
if updateSrc, isUpdateSrc := updateJoin.Child.(*plan.UpdateSource); isUpdateSrc {
457-
return updateSrc
458+
tableAliases := make(map[string]string)
459+
for alias, updateTarget := range updateJoin.UpdateTargets {
460+
tableAliases[alias] = getTableName(updateTarget)
461+
}
462+
return updateSrc, tableAliases
458463
}
459464
}
460465
}
461-
return nil
466+
return nil, nil
462467
}
463468

464469
// getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the
@@ -481,7 +486,7 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
481486
triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, DefaultRuleSelector, qFlags)
482487
case sqlparser.UpdateStr:
483488
var scopeNode *plan.Project
484-
if updateSrc := getUpdateJoinSource(n); updateSrc == nil {
489+
if updateSrc, tableAliases := getUpdateJoinSource(n); updateSrc == nil {
485490
scopeNode = plan.NewProject(
486491
[]sql.Expression{expression.NewStar()},
487492
plan.NewCrossJoin(
@@ -490,18 +495,24 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
490495
),
491496
)
492497
} else {
493-
// TODO: We should be able to handle duplicate column names by masking columns that aren't part of the
494-
// triggered table https://github.com/dolthub/dolt/issues/9403
495-
err = validateNoConflictingColumnNames(updateSrc.Child.Schema())
496-
if err != nil {
497-
return nil, err
498+
updateSrcCols := updateSrc.Child.Schema()
499+
triggerTableName := getTableName(trigger.Table)
500+
maskedColNames := make([]string, len(updateSrcCols))
501+
for i, col := range updateSrcCols {
502+
// To avoid confusion when joined tables share a column name, we mask the column names from
503+
// non-triggered tables
504+
if col.Source == triggerTableName || tableAliases[col.Source] == triggerTableName {
505+
maskedColNames[i] = col.Name
506+
} else {
507+
maskedColNames[i] = ""
508+
}
498509
}
499-
// The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old.
510+
// The scopeNode for an UpdateJoin should contain every column in the updateSource as new and old.
500511
scopeNode = plan.NewProject(
501512
[]sql.Expression{expression.NewStar()},
502513
plan.NewCrossJoin(
503-
plan.NewSubqueryAlias("old", "", updateSrc.Child),
504-
plan.NewSubqueryAlias("new", "", updateSrc.Child),
514+
plan.NewSubqueryAlias("old", "", updateSrc.Child).WithColumnNames(maskedColNames),
515+
plan.NewSubqueryAlias("new", "", updateSrc.Child).WithColumnNames(maskedColNames),
505516
),
506517
)
507518
}
@@ -521,19 +532,6 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
521532
return triggerLogic, err
522533
}
523534

524-
// validateNoConflictingColumnNames checks the columns of a joined table to make sure there are no conflicting column
525-
// names
526-
func validateNoConflictingColumnNames(sch sql.Schema) error {
527-
columnNames := make(map[string]struct{})
528-
for _, col := range sch {
529-
if _, ok := columnNames[col.Name]; ok {
530-
return errors.New("Unable to apply triggers when joined tables have columns with the same name")
531-
}
532-
columnNames[col.Name] = struct{}{}
533-
}
534-
return nil
535-
}
536-
537535
// validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any
538536
// table being updated in an outer scope of this analysis)
539537
func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *plan.Scope) error {

0 commit comments

Comments
 (0)