@@ -490,6 +490,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
490490 return
491491}
492492
493+ // buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
494+ // children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
495+ // don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
496+ // analyzer processing that converts the subquery into a join, and then requires the same logic to
497+ // create an UpdateJoin node under the original Update node.
493498func (b * Builder ) buildUpdate (inScope * scope , u * ast.Update ) (outScope * scope ) {
494499 // TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
495500 // The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -532,44 +537,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
532537 update .IsProcNested = b .ProcCtx ().DbName != ""
533538
534539 var checks []* sql.CheckConstraint
535- if join , ok := outScope .node .(* plan.JoinNode ); ok {
536- // TODO this doesn't work, a lot of the time the top node
537- // is a filter. This would have to go before we build the
538- // filter/accessory nodes. But that errors for a lot of queries.
539- source := plan .NewUpdateSource (
540- join ,
541- ignore ,
542- updateExprs ,
543- )
544- updaters , err := rowUpdatersByTable (b .ctx , source , join )
540+ if hasJoinNode (outScope .node ) {
541+ tablesToUpdate , err := getResolvedTablesToUpdate (b .ctx , update .Child , outScope .node )
545542 if err != nil {
546543 b .handleErr (err )
547544 }
548- updateJoin := plan .NewUpdateJoin (updaters , source )
549- update .Child = updateJoin
550- transform .Inspect (update , func (n sql.Node ) bool {
551- // todo maybe this should be later stage
552- switch n := n .(type ) {
553- case sql.NameableNode :
554- if _ , ok := updaters [n .Name ()]; ok {
555- rt := getResolvedTable (n )
556- tableScope := inScope .push ()
557- for _ , c := range rt .Schema () {
558- tableScope .addColumn (scopeColumn {
559- db : rt .SqlDatabase .Name (),
560- table : strings .ToLower (n .Name ()),
561- tableId : tableScope .tables [strings .ToLower (n .Name ())],
562- col : strings .ToLower (c .Name ),
563- typ : c .Type ,
564- nullable : c .Nullable ,
565- })
566- }
567- checks = append (checks , b .loadChecksFromTable (tableScope , rt .Table )... )
568- }
569- default :
545+
546+ for _ , rt := range tablesToUpdate {
547+ tableScope := inScope .push ()
548+ for _ , c := range rt .Schema () {
549+ tableScope .addColumn (scopeColumn {
550+ db : rt .SqlDatabase .Name (),
551+ table : strings .ToLower (rt .Name ()),
552+ tableId : tableScope .tables [strings .ToLower (rt .Name ())],
553+ col : strings .ToLower (c .Name ),
554+ typ : c .Type ,
555+ nullable : c .Nullable ,
556+ })
570557 }
571- return true
572- })
558+ checks = append ( checks , b . loadChecksFromTable ( tableScope , rt . Table ) ... )
559+ }
573560 } else {
574561 transform .Inspect (update , func (n sql.Node ) bool {
575562 // todo maybe this should be later stage
@@ -588,35 +575,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
588575 return
589576}
590577
591- // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
592- func rowUpdatersByTable (ctx * sql.Context , node sql.Node , ij sql.Node ) (map [string ]sql.RowUpdater , error ) {
593- namesOfTableToBeUpdated := getTablesToBeUpdated (node )
594- resolvedTables := getTablesByName (ij )
595-
596- rowUpdatersByTable := make (map [string ]sql.RowUpdater )
597- for tableToBeUpdated , _ := range namesOfTableToBeUpdated {
598- resolvedTable , ok := resolvedTables [strings .ToLower (tableToBeUpdated )]
599- if ! ok {
600- return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
578+ // hasJoinNode returns true if |node| or any child is a JoinNode.
579+ func hasJoinNode (node sql.Node ) bool {
580+ updateJoinFound := false
581+ transform .Inspect (node , func (n sql.Node ) bool {
582+ if _ , ok := n .(* plan.JoinNode ); ok {
583+ updateJoinFound = true
601584 }
585+ return ! updateJoinFound
586+ })
587+ return updateJoinFound
588+ }
602589
603- var table = resolvedTable .UnderlyingTable ()
590+ func getResolvedTablesToUpdate (_ * sql.Context , node sql.Node , ij sql.Node ) (resolvedTables []* plan.ResolvedTable , err error ) {
591+ namesOfTablesToBeUpdated := getTablesToBeUpdated (node )
592+ resolvedTablesMap := getTablesByName (ij )
604593
605- // If there is no UpdatableTable for a table being updated, error out
606- updatable , ok := table .(sql. UpdatableTable )
607- if ! ok && updatable == nil {
594+ for tableToBeUpdated , _ := range namesOfTablesToBeUpdated {
595+ resolvedTable , ok := resolvedTablesMap [ strings . ToLower ( tableToBeUpdated )]
596+ if ! ok {
608597 return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
609598 }
610599
611- keyless := sql .IsKeyless (updatable .Schema ())
612- if keyless {
613- return nil , sql .ErrUnsupportedFeature .New ("error: keyless tables unsupported for UPDATE JOIN" )
614- }
615-
616- rowUpdatersByTable [tableToBeUpdated ] = updatable .Updater (ctx )
600+ resolvedTables = append (resolvedTables , resolvedTable )
617601 }
618602
619- return rowUpdatersByTable , nil
603+ return resolvedTables , nil
620604}
621605
622606// getTablesByName takes a node and returns all found resolved tables in a map.
0 commit comments