@@ -593,35 +593,27 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
593593 case * plan.TableEditorIter :
594594 // If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
595595 innerIter := i .InnerIter ()
596- if insertIter , ok := innerIter .(* insertIter ); ok {
597- if len (insertIter .returnExprs ) > 0 {
598- return insertIter , insertIter .returnSchema
599- } else {
600- // TODO: How do we use the default logic if this isn't true... ? For now, just copying...
601- clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) > 0
602- rowHandler := getRowHandler (clientFoundRowsToggled , iter )
603- if rowHandler == nil {
604- return iter , nil
605- }
606- return & accumulatorIter {
607- iter : iter ,
608- updateRowHandler : rowHandler ,
609- }, types .OkResultSchema
610- }
596+ if insertIter , ok := innerIter .(* insertIter ); ok && len (insertIter .returnExprs ) > 0 {
597+ return insertIter , insertIter .returnSchema
611598 }
612599
613- return iter , nil
600+ return defaultAccumulatorIter ( ctx , iter )
614601 default :
615- clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) > 0
616- rowHandler := getRowHandler (clientFoundRowsToggled , iter )
617- if rowHandler == nil {
618- return iter , nil
619- }
620- return & accumulatorIter {
621- iter : iter ,
622- updateRowHandler : rowHandler ,
623- }, types .OkResultSchema
602+ return defaultAccumulatorIter (ctx , iter )
603+ }
604+ }
605+
606+ // defaultAccumulatorIter returns the default accumulator iter for a DML node
607+ func defaultAccumulatorIter (ctx * sql.Context , iter sql.RowIter ) (sql.RowIter , sql.Schema ) {
608+ clientFoundRowsToggled := (ctx .Client ().Capabilities & mysql .CapabilityClientFoundRows ) > 0
609+ rowHandler := getRowHandler (clientFoundRowsToggled , iter )
610+ if rowHandler == nil {
611+ return iter , nil
624612 }
613+ return & accumulatorIter {
614+ iter : iter ,
615+ updateRowHandler : rowHandler ,
616+ }, types .OkResultSchema
625617}
626618
627619func (a * accumulatorIter ) Next (ctx * sql.Context ) (r sql.Row , err error ) {
0 commit comments