1515package analyzer
1616
1717import (
18- "errors"
1918 "fmt"
2019 "strings"
2120
@@ -450,15 +449,21 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
450449 })
451450}
452451
453- func getUpdateJoinSource (n sql.Node ) * plan.UpdateSource {
452+ // getUpdateJoinSource looks for an UpdateJoin child in an Update node and get the UpdateSource and a map of table
453+ // aliases
454+ func getUpdateJoinSource (n sql.Node ) (* plan.UpdateSource , map [string ]string ) {
454455 if updateNode , isUpdate := n .(* plan.Update ); isUpdate {
455456 if updateJoin , isUpdateJoin := updateNode .Child .(* plan.UpdateJoin ); isUpdateJoin {
456457 if updateSrc , isUpdateSrc := updateJoin .Child .(* plan.UpdateSource ); isUpdateSrc {
457- return updateSrc
458+ tableAliases := make (map [string ]string )
459+ for alias , updateTarget := range updateJoin .UpdateTargets {
460+ tableAliases [alias ] = getTableName (updateTarget )
461+ }
462+ return updateSrc , tableAliases
458463 }
459464 }
460465 }
461- return nil
466+ return nil , nil
462467}
463468
464469// getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the
@@ -481,7 +486,7 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
481486 triggerLogic , _ , err = a .analyzeWithSelector (ctx , trigger .Body , s , SelectAllBatches , DefaultRuleSelector , qFlags )
482487 case sqlparser .UpdateStr :
483488 var scopeNode * plan.Project
484- if updateSrc := getUpdateJoinSource (n ); updateSrc == nil {
489+ if updateSrc , tableAliases := getUpdateJoinSource (n ); updateSrc == nil {
485490 scopeNode = plan .NewProject (
486491 []sql.Expression {expression .NewStar ()},
487492 plan .NewCrossJoin (
@@ -490,18 +495,24 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
490495 ),
491496 )
492497 } else {
493- // TODO: We should be able to handle duplicate column names by masking columns that aren't part of the
494- // triggered table https://github.com/dolthub/dolt/issues/9403
495- err = validateNoConflictingColumnNames (updateSrc .Child .Schema ())
496- if err != nil {
497- return nil , err
498+ updateSrcCols := updateSrc .Child .Schema ()
499+ triggerTableName := getTableName (trigger .Table )
500+ maskedColNames := make ([]string , len (updateSrcCols ))
501+ for i , col := range updateSrcCols {
502+ // To avoid confusion when joined tables share a column name, we mask the column names from
503+ // non-triggered tables
504+ if col .Source == triggerTableName || tableAliases [col .Source ] == triggerTableName {
505+ maskedColNames [i ] = col .Name
506+ } else {
507+ maskedColNames [i ] = ""
508+ }
498509 }
499- // The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old.
510+ // The scopeNode for an UpdateJoin should contain every column in the updateSource as new and old.
500511 scopeNode = plan .NewProject (
501512 []sql.Expression {expression .NewStar ()},
502513 plan .NewCrossJoin (
503- plan .NewSubqueryAlias ("old" , "" , updateSrc .Child ),
504- plan .NewSubqueryAlias ("new" , "" , updateSrc .Child ),
514+ plan .NewSubqueryAlias ("old" , "" , updateSrc .Child ). WithColumnNames ( maskedColNames ) ,
515+ plan .NewSubqueryAlias ("new" , "" , updateSrc .Child ). WithColumnNames ( maskedColNames ) ,
505516 ),
506517 )
507518 }
@@ -521,19 +532,6 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
521532 return triggerLogic , err
522533}
523534
524- // validateNoConflictingColumnNames checks the columns of a joined table to make sure there are no conflicting column
525- // names
526- func validateNoConflictingColumnNames (sch sql.Schema ) error {
527- columnNames := make (map [string ]struct {})
528- for _ , col := range sch {
529- if _ , ok := columnNames [col .Name ]; ok {
530- return errors .New ("Unable to apply triggers when joined tables have columns with the same name" )
531- }
532- columnNames [col .Name ] = struct {}{}
533- }
534- return nil
535- }
536-
537535// validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any
538536// table being updated in an outer scope of this analysis)
539537func validateNoCircularUpdates (trigger * plan.CreateTrigger , n sql.Node , scope * plan.Scope ) error {
0 commit comments