Skip to content

Commit 05e4a9d

Browse files
authored
Merge pull request #1852 from dolthub/zachmu/procedures
Alter stored procedure execution to deal with statements that commit transactions
2 parents 2fc0613 + c97cb0a commit 05e4a9d

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

sql/rowexec/other.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,13 @@ func (b *BaseBuilder) buildBlock(ctx *sql.Context, n *plan.Block, row sql.Row) (
250250

251251
selectSeen := false
252252
for _, s := range n.Children() {
253-
err := func() error {
253+
// TODO: this should happen at iteration time, but this call is where the actual iteration happens
254+
err := startTransaction(ctx)
255+
if err != nil {
256+
return nil, err
257+
}
258+
259+
err = func() error {
254260
rowCache, disposeFunc := ctx.Memory.NewRowsCache()
255261
defer disposeFunc()
256262

sql/rowexec/proc.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ func (b *BaseBuilder) buildIfElseBlock(ctx *sql.Context, n *plan.IfElseBlock, ro
8989
continue
9090
}
9191

92+
// TODO: this should happen at iteration time, but this call is where the actual iteration happens
93+
err = startTransaction(ctx)
94+
if err != nil {
95+
return nil, err
96+
}
97+
9298
branchIter, err = b.buildNodeExec(ctx, ifConditional, row)
9399
if err != nil {
94100
return nil, err
@@ -105,6 +111,12 @@ func (b *BaseBuilder) buildIfElseBlock(ctx *sql.Context, n *plan.IfElseBlock, ro
105111
}, nil
106112
}
107113

114+
// TODO: this should happen at iteration time, but this call is where the actual iteration happens
115+
err = startTransaction(ctx)
116+
if err != nil {
117+
return nil, err
118+
}
119+
108120
// All conditions failed so we run the else
109121
branchIter, err = b.buildNodeExec(ctx, n.Else, row)
110122
if err != nil {

sql/rowexec/proc_iters.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ var _ plan.BlockRowIter = (*ifElseIter)(nil)
3939

4040
// Next implements the sql.RowIter interface.
4141
func (i *ifElseIter) Next(ctx *sql.Context) (sql.Row, error) {
42+
if err := startTransaction(ctx); err != nil {
43+
return nil, err
44+
}
45+
4246
return i.branchIter.Next(ctx)
4347
}
4448

@@ -67,6 +71,10 @@ var _ sql.RowIter = (*beginEndIter)(nil)
6771

6872
// Next implements the interface sql.RowIter.
6973
func (b *beginEndIter) Next(ctx *sql.Context) (sql.Row, error) {
74+
if err := startTransaction(ctx); err != nil {
75+
return nil, err
76+
}
77+
7078
row, err := b.rowIter.Next(ctx)
7179
if err != nil {
7280
if exitErr, ok := err.(expression.ProcedureBlockExitError); ok && b.Pref.CurrentHeight() == int(exitErr) {
@@ -278,6 +286,10 @@ func (l *loopIter) Next(ctx *sql.Context) (sql.Row, error) {
278286
}
279287
}
280288

289+
if err := startTransaction(ctx); err != nil {
290+
return nil, err
291+
}
292+
281293
nextRow, err := l.blockIter.Next(ctx)
282294
if err != nil {
283295
restart := false
@@ -395,3 +407,21 @@ func (i *iterateIter) Next(ctx *sql.Context) (sql.Row, error) {
395407
func (i *iterateIter) Close(ctx *sql.Context) error {
396408
return nil
397409
}
410+
411+
// startTransaction begins a new transaction if necessary, e.g. if a statement in a stored procedure committed the
412+
// current one
413+
func startTransaction(ctx *sql.Context) error {
414+
if ctx.GetTransaction() == nil {
415+
ts, ok := ctx.Session.(sql.TransactionSession)
416+
if ok {
417+
tx, err := ts.StartTransaction(ctx, sql.ReadWrite)
418+
if err != nil {
419+
return err
420+
}
421+
422+
ctx.SetTransaction(tx)
423+
}
424+
}
425+
426+
return nil
427+
}

0 commit comments

Comments
 (0)