@@ -516,87 +516,66 @@ type accumulatorIter struct {
516516 updateRowHandler accumulatorRowHandler
517517}
518518
519- func AddAccumulatorIter (ctx * sql.Context , node sql.Node , iter sql.RowIter ) (sql.RowIter , sql.Schema , error ) {
520- clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) == mysql .CapabilityClientFoundRows
521519
522- var rowHandler accumulatorRowHandler
523- switch n := node .(type ) {
524- case * plan.TriggerExecutor :
525- return AddAccumulatorIter (ctx , n .Left (), iter )
526- case * plan.InsertInto :
527- if n .IsReplace {
528- rowHandler = & replaceRowHandler {}
529- } else if len (n .OnDupExprs ) > 0 {
530- rowHandler = & onDuplicateUpdateHandler {schema : n .Schema (), clientFoundRowsCapability : clientFoundRowsToggled }
531- } else {
532- rowHandler = & insertRowHandler {}
520+ func getRowHandler (ctx * sql.Context , clientFoundRowsToggled bool , iter sql.RowIter ) accumulatorRowHandler {
521+ switch i := iter .(type ) {
522+ case * plan.TableEditorIter :
523+ return getRowHandler (ctx , clientFoundRowsToggled , i .InnerIter ())
524+ case * plan.CheckpointingTableEditorIter :
525+ return getRowHandler (ctx , clientFoundRowsToggled , i .InnerIter ())
526+ case * ProjectIter :
527+ return getRowHandler (ctx , clientFoundRowsToggled , i .childIter )
528+ case * triggerIter :
529+ return getRowHandler (ctx , clientFoundRowsToggled , i .child )
530+ case * insertIter :
531+ if i .replacer != nil {
532+ return & replaceRowHandler {}
533533 }
534- case * plan.DeleteFrom :
535- rowHandler = & deleteRowHandler {}
536- case * plan.Update :
537- // search for a join
538- hasJoin := false
539- transform .Inspect (n , func (node sql.Node ) bool {
540- switch node .(type ) {
541- case * plan.JoinNode :
542- hasJoin = true
543- return false
544- }
545- return true
546- })
547- if hasJoin {
548- var schema sql.Schema
549- var updaterMap map [string ]sql.RowUpdater
550- transform .Inspect (n , func (node sql.Node ) bool {
551- switch nn := node .(type ) {
552- case * plan.JoinNode , * plan.Project :
553- schema = node .Schema ()
554- return false
555- case * plan.UpdateJoin :
556- updaterMap = nn .Updaters
557- return true
558- default :
559- return true
560- }
561- })
562- if schema == nil {
563- return nil , nil , fmt .Errorf ("error: No JoinNode found in query plan to go along with an UpdateTypeJoinUpdate" )
564- }
565- // assign row handler to updateJoinIter
566- rowHandler = & updateJoinRowHandler {joinSchema : schema , tableMap : plan .RecreateTableSchemaFromJoinSchema (schema ), updaterMap : updaterMap }
567- for i , done := iter , false ; ! done ; {
568- switch ii := i .(type ) {
569- case * plan.TableEditorIter :
570- i = ii .InnerIter ()
571- case * ProjectIter :
572- i = ii .childIter
573- case * plan.CheckpointingTableEditorIter :
574- i = ii .InnerIter ()
575- case * triggerIter :
576- i = ii .child
577- case * updateIter :
578- i = ii .childIter
579- case * updateJoinIter :
580- ii .accumulator = rowHandler .(* updateJoinRowHandler )
581- done = true
582- default :
583- return nil , nil , fmt .Errorf ("failed to apply rowHandler to updateJoin, unknown type: %T" , iter )
584- }
534+ if i .updater != nil {
535+ return & onDuplicateUpdateHandler {schema : i .schema , clientFoundRowsCapability : clientFoundRowsToggled }
536+ }
537+ return & insertRowHandler {}
538+ case * deleteIter :
539+ return & deleteRowHandler {}
540+ case * updateIter :
541+ if updateJoin , ok := i .childIter .(* updateJoinIter ); ok {
542+ return & updateJoinRowHandler {
543+ joinSchema : updateJoin .joinSchema ,
544+ tableMap : plan .RecreateTableSchemaFromJoinSchema (updateJoin .joinSchema ),
545+ updaterMap : updateJoin .updaters ,
585546 }
586- } else {
587- // the schema of the update node is a self-concatenation of the underlying table's, so split it in half
588- // for new / old row comparison purposes
589- sch := n .Schema ()
590- rowHandler = & updateRowHandler {schema : sch [:len (sch )/ 2 ], clientFoundRowsCapability : clientFoundRowsToggled }
591547 }
548+ return & updateRowHandler {schema : i .schema [:len (i .schema )/ 2 ], clientFoundRowsCapability : clientFoundRowsToggled }
592549 default :
593- return iter , nil , nil
550+ return nil
594551 }
552+ }
595553
596- return & accumulatorIter {
597- iter : iter ,
598- updateRowHandler : rowHandler ,
599- }, types .OkResultSchema , nil
554+ func AddAccumulatorIter (ctx * sql.Context , node sql.Node , iter sql.RowIter ) (sql.RowIter , sql.Schema , error ) {
555+ switch i := iter .(type ) {
556+ case * callIter :
557+ childIter , sch , err := AddAccumulatorIter (ctx , node , i .innerIter )
558+ i .innerIter = childIter
559+ return i , sch , err
560+ case * beginEndIter :
561+ childIter , sch , err := AddAccumulatorIter (ctx , node , i .rowIter )
562+ i .rowIter = childIter
563+ return i , sch , err
564+ case * blockIter :
565+ childIter , sch , err := AddAccumulatorIter (ctx , node , i .internalIter )
566+ i .internalIter = childIter
567+ return i , sch , err
568+ default :
569+ clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) == mysql .CapabilityClientFoundRows
570+ rowHandler := getRowHandler (ctx , clientFoundRowsToggled , iter )
571+ if rowHandler == nil {
572+ return iter , nil , nil
573+ }
574+ return & accumulatorIter {
575+ iter : iter ,
576+ updateRowHandler : rowHandler ,
577+ }, types .OkResultSchema , nil
578+ }
600579}
601580
602581func (a * accumulatorIter ) Next (ctx * sql.Context ) (r sql.Row , err error ) {
0 commit comments