@@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
492
492
return
493
493
}
494
494
495
+ // buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
496
+ // children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
497
+ // don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
498
+ // analyzer processing that converts the subquery into a join, and then requires the same logic to
499
+ // create an UpdateJoin node under the original Update node.
495
500
func (b * Builder ) buildUpdate (inScope * scope , u * ast.Update ) (outScope * scope ) {
496
501
// TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
497
502
// The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -534,44 +539,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
534
539
update .IsProcNested = b .ProcCtx ().DbName != ""
535
540
536
541
var checks []* sql.CheckConstraint
537
- if join , ok := outScope .node .(* plan.JoinNode ); ok {
538
- // TODO this doesn't work, a lot of the time the top node
539
- // is a filter. This would have to go before we build the
540
- // filter/accessory nodes. But that errors for a lot of queries.
541
- source := plan .NewUpdateSource (
542
- join ,
543
- ignore ,
544
- updateExprs ,
545
- )
546
- updaters , err := rowUpdatersByTable (b .ctx , source , join )
542
+ if hasJoinNode (outScope .node ) {
543
+ tablesToUpdate , err := getResolvedTablesToUpdate (b .ctx , update .Child , outScope .node )
547
544
if err != nil {
548
545
b .handleErr (err )
549
546
}
550
- updateJoin := plan .NewUpdateJoin (updaters , source )
551
- update .Child = updateJoin
552
- transform .Inspect (update , func (n sql.Node ) bool {
553
- // todo maybe this should be later stage
554
- switch n := n .(type ) {
555
- case sql.NameableNode :
556
- if _ , ok := updaters [n .Name ()]; ok {
557
- rt := getResolvedTable (n )
558
- tableScope := inScope .push ()
559
- for _ , c := range rt .Schema () {
560
- tableScope .addColumn (scopeColumn {
561
- db : rt .SqlDatabase .Name (),
562
- table : strings .ToLower (n .Name ()),
563
- tableId : tableScope .tables [strings .ToLower (n .Name ())],
564
- col : strings .ToLower (c .Name ),
565
- typ : c .Type ,
566
- nullable : c .Nullable ,
567
- })
568
- }
569
- checks = append (checks , b .loadChecksFromTable (tableScope , rt .Table )... )
570
- }
571
- default :
547
+
548
+ for _ , rt := range tablesToUpdate {
549
+ tableScope := inScope .push ()
550
+ for _ , c := range rt .Schema () {
551
+ tableScope .addColumn (scopeColumn {
552
+ db : rt .SqlDatabase .Name (),
553
+ table : strings .ToLower (rt .Name ()),
554
+ tableId : tableScope .tables [strings .ToLower (rt .Name ())],
555
+ col : strings .ToLower (c .Name ),
556
+ typ : c .Type ,
557
+ nullable : c .Nullable ,
558
+ })
572
559
}
573
- return true
574
- })
560
+ checks = append ( checks , b . loadChecksFromTable ( tableScope , rt . Table ) ... )
561
+ }
575
562
} else {
576
563
transform .Inspect (update , func (n sql.Node ) bool {
577
564
// todo maybe this should be later stage
@@ -594,35 +581,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
594
581
return
595
582
}
596
583
597
- // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
598
- func rowUpdatersByTable (ctx * sql.Context , node sql.Node , ij sql.Node ) (map [string ]sql.RowUpdater , error ) {
599
- namesOfTableToBeUpdated := getTablesToBeUpdated (node )
600
- resolvedTables := getTablesByName (ij )
601
-
602
- rowUpdatersByTable := make (map [string ]sql.RowUpdater )
603
- for tableToBeUpdated , _ := range namesOfTableToBeUpdated {
604
- resolvedTable , ok := resolvedTables [strings .ToLower (tableToBeUpdated )]
605
- if ! ok {
606
- return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
584
+ // hasJoinNode returns true if |node| or any child is a JoinNode.
585
+ func hasJoinNode (node sql.Node ) bool {
586
+ updateJoinFound := false
587
+ transform .Inspect (node , func (n sql.Node ) bool {
588
+ if _ , ok := n .(* plan.JoinNode ); ok {
589
+ updateJoinFound = true
607
590
}
591
+ return ! updateJoinFound
592
+ })
593
+ return updateJoinFound
594
+ }
608
595
609
- var table = resolvedTable .UnderlyingTable ()
596
+ func getResolvedTablesToUpdate (_ * sql.Context , node sql.Node , ij sql.Node ) (resolvedTables []* plan.ResolvedTable , err error ) {
597
+ namesOfTablesToBeUpdated := getTablesToBeUpdated (node )
598
+ resolvedTablesMap := getTablesByName (ij )
610
599
611
- // If there is no UpdatableTable for a table being updated, error out
612
- updatable , ok := table .(sql. UpdatableTable )
613
- if ! ok && updatable == nil {
600
+ for tableToBeUpdated , _ := range namesOfTablesToBeUpdated {
601
+ resolvedTable , ok := resolvedTablesMap [ strings . ToLower ( tableToBeUpdated )]
602
+ if ! ok {
614
603
return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
615
604
}
616
605
617
- keyless := sql .IsKeyless (updatable .Schema ())
618
- if keyless {
619
- return nil , sql .ErrUnsupportedFeature .New ("error: keyless tables unsupported for UPDATE JOIN" )
620
- }
621
-
622
- rowUpdatersByTable [tableToBeUpdated ] = updatable .Updater (ctx )
606
+ resolvedTables = append (resolvedTables , resolvedTable )
623
607
}
624
608
625
- return rowUpdatersByTable , nil
609
+ return resolvedTables , nil
626
610
}
627
611
628
612
// getTablesByName takes a node and returns all found resolved tables in a map.
0 commit comments