@@ -39,6 +39,10 @@ var _ plan.BlockRowIter = (*ifElseIter)(nil)
3939
4040// Next implements the sql.RowIter interface.
4141func (i * ifElseIter ) Next (ctx * sql.Context ) (sql.Row , error ) {
42+ if err := startTransaction (ctx ); err != nil {
43+ return nil , err
44+ }
45+
4246 return i .branchIter .Next (ctx )
4347}
4448
@@ -67,6 +71,10 @@ var _ sql.RowIter = (*beginEndIter)(nil)
6771
6872// Next implements the interface sql.RowIter.
6973func (b * beginEndIter ) Next (ctx * sql.Context ) (sql.Row , error ) {
74+ if err := startTransaction (ctx ); err != nil {
75+ return nil , err
76+ }
77+
7078 row , err := b .rowIter .Next (ctx )
7179 if err != nil {
7280 if exitErr , ok := err .(expression.ProcedureBlockExitError ); ok && b .Pref .CurrentHeight () == int (exitErr ) {
@@ -278,6 +286,10 @@ func (l *loopIter) Next(ctx *sql.Context) (sql.Row, error) {
278286 }
279287 }
280288
289+ if err := startTransaction (ctx ); err != nil {
290+ return nil , err
291+ }
292+
281293 nextRow , err := l .blockIter .Next (ctx )
282294 if err != nil {
283295 restart := false
@@ -395,3 +407,21 @@ func (i *iterateIter) Next(ctx *sql.Context) (sql.Row, error) {
395407func (i * iterateIter ) Close (ctx * sql.Context ) error {
396408 return nil
397409}
410+
411+ // startTransaction begins a new transaction if necessary, e.g. if a statement in a stored procedure committed the
412+ // current one
413+ func startTransaction (ctx * sql.Context ) error {
414+ if ctx .GetTransaction () == nil {
415+ ts , ok := ctx .Session .(sql.TransactionSession )
416+ if ok {
417+ tx , err := ts .StartTransaction (ctx , sql .ReadWrite )
418+ if err != nil {
419+ return err
420+ }
421+
422+ ctx .SetTransaction (tx )
423+ }
424+ }
425+
426+ return nil
427+ }
0 commit comments