Skip to content

Commit 6df0972

Browse files
committed
Bug fix for omitting an accumulator for many DML nodes
1 parent 43b5a64 commit 6df0972

File tree

1 file changed

+17
-25
lines changed

1 file changed

+17
-25
lines changed

sql/rowexec/dml_iters.go

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -593,35 +593,27 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc
593593
case *plan.TableEditorIter:
594594
// If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter
595595
innerIter := i.InnerIter()
596-
if insertIter, ok := innerIter.(*insertIter); ok {
597-
if len(insertIter.returnExprs) > 0 {
598-
return insertIter, insertIter.returnSchema
599-
} else {
600-
// TODO: How do we use the default logic if this isn't true... ? For now, just copying...
601-
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
602-
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
603-
if rowHandler == nil {
604-
return iter, nil
605-
}
606-
return &accumulatorIter{
607-
iter: iter,
608-
updateRowHandler: rowHandler,
609-
}, types.OkResultSchema
610-
}
596+
if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 {
597+
return insertIter, insertIter.returnSchema
611598
}
612599

613-
return iter, nil
600+
return defaultAccumulatorIter(ctx, iter)
614601
default:
615-
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
616-
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
617-
if rowHandler == nil {
618-
return iter, nil
619-
}
620-
return &accumulatorIter{
621-
iter: iter,
622-
updateRowHandler: rowHandler,
623-
}, types.OkResultSchema
602+
return defaultAccumulatorIter(ctx, iter)
603+
}
604+
}
605+
606+
// defaultAccumulatorIter returns the default accumulator iter for a DML node
607+
func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) {
608+
clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0
609+
rowHandler := getRowHandler(clientFoundRowsToggled, iter)
610+
if rowHandler == nil {
611+
return iter, nil
624612
}
613+
return &accumulatorIter{
614+
iter: iter,
615+
updateRowHandler: rowHandler,
616+
}, types.OkResultSchema
625617
}
626618

627619
func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) {

0 commit comments

Comments
 (0)