Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions enginetest/queries/update_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ var UpdateScriptTests = []ScriptTest{
},
{
Dialect: "mysql",
Name: "UPDATE join – multiple tables with same column names with triggers",
// https://github.com/dolthub/dolt/issues/9403
Name: "UPDATE join – multiple tables with same column names with triggers",
SetUpScript: []string{
"create table customers (id int primary key, name text, tier text)",
"create table orders (id int primary key, customer_id int, status text)",
Expand All @@ -615,19 +616,11 @@ var UpdateScriptTests = []ScriptTest{
end;`,
"insert into customers values(1, 'Alice', 'silver'), (2, 'Bob', 'gold');",
"insert into orders values (101, 1, 'pending'), (102, 2, 'pending');",
"update customers c join orders o on c.id = o.customer_id " +
"set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
},
Assertions: []ScriptTestAssertion{
{
Query: "update customers c join orders o on c.id = o.customer_id " +
"set c.tier = 'platinum', o.status = 'shipped' where o.status = 'pending'",
// TODO: we shouldn't expect an error once we're able to handle conflicting column names
// https://github.com/dolthub/dolt/issues/9403
ExpectedErrStr: "Unable to apply triggers when joined tables have columns with the same name",
},
{
// TODO: unskip once we're able to handle conflicting column names
// https://github.com/dolthub/dolt/issues/9403
Skip: true,
Query: "SELECT * FROM trigger_log order by msg;",
Expected: []sql.Row{
{"Customer 1 tier changed from silver to platinum"},
Expand Down
50 changes: 24 additions & 26 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package analyzer

import (
"errors"
"fmt"
"strings"

Expand Down Expand Up @@ -450,15 +449,21 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
})
}

func getUpdateJoinSource(n sql.Node) *plan.UpdateSource {
// getUpdateJoinSource looks for an UpdateJoin child in an Update node and get the UpdateSource and a map of table
// aliases
func getUpdateJoinSource(n sql.Node) (*plan.UpdateSource, map[string]string) {
if updateNode, isUpdate := n.(*plan.Update); isUpdate {
if updateJoin, isUpdateJoin := updateNode.Child.(*plan.UpdateJoin); isUpdateJoin {
if updateSrc, isUpdateSrc := updateJoin.Child.(*plan.UpdateSource); isUpdateSrc {
return updateSrc
tableAliases := make(map[string]string)
for alias, updateTarget := range updateJoin.UpdateTargets {
tableAliases[alias] = getTableName(updateTarget)
}
return updateSrc, tableAliases
}
}
}
return nil
return nil, nil
}

// getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the
Expand All @@ -481,7 +486,7 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, DefaultRuleSelector, qFlags)
case sqlparser.UpdateStr:
var scopeNode *plan.Project
if updateSrc := getUpdateJoinSource(n); updateSrc == nil {
if updateSrc, tableAliases := getUpdateJoinSource(n); updateSrc == nil {
scopeNode = plan.NewProject(
[]sql.Expression{expression.NewStar()},
plan.NewCrossJoin(
Expand All @@ -490,18 +495,24 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
),
)
} else {
// TODO: We should be able to handle duplicate column names by masking columns that aren't part of the
// triggered table https://github.com/dolthub/dolt/issues/9403
err = validateNoConflictingColumnNames(updateSrc.Child.Schema())
if err != nil {
return nil, err
updateSrcCols := updateSrc.Child.Schema()
triggerTableName := getTableName(trigger.Table)
maskedColNames := make([]string, len(updateSrcCols))
for i, col := range updateSrcCols {
// To avoid confusion when joined tables share a column name, we mask the column names from
// non-triggered tables
if col.Source == triggerTableName || tableAliases[col.Source] == triggerTableName {
maskedColNames[i] = col.Name
} else {
maskedColNames[i] = ""
}
}
// The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old.
// The scopeNode for an UpdateJoin should contain every column in the updateSource as new and old.
scopeNode = plan.NewProject(
[]sql.Expression{expression.NewStar()},
plan.NewCrossJoin(
plan.NewSubqueryAlias("old", "", updateSrc.Child),
plan.NewSubqueryAlias("new", "", updateSrc.Child),
plan.NewSubqueryAlias("old", "", updateSrc.Child).WithColumnNames(maskedColNames),
plan.NewSubqueryAlias("new", "", updateSrc.Child).WithColumnNames(maskedColNames),
),
)
}
Expand All @@ -521,19 +532,6 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
return triggerLogic, err
}

// validateNoConflictingColumnNames checks the columns of a joined table to make sure there are no conflicting column
// names
func validateNoConflictingColumnNames(sch sql.Schema) error {
columnNames := make(map[string]struct{})
for _, col := range sch {
if _, ok := columnNames[col.Name]; ok {
return errors.New("Unable to apply triggers when joined tables have columns with the same name")
}
columnNames[col.Name] = struct{}{}
}
return nil
}

// validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any
// table being updated in an outer scope of this analysis)
func validateNoCircularUpdates(trigger *plan.CreateTrigger, n sql.Node, scope *plan.Scope) error {
Expand Down
Loading