@@ -150,12 +150,10 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) {
150150 ins := plan .NewInsertInto (db , plan .NewInsertDestination (sch , dest ), srcScope .node , isReplace , columns , onDupExprs , ignore )
151151 ins .LiteralValueSource = srcLiteralOnly
152152
153- if i .Returning != nil {
154- returningExprs := make ([]sql.Expression , len (i .Returning ))
155- for i , selectExpr := range i .Returning {
156- returningExprs [i ] = b .selectExprToExpression (destScope , selectExpr )
157- }
158- ins .Returning = returningExprs
153+ if len (i .Returning ) > 0 {
154+ // TODO: read returning results from outScope instead of ins.Returning so that there is no need to return list
155+ // of expressions
156+ ins .Returning = b .analyzeSelectList (destScope , destScope , i .Returning )
159157 }
160158
161159 b .validateInsert (ins )
@@ -492,6 +490,11 @@ func (b *Builder) buildDelete(inScope *scope, d *ast.Delete) (outScope *scope) {
492490 return
493491}
494492
493+ // buildUpdate builds a Update node from |u|. If the update joins tables, the returned Update node's
494+ // children will have a JoinNode, which will later be replaced by an UpdateJoin node during analysis. We
495+ // don't create the UpdateJoin node here, because some query plans, such as IN SUBQUERY nodes, require
496+ // analyzer processing that converts the subquery into a join, and then requires the same logic to
497+ // create an UpdateJoin node under the original Update node.
495498func (b * Builder ) buildUpdate (inScope * scope , u * ast.Update ) (outScope * scope ) {
496499 // TODO: this shouldn't be called during ComPrepare or `PREPARE ... FROM ...` statements, but currently it is.
497500 // The end result is that the ComDelete counter is incremented during prepare statements, which is incorrect.
@@ -534,44 +537,26 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
534537 update .IsProcNested = b .ProcCtx ().DbName != ""
535538
536539 var checks []* sql.CheckConstraint
537- if join , ok := outScope .node .(* plan.JoinNode ); ok {
538- // TODO this doesn't work, a lot of the time the top node
539- // is a filter. This would have to go before we build the
540- // filter/accessory nodes. But that errors for a lot of queries.
541- source := plan .NewUpdateSource (
542- join ,
543- ignore ,
544- updateExprs ,
545- )
546- updaters , err := rowUpdatersByTable (b .ctx , source , join )
540+ if hasJoinNode (outScope .node ) {
541+ tablesToUpdate , err := getResolvedTablesToUpdate (b .ctx , update .Child , outScope .node )
547542 if err != nil {
548543 b .handleErr (err )
549544 }
550- updateJoin := plan .NewUpdateJoin (updaters , source )
551- update .Child = updateJoin
552- transform .Inspect (update , func (n sql.Node ) bool {
553- // todo maybe this should be later stage
554- switch n := n .(type ) {
555- case sql.NameableNode :
556- if _ , ok := updaters [n .Name ()]; ok {
557- rt := getResolvedTable (n )
558- tableScope := inScope .push ()
559- for _ , c := range rt .Schema () {
560- tableScope .addColumn (scopeColumn {
561- db : rt .SqlDatabase .Name (),
562- table : strings .ToLower (n .Name ()),
563- tableId : tableScope .tables [strings .ToLower (n .Name ())],
564- col : strings .ToLower (c .Name ),
565- typ : c .Type ,
566- nullable : c .Nullable ,
567- })
568- }
569- checks = append (checks , b .loadChecksFromTable (tableScope , rt .Table )... )
570- }
571- default :
545+
546+ for _ , rt := range tablesToUpdate {
547+ tableScope := inScope .push ()
548+ for _ , c := range rt .Schema () {
549+ tableScope .addColumn (scopeColumn {
550+ db : rt .SqlDatabase .Name (),
551+ table : strings .ToLower (rt .Name ()),
552+ tableId : tableScope .tables [strings .ToLower (rt .Name ())],
553+ col : strings .ToLower (c .Name ),
554+ typ : c .Type ,
555+ nullable : c .Nullable ,
556+ })
572557 }
573- return true
574- })
558+ checks = append ( checks , b . loadChecksFromTable ( tableScope , rt . Table ) ... )
559+ }
575560 } else {
576561 transform .Inspect (update , func (n sql.Node ) bool {
577562 // todo maybe this should be later stage
@@ -583,46 +568,39 @@ func (b *Builder) buildUpdate(inScope *scope, u *ast.Update) (outScope *scope) {
583568 }
584569
585570 if len (u .Returning ) > 0 {
586- returningExprs := make ([]sql.Expression , len (u .Returning ))
587- for i , selectExpr := range u .Returning {
588- returningExprs [i ] = b .selectExprToExpression (outScope , selectExpr )
589- }
590- update .Returning = returningExprs
571+ update .Returning = b .analyzeSelectList (outScope , outScope , u .Returning )
591572 }
592573
593574 outScope .node = update .WithChecks (checks )
594575 return
595576}
596577
597- // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
598- func rowUpdatersByTable (ctx * sql.Context , node sql.Node , ij sql.Node ) (map [string ]sql.RowUpdater , error ) {
599- namesOfTableToBeUpdated := getTablesToBeUpdated (node )
600- resolvedTables := getTablesByName (ij )
601-
602- rowUpdatersByTable := make (map [string ]sql.RowUpdater )
603- for tableToBeUpdated , _ := range namesOfTableToBeUpdated {
604- resolvedTable , ok := resolvedTables [strings .ToLower (tableToBeUpdated )]
605- if ! ok {
606- return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
578+ // hasJoinNode returns true if |node| or any child is a JoinNode.
579+ func hasJoinNode (node sql.Node ) bool {
580+ updateJoinFound := false
581+ transform .Inspect (node , func (n sql.Node ) bool {
582+ if _ , ok := n .(* plan.JoinNode ); ok {
583+ updateJoinFound = true
607584 }
585+ return ! updateJoinFound
586+ })
587+ return updateJoinFound
588+ }
608589
609- var table = resolvedTable .UnderlyingTable ()
590+ func getResolvedTablesToUpdate (_ * sql.Context , node sql.Node , ij sql.Node ) (resolvedTables []* plan.ResolvedTable , err error ) {
591+ namesOfTablesToBeUpdated := getTablesToBeUpdated (node )
592+ resolvedTablesMap := getTablesByName (ij )
610593
611- // If there is no UpdatableTable for a table being updated, error out
612- updatable , ok := table .(sql. UpdatableTable )
613- if ! ok && updatable == nil {
594+ for tableToBeUpdated , _ := range namesOfTablesToBeUpdated {
595+ resolvedTable , ok := resolvedTablesMap [ strings . ToLower ( tableToBeUpdated )]
596+ if ! ok {
614597 return nil , plan .ErrUpdateForTableNotSupported .New (tableToBeUpdated )
615598 }
616599
617- keyless := sql .IsKeyless (updatable .Schema ())
618- if keyless {
619- return nil , sql .ErrUnsupportedFeature .New ("error: keyless tables unsupported for UPDATE JOIN" )
620- }
621-
622- rowUpdatersByTable [tableToBeUpdated ] = updatable .Updater (ctx )
600+ resolvedTables = append (resolvedTables , resolvedTable )
623601 }
624602
625- return rowUpdatersByTable , nil
603+ return resolvedTables , nil
626604}
627605
628606// getTablesByName takes a node and returns all found resolved tables in a map.
0 commit comments