@@ -516,17 +516,18 @@ type accumulatorIter struct {
516516 updateRowHandler accumulatorRowHandler
517517}
518518
519-
520- func getRowHandler (ctx * sql.Context , clientFoundRowsToggled bool , iter sql.RowIter ) accumulatorRowHandler {
519+ func getRowHandler (clientFoundRowsToggled bool , iter sql.RowIter ) accumulatorRowHandler {
521520 switch i := iter .(type ) {
522521 case * plan.TableEditorIter :
523- return getRowHandler (ctx , clientFoundRowsToggled , i .InnerIter ())
522+ return getRowHandler (clientFoundRowsToggled , i .InnerIter ())
524523 case * plan.CheckpointingTableEditorIter :
525- return getRowHandler (ctx , clientFoundRowsToggled , i .InnerIter ())
524+ return getRowHandler (clientFoundRowsToggled , i .InnerIter ())
526525 case * ProjectIter :
527- return getRowHandler (ctx , clientFoundRowsToggled , i .childIter )
526+ return getRowHandler (clientFoundRowsToggled , i .childIter )
528527 case * triggerIter :
529- return getRowHandler (ctx , clientFoundRowsToggled , i .child )
528+ return getRowHandler (clientFoundRowsToggled , i .child )
529+ case * blockIter :
530+ return getRowHandler (clientFoundRowsToggled , i .repIter )
530531 case * insertIter :
531532 if i .replacer != nil {
532533 return & replaceRowHandler {}
@@ -538,14 +539,25 @@ func getRowHandler(ctx *sql.Context, clientFoundRowsToggled bool, iter sql.RowIt
538539 case * deleteIter :
539540 return & deleteRowHandler {}
540541 case * updateIter :
541- if updateJoin , ok := i .childIter .(* updateJoinIter ); ok {
542- return & updateJoinRowHandler {
543- joinSchema : updateJoin .joinSchema ,
544- tableMap : plan .RecreateTableSchemaFromJoinSchema (updateJoin .joinSchema ),
545- updaterMap : updateJoin .updaters ,
546- }
542+ // it's possible that there's an updateJoinIter that's not the immediate child of updateIter
543+ rowHandler := getRowHandler (clientFoundRowsToggled , i .childIter )
544+ if rowHandler != nil {
545+ return rowHandler
546+ }
547+ sch := i .schema
548+ // special case for foreign keys, plan.ForeignKeyHandler.Schema() returns original schema
549+ if fkHandler , isFk := i .updater .(* plan.ForeignKeyHandler ); isFk {
550+ sch = fkHandler .Sch
547551 }
548- return & updateRowHandler {schema : i .schema [:len (i .schema )/ 2 ], clientFoundRowsCapability : clientFoundRowsToggled }
552+ return & updateRowHandler {schema : sch , clientFoundRowsCapability : clientFoundRowsToggled }
553+ case * updateJoinIter :
554+ rowHandler := & updateJoinRowHandler {
555+ joinSchema : i .joinSchema ,
556+ tableMap : plan .RecreateTableSchemaFromJoinSchema (i .joinSchema ),
557+ updaterMap : i .updaters ,
558+ }
559+ i .accumulator = rowHandler
560+ return rowHandler
549561 default :
550562 return nil
551563 }
@@ -561,13 +573,9 @@ func AddAccumulatorIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) (sql.
561573 childIter , sch , err := AddAccumulatorIter (ctx , node , i .rowIter )
562574 i .rowIter = childIter
563575 return i , sch , err
564- case * blockIter :
565- childIter , sch , err := AddAccumulatorIter (ctx , node , i .internalIter )
566- i .internalIter = childIter
567- return i , sch , err
568576 default :
569- clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) == mysql . CapabilityClientFoundRows
570- rowHandler := getRowHandler (ctx , clientFoundRowsToggled , iter )
577+ clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) > 0
578+ rowHandler := getRowHandler (clientFoundRowsToggled , iter )
571579 if rowHandler == nil {
572580 return iter , nil , nil
573581 }
0 commit comments