@@ -534,20 +534,29 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
534534 update .IsProcNested = b .ProcCtx ().DbName != ""
535535
536536 var checks []* sql.CheckConstraint
537- if join , ok := outScope .node .(* plan.JoinNode ); ok {
538- // TODO this doesn't work, a lot of the time the top node
539- // is a filter. This would have to go before we build the
540- // filter/accessory nodes. But that errors for a lot of queries.
537+ if hasJoinNode (outScope .node ) {
541538 source := plan .NewUpdateSource (
542- join ,
539+ outScope . node ,
543540 ignore ,
544541 updateExprs ,
545542 )
546- updaters , err := rowUpdatersByTable (b .ctx , source , join )
543+ updaters , err := rowUpdatersByTable (b .ctx , source , outScope . node )
547544 if err != nil {
548545 b .handleErr (err )
549546 }
547+
548+ tablesToUpdate , err := getResolvedTablesToUpdate (b .ctx , source , outScope .node )
549+ if err != nil {
550+ b .handleErr (err )
551+ }
552+
553+ tableNames := make ([]string , len (tablesToUpdate ))
554+ for i , tableToUpdate := range tablesToUpdate {
555+ tableNames [i ] = tableToUpdate .Name ()
556+ }
557+
550558 updateJoin := plan .NewUpdateJoin (updaters , source )
559+ updateJoin .TargetTables = tableNames
551560 update .Child = updateJoin
552561 transform .Inspect (update , func (n sql.Node ) bool {
553562 // todo maybe this should be later stage
@@ -594,6 +603,34 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
594603 return
595604}
596605
606+ // hasJoinNode returns true if |node| or any child is a JoinNode.
607+ func hasJoinNode (node sql.Node ) bool {
608+ updateJoinFound := false
609+ transform .Inspect (node , func (n sql.Node ) bool {
610+ if _ , ok := n .(* plan.JoinNode ); ok {
611+ updateJoinFound = true
612+ }
613+ return ! updateJoinFound
614+ })
615+ return updateJoinFound
616+ }
617+
618+ func getResolvedTablesToUpdate (_ * sql.Context , node sql.Node , ij sql.Node ) (resolvedTables []* plan.ResolvedTable , err error ) {
619+ namesOfTableToBeUpdated := getTablesToBeUpdated (node )
620+ resolvedTablesMap := getTablesByName (ij )
621+
622+ for tableToBeUpdated , _ := range namesOfTableToBeUpdated {
623+ resolvedTable , ok := resolvedTablesMap [strings .ToLower (tableToBeUpdated )]
624+ if ! ok {
625+ return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
626+ }
627+
628+ resolvedTables = append (resolvedTables , resolvedTable )
629+ }
630+
631+ return resolvedTables , nil
632+ }
633+
597634// rowUpdatersByTable maps a set of tables to their RowUpdater objects.
598635func rowUpdatersByTable (ctx * sql.Context , node sql.Node , ij sql.Node ) (map [string ]sql.RowUpdater , error ) {
599636 namesOfTableToBeUpdated := getTablesToBeUpdated (node )
0 commit comments