Skip to content

Commit dd846ab

Browse files
committed
fixed bug where column names conflict but only works when Project is the direct child of UpdateSource. will revert. committing for future reference
1 parent a0af54c commit dd846ab

File tree

1 file changed

+45
-9
lines changed

1 file changed

+45
-9
lines changed

sql/analyzer/triggers.go

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
241241
return nil, transform.SameTree, err
242242
}
243243

244+
// triggerTable = getTableName(ct)
244245
var triggerTable string
245246
switch t := ct.Table.(type) {
246247
case *plan.ResolvedTable:
@@ -450,6 +451,39 @@ func getUpdateJoinSource(n sql.Node) *plan.UpdateSource {
450451
return nil
451452
}
452453

454+
// Determines if a GetField expression references the triggered table in an UpdateJoin
455+
func isUpdateJoinTriggerField(getField *expression.GetField, updateJoin *plan.UpdateJoin, trigger *plan.CreateTrigger) bool {
456+
updateTargets := updateJoin.UpdateTargets
457+
if updateTarget, isUpdateTarget := updateTargets[getField.Table()]; isUpdateTarget {
458+
if getTableName(updateTarget) == getTableName(trigger.Table) {
459+
return true
460+
}
461+
}
462+
return false
463+
}
464+
465+
// Returns the projection from an UpdateJoin with the non-triggered tables masked. This is to prevent conflicts if two
466+
// joined tables have columns with the same name
467+
func getMaskedUpdateJoinProject(updateJoin *plan.UpdateJoin, trigger *plan.CreateTrigger) *plan.Project {
468+
if updateSrc, isUpdateSrc := updateJoin.Child.(*plan.UpdateSource); isUpdateSrc {
469+
// get project parent
470+
if project, isProject := updateSrc.Child.(*plan.Project); isProject {
471+
projections := project.Projections
472+
maskedProjections := make([]sql.Expression, len(projections))
473+
for i, projection := range projections {
474+
maskedProjections[i] = projection
475+
if gf, isGf := projection.(*expression.GetField); isGf {
476+
if !isUpdateJoinTriggerField(gf, updateJoin, trigger) {
477+
maskedProjections[i] = gf.WithName("")
478+
}
479+
}
480+
}
481+
return plan.NewProject(maskedProjections, project.Child)
482+
}
483+
}
484+
panic("UpdateJoin node is not correctly structured")
485+
}
486+
453487
// getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the
454488
// plan node given, which must be an insert, update, or delete.
455489
func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, trigger *plan.CreateTrigger, qFlags *sql.QueryFlags) (sql.Node, error) {
@@ -458,41 +492,43 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
458492
// fabricate one with the right properties (its child schema matches the table schema, with the right aliased name)
459493
var triggerLogic sql.Node
460494
var err error
495+
var scopeNode *plan.Project
461496
qFlags = nil
462497

463498
switch trigger.TriggerEvent {
464499
case sqlparser.InsertStr:
465-
scopeNode := plan.NewProject(
500+
scopeNode = plan.NewProject(
466501
[]sql.Expression{expression.NewStar()},
467502
plan.NewTableAlias("new", trigger.Table),
468503
)
469504
s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
470505
triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, DefaultRuleSelector, qFlags)
471506
case sqlparser.UpdateStr:
472-
var scopeNode *plan.Project
473-
if updateSrc := getUpdateJoinSource(n); updateSrc == nil {
507+
if updateJoin, isUpdateJoin := n.(*plan.Update).Child.(*plan.UpdateJoin); isUpdateJoin {
508+
masked := getMaskedUpdateJoinProject(updateJoin, trigger)
509+
// The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old but should
510+
// have placeholder expressions for non-triggered tables.
474511
scopeNode = plan.NewProject(
475512
[]sql.Expression{expression.NewStar()},
476513
plan.NewCrossJoin(
477-
plan.NewTableAlias("old", trigger.Table),
478-
plan.NewTableAlias("new", trigger.Table),
514+
plan.NewSubqueryAlias("old", "", masked),
515+
plan.NewSubqueryAlias("new", "", masked),
479516
),
480517
)
481518
} else {
482-
// The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old.
483519
scopeNode = plan.NewProject(
484520
[]sql.Expression{expression.NewStar()},
485521
plan.NewCrossJoin(
486-
plan.NewSubqueryAlias("old", "", updateSrc.Child),
487-
plan.NewSubqueryAlias("new", "", updateSrc.Child),
522+
plan.NewTableAlias("old", trigger.Table),
523+
plan.NewTableAlias("new", trigger.Table),
488524
),
489525
)
490526
}
491527
// Triggers are wrapped in prepend nodes, which means that the parent scope is included
492528
s := (*plan.Scope)(nil).NewScope(scopeNode).WithMemos(scope.Memo(n).MemoNodes()).WithProcedureCache(scope.ProcedureCache())
493529
triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, DefaultRuleSelector, qFlags)
494530
case sqlparser.DeleteStr:
495-
scopeNode := plan.NewProject(
531+
scopeNode = plan.NewProject(
496532
[]sql.Expression{expression.NewStar()},
497533
plan.NewTableAlias("old", trigger.Table),
498534
)

0 commit comments

Comments
 (0)