15
15
package analyzer
16
16
17
17
import (
18
- "errors"
19
18
"fmt"
20
19
"strings"
21
20
@@ -450,15 +449,21 @@ func applyTrigger(ctx *sql.Context, a *Analyzer, originalNode, n sql.Node, scope
450
449
})
451
450
}
452
451
453
- func getUpdateJoinSource (n sql.Node ) * plan.UpdateSource {
452
+ // getUpdateJoinSource looks for an UpdateJoin child in an Update node and get the UpdateSource and a map of table
453
+ // aliases
454
+ func getUpdateJoinSource (n sql.Node ) (* plan.UpdateSource , map [string ]string ) {
454
455
if updateNode , isUpdate := n .(* plan.Update ); isUpdate {
455
456
if updateJoin , isUpdateJoin := updateNode .Child .(* plan.UpdateJoin ); isUpdateJoin {
456
457
if updateSrc , isUpdateSrc := updateJoin .Child .(* plan.UpdateSource ); isUpdateSrc {
457
- return updateSrc
458
+ tableAliases := make (map [string ]string )
459
+ for alias , updateTarget := range updateJoin .UpdateTargets {
460
+ tableAliases [alias ] = getTableName (updateTarget )
461
+ }
462
+ return updateSrc , tableAliases
458
463
}
459
464
}
460
465
}
461
- return nil
466
+ return nil , nil
462
467
}
463
468
464
469
// getTriggerLogic analyzes and returns the Node representing the trigger body for the trigger given, applied to the
@@ -481,7 +486,7 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
481
486
triggerLogic , _ , err = a .analyzeWithSelector (ctx , trigger .Body , s , SelectAllBatches , DefaultRuleSelector , qFlags )
482
487
case sqlparser .UpdateStr :
483
488
var scopeNode * plan.Project
484
- if updateSrc := getUpdateJoinSource (n ); updateSrc == nil {
489
+ if updateSrc , tableAliases := getUpdateJoinSource (n ); updateSrc == nil {
485
490
scopeNode = plan .NewProject (
486
491
[]sql.Expression {expression .NewStar ()},
487
492
plan .NewCrossJoin (
@@ -490,18 +495,24 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
490
495
),
491
496
)
492
497
} else {
493
- // TODO: We should be able to handle duplicate column names by masking columns that aren't part of the
494
- // triggered table https://github.com/dolthub/dolt/issues/9403
495
- err = validateNoConflictingColumnNames (updateSrc .Child .Schema ())
496
- if err != nil {
497
- return nil , err
498
+ updateSrcCols := updateSrc .Child .Schema ()
499
+ triggerTableName := getTableName (trigger .Table )
500
+ maskedColNames := make ([]string , len (updateSrcCols ))
501
+ for i , col := range updateSrcCols {
502
+ // To avoid confusion when joined tables share a column name, we mask the column names from
503
+ // non-triggered tables
504
+ if col .Source == triggerTableName || tableAliases [col .Source ] == triggerTableName {
505
+ maskedColNames [i ] = col .Name
506
+ } else {
507
+ maskedColNames [i ] = ""
508
+ }
498
509
}
499
- // The scopeNode for an UpdateJoin should contain every node in the updateSource as new and old.
510
+ // The scopeNode for an UpdateJoin should contain every column in the updateSource as new and old.
500
511
scopeNode = plan .NewProject (
501
512
[]sql.Expression {expression .NewStar ()},
502
513
plan .NewCrossJoin (
503
- plan .NewSubqueryAlias ("old" , "" , updateSrc .Child ),
504
- plan .NewSubqueryAlias ("new" , "" , updateSrc .Child ),
514
+ plan .NewSubqueryAlias ("old" , "" , updateSrc .Child ). WithColumnNames ( maskedColNames ) ,
515
+ plan .NewSubqueryAlias ("new" , "" , updateSrc .Child ). WithColumnNames ( maskedColNames ) ,
505
516
),
506
517
)
507
518
}
@@ -521,19 +532,6 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
521
532
return triggerLogic , err
522
533
}
523
534
524
- // validateNoConflictingColumnNames checks the columns of a joined table to make sure there are no conflicting column
525
- // names
526
- func validateNoConflictingColumnNames (sch sql.Schema ) error {
527
- columnNames := make (map [string ]struct {})
528
- for _ , col := range sch {
529
- if _ , ok := columnNames [col .Name ]; ok {
530
- return errors .New ("Unable to apply triggers when joined tables have columns with the same name" )
531
- }
532
- columnNames [col .Name ] = struct {}{}
533
- }
534
- return nil
535
- }
536
-
537
535
// validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any
538
536
// table being updated in an outer scope of this analysis)
539
537
func validateNoCircularUpdates (trigger * plan.CreateTrigger , n sql.Node , scope * plan.Scope ) error {
0 commit comments