@@ -492,6 +492,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
492492 return
493493}
494494
495+ // buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
496+ // children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
497+ // don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
498+ // analyzer processing that converts the subquery into a join, and then requires the same logic to
499+ // create an UpdateJoin node under the original Update node.
495500func (b * Builder ) buildUpdate (inScope * scope , u * ast.Update ) (outScope * scope ) {
496501 // TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
497502 // The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -534,44 +539,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
534539 update .IsProcNested = b .ProcCtx ().DbName != ""
535540
536541 var checks []* sql.CheckConstraint
537- if join , ok := outScope .node .(* plan.JoinNode ); ok {
538- // TODO this doesn't work, a lot of the time the top node
539- // is a filter. This would have to go before we build the
540- // filter/accessory nodes. But that errors for a lot of queries.
541- source := plan .NewUpdateSource (
542- join ,
543- ignore ,
544- updateExprs ,
545- )
546- updaters , err := rowUpdatersByTable (b .ctx , source , join )
542+ if hasJoinNode (outScope .node ) {
543+ tablesToUpdate , err := getResolvedTablesToUpdate (b .ctx , update .Child , outScope .node )
547544 if err != nil {
548545 b .handleErr (err )
549546 }
550- updateJoin := plan .NewUpdateJoin (updaters , source )
551- update .Child = updateJoin
552- transform .Inspect (update , func (n sql.Node ) bool {
553- // todo maybe this should be later stage
554- switch n := n .(type ) {
555- case sql.NameableNode :
556- if _ , ok := updaters [n .Name ()]; ok {
557- rt := getResolvedTable (n )
558- tableScope := inScope .push ()
559- for _ , c := range rt .Schema () {
560- tableScope .addColumn (scopeColumn {
561- db : rt .SqlDatabase .Name (),
562- table : strings .ToLower (n .Name ()),
563- tableId : tableScope .tables [strings .ToLower (n .Name ())],
564- col : strings .ToLower (c .Name ),
565- typ : c .Type ,
566- nullable : c .Nullable ,
567- })
568- }
569- checks = append (checks , b .loadChecksFromTable (tableScope , rt .Table )... )
570- }
571- default :
547+
548+ for _ , rt := range tablesToUpdate {
549+ tableScope := inScope .push ()
550+ for _ , c := range rt .Schema () {
551+ tableScope .addColumn (scopeColumn {
552+ db : rt .SqlDatabase .Name (),
553+ table : strings .ToLower (rt .Name ()),
554+ tableId : tableScope .tables [strings .ToLower (rt .Name ())],
555+ col : strings .ToLower (c .Name ),
556+ typ : c .Type ,
557+ nullable : c .Nullable ,
558+ })
572559 }
573- return true
574- })
560+ checks = append ( checks , b . loadChecksFromTable ( tableScope , rt . Table ) ... )
561+ }
575562 } else {
576563 transform .Inspect (update , func (n sql.Node ) bool {
577564 // todo maybe this should be later stage
@@ -594,35 +581,32 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
594581 return
595582}
596583
597- // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
598- func rowUpdatersByTable (ctx * sql.Context , node sql.Node , ij sql.Node ) (map [string ]sql.RowUpdater , error ) {
599- namesOfTableToBeUpdated := getTablesToBeUpdated (node )
600- resolvedTables := getTablesByName (ij )
601-
602- rowUpdatersByTable := make (map [string ]sql.RowUpdater )
603- for tableToBeUpdated , _ := range namesOfTableToBeUpdated {
604- resolvedTable , ok := resolvedTables [strings .ToLower (tableToBeUpdated )]
605- if ! ok {
606- return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
584+ // hasJoinNode returns true if |node| or any child is a JoinNode.
585+ func hasJoinNode (node sql.Node ) bool {
586+ updateJoinFound := false
587+ transform .Inspect (node , func (n sql.Node ) bool {
588+ if _ , ok := n .(* plan.JoinNode ); ok {
589+ updateJoinFound = true
607590 }
591+ return ! updateJoinFound
592+ })
593+ return updateJoinFound
594+ }
608595
609- var table = resolvedTable .UnderlyingTable ()
596+ func getResolvedTablesToUpdate (_ * sql.Context , node sql.Node , ij sql.Node ) (resolvedTables []* plan.ResolvedTable , err error ) {
597+ namesOfTablesToBeUpdated := getTablesToBeUpdated (node )
598+ resolvedTablesMap := getTablesByName (ij )
610599
611- // If there is no UpdatableTable for a table being updated, error out
612- updatable , ok := table .(sql. UpdatableTable )
613- if ! ok && updatable == nil {
600+ for tableToBeUpdated , _ := range namesOfTablesToBeUpdated {
601+ resolvedTable , ok := resolvedTablesMap [ strings . ToLower ( tableToBeUpdated )]
602+ if ! ok {
614603 return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
615604 }
616605
617- keyless := sql .IsKeyless (updatable .Schema ())
618- if keyless {
619- return nil , sql .ErrUnsupportedFeature .New ("error: keyless tables unsupported for UPDATE JOIN" )
620- }
621-
622- rowUpdatersByTable [tableToBeUpdated ] = updatable .Updater (ctx )
606+ resolvedTables = append (resolvedTables , resolvedTable )
623607 }
624608
625- return rowUpdatersByTable , nil
609+ return resolvedTables , nil
626610}
627611
628612// getTablesByName takes a node and returns all found resolved tables in a map.
0 commit comments