Skip to content

Commit 3a4afb2

Browse files
authored
Merge pull request #1854 from dolthub/fulghum/bugfix-2
Prevent loops in stored procedures from returning multiple result sets
2 parents 8dd09a6 + dd8c58a commit 3a4afb2

File tree

4 files changed

+185
-106
lines changed

4 files changed

+185
-106
lines changed

enginetest/queries/procedure_queries.go

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,50 @@ import (
2222
)
2323

2424
var ProcedureLogicTests = []ScriptTest{
25+
{
26+
// When a loop is executed once before the first evaluation of the loop condition, we expect the stored
27+
// procedure to return the last result set from that first loop execution.
28+
Name: "REPEAT with OnceBefore returns first loop evaluation result set",
29+
SetUpScript: []string{
30+
`CREATE PROCEDURE p1()
31+
BEGIN
32+
SET @counter = 0;
33+
REPEAT
34+
SELECT 42 from dual;
35+
SET @counter = @counter + 1;
36+
UNTIL @counter >= 0
37+
END REPEAT;
38+
END`,
39+
},
40+
Assertions: []ScriptTestAssertion{
41+
{
42+
Query: "CALL p1;",
43+
Expected: []sql.Row{{42}},
44+
},
45+
},
46+
},
47+
{
48+
// When a loop condition evals to false, we expect the stored procedure to return the last
49+
// result set from the previous loop execution.
50+
Name: "WHILE returns previous loop evaluation result set",
51+
SetUpScript: []string{
52+
`CREATE PROCEDURE p1()
53+
BEGIN
54+
SET @counter = 0;
55+
WHILE @counter <= 0 DO
56+
SET @counter = @counter + 1;
57+
SELECT CAST(@counter + 41 as SIGNED) from dual;
58+
END WHILE;
59+
END`,
60+
},
61+
Assertions: []ScriptTestAssertion{
62+
{
63+
Query: "CALL p1;",
64+
Expected: []sql.Row{{42}},
65+
},
66+
},
67+
},
68+
2569
{
2670
Name: "Simple SELECT",
2771
SetUpScript: []string{
@@ -278,17 +322,25 @@ BEGIN
278322
END`,
279323
},
280324
Assertions: []ScriptTestAssertion{
325+
// TODO: MySQL won't actually return *any* result set for these stored procedures. We have done work
326+
// to filter out all but the last result set generated by the stored procedure, but we still
327+
// need to filter out Result Sets that should be completely omitted.
281328
{
282329
Query: "CALL p1(0)",
283-
Expected: []sql.Row{},
330+
Expected: []sql.Row{{}},
284331
},
285332
{
286333
Query: "CALL p1(1)",
287-
Expected: []sql.Row{{}, {}}, // Next calls return an empty row, but progress the loop
334+
Expected: []sql.Row{{}},
288335
},
289336
{
290337
Query: "CALL p1(2)",
291-
Expected: []sql.Row{{}, {}, {}},
338+
Expected: []sql.Row{{}},
339+
},
340+
{
341+
// https://github.com/dolthub/dolt/issues/6230
342+
Query: "CALL p1(200)",
343+
Expected: []sql.Row{{}},
292344
},
293345
},
294346
},
@@ -304,17 +356,20 @@ BEGIN
304356
END`,
305357
},
306358
Assertions: []ScriptTestAssertion{
359+
// TODO: MySQL won't actually return *any* result set for these stored procedures. We have done work
360+
// to filter out all but the last result set generated by the stored procedure, but we still
361+
// need to filter out Result Sets that should be completely omitted.
307362
{
308363
Query: "CALL p1(0)",
309364
Expected: []sql.Row{{}},
310365
},
311366
{
312367
Query: "CALL p1(1)",
313-
Expected: []sql.Row{{}, {}},
368+
Expected: []sql.Row{{}},
314369
},
315370
{
316371
Query: "CALL p1(2)",
317-
Expected: []sql.Row{{}, {}, {}},
372+
Expected: []sql.Row{{}},
318373
},
319374
},
320375
},

sql/rowexec/proc.go

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ package rowexec
1616

1717
import (
1818
"fmt"
19+
"io"
1920
"strings"
20-
"sync"
2121

2222
"github.com/dolthub/go-mysql-server/sql"
2323
"github.com/dolthub/go-mysql-server/sql/expression"
@@ -194,27 +194,137 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq
194194
}, nil
195195
}
196196

197+
// buildLoop builds and returns an iterator that can be used to iterate over the result set returned from the
198+
// specified loop, |n|, for the specified row, |row|. Note that because of how we execute stored procedures and cache
199+
// the results in order to only send back the LAST result set (instead of supporting multiple results sets from
200+
// stored procedures, like MySQL does), building the iterator here also implicitly means that we're executing the
201+
// loop logic and caching the result set in memory. This will obviously be an issue for very large result sets.
202+
// Unfortunately, we can't know at analysis time what the last result set returned will be, since conditional logic
203+
// in stored procedures can't be known until execution time, hence why we end up caching result sets when we
204+
// see them and just playing back the last one. Adding support for MySQL's multiple result set behavior and better
205+
// matching MySQL on which statements are allowed to return result sets from a stored procedure seems like it could
206+
// potentially allow us to get rid of that caching.
197207
func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sql.RowIter, error) {
198-
var blockIter sql.RowIter
199-
// Currently, acquiring the RowIter will actually run through the loop once, so we abuse this by grabbing the iter
200-
// only if we're supposed to run through the iter once before evaluating the condition
208+
// Acquiring the RowIter will actually execute the loop body once (because of how we cache/scan for the right
209+
// SELECT result set to return), so we grab the iter ONLY if we're supposed to run through the loop body once
210+
// before evaluating the condition
211+
var loopBodyIter sql.RowIter
201212
if n.OnceBeforeEval {
202213
var err error
203-
blockIter, err = b.loopAcquireRowIter(ctx, row, n.Label, n.Block, true)
214+
loopBodyIter, err = b.loopAcquireRowIter(ctx, row, n.Label, n.Block, true)
204215
if err != nil {
205216
return nil, err
206217
}
207218
}
208-
iter := &loopIter{
209-
block: n.Block,
210-
label: strings.ToLower(n.Label),
211-
condition: n.Condition,
212-
once: sync.Once{},
213-
blockIter: blockIter,
214-
row: row,
215-
loopIteration: 0,
219+
220+
var returnRows []sql.Row
221+
var returnNode sql.Node
222+
var returnSch sql.Schema
223+
selectSeen := false
224+
225+
// It's technically valid to make an infinite loop, but we don't want to actually allow that
226+
const maxIterationCount = 10_000_000_000
227+
228+
for loopIteration := 0; loopIteration <= maxIterationCount; loopIteration++ {
229+
if loopIteration >= maxIterationCount {
230+
return nil, fmt.Errorf("infinite LOOP detected")
231+
}
232+
233+
// If the condition is false, then we stop evaluation
234+
condition, err := n.Condition.Eval(ctx, nil)
235+
if err != nil {
236+
return nil, err
237+
}
238+
conditionBool, err := types.ConvertToBool(condition)
239+
if err != nil {
240+
return nil, err
241+
}
242+
if !conditionBool {
243+
// loopBodyIter should only be set if this is the first time through the loop and the loop has a
244+
// OnceBeforeEval condition. This ensures we return a result set, without us having to drain the iterator,
245+
// recache rows, and return a new iterator.
246+
if loopBodyIter != nil {
247+
return loopBodyIter, nil
248+
} else {
249+
break
250+
}
251+
}
252+
253+
if loopBodyIter == nil {
254+
var err error
255+
loopBodyIter, err = b.loopAcquireRowIter(ctx, nil, strings.ToLower(n.Label), n.Block, false)
256+
if err == io.EOF {
257+
break
258+
} else if err != nil {
259+
return nil, err
260+
}
261+
}
262+
263+
includeResultSet := false
264+
265+
var subIterNode sql.Node = n.Block
266+
subIterSch := n.Block.Schema()
267+
if blockRowIter, ok := loopBodyIter.(plan.BlockRowIter); ok {
268+
subIterNode = blockRowIter.RepresentingNode()
269+
subIterSch = blockRowIter.Schema()
270+
271+
if plan.NodeRepresentsSelect(subIterNode) {
272+
selectSeen = true
273+
includeResultSet = true
274+
returnNode = subIterNode
275+
returnSch = subIterSch
276+
} else if !selectSeen {
277+
includeResultSet = true
278+
returnNode = subIterNode
279+
returnSch = subIterSch
280+
}
281+
}
282+
283+
// Wrap the caching code in an inline function so that we can use defer to safely dispose of the cache
284+
err = func() error {
285+
rowCache, disposeFunc := ctx.Memory.NewRowsCache()
286+
defer disposeFunc()
287+
288+
nextRow, err := loopBodyIter.Next(ctx)
289+
for ; err == nil; nextRow, err = loopBodyIter.Next(ctx) {
290+
rowCache.Add(nextRow)
291+
}
292+
if err != io.EOF {
293+
return err
294+
}
295+
296+
err = loopBodyIter.Close(ctx)
297+
if err != nil {
298+
return err
299+
}
300+
loopBodyIter = nil
301+
302+
if includeResultSet {
303+
returnRows = rowCache.Get()
304+
}
305+
return nil
306+
}()
307+
308+
if err != nil {
309+
if err == io.EOF {
310+
// no-op for an EOF, just execute the next loop iteration
311+
} else if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == n.Label {
312+
if controlFlow.IsExit {
313+
break
314+
}
315+
} else {
316+
// If the error wasn't a control flow error signaling to start the next loop iteration or to
317+
// exit the loop, then it must be a real error, so just return it.
318+
return nil, err
319+
}
320+
}
216321
}
217-
return iter, nil
322+
323+
return &blockIter{
324+
internalIter: sql.RowsToRowIter(returnRows...),
325+
repNode: returnNode,
326+
sch: returnSch,
327+
}, nil
218328
}
219329

220330
func (b *BaseBuilder) buildElseCaseError(ctx *sql.Context, n plan.ElseCaseError, row sql.Row) (sql.RowIter, error) {

sql/rowexec/proc_iters.go

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@ import (
1818
"fmt"
1919
"io"
2020
"strings"
21-
"sync"
2221

2322
"github.com/dolthub/vitess/go/mysql"
2423

2524
"github.com/dolthub/go-mysql-server/sql"
2625
"github.com/dolthub/go-mysql-server/sql/expression"
2726
"github.com/dolthub/go-mysql-server/sql/plan"
28-
"github.com/dolthub/go-mysql-server/sql/types"
2927
)
3028

3129
// ifElseIter is the row iterator for *IfElseBlock.
@@ -245,90 +243,6 @@ func (c *closeIter) Close(ctx *sql.Context) error {
245243
return nil
246244
}
247245

248-
// loopIter is the sql.RowIter of *Loop.
249-
type loopIter struct {
250-
block *plan.Block
251-
label string
252-
condition sql.Expression
253-
once sync.Once
254-
blockIter sql.RowIter
255-
row sql.Row
256-
loopIteration uint64
257-
}
258-
259-
var _ sql.RowIter = (*loopIter)(nil)
260-
261-
// Next implements the interface sql.RowIter.
262-
func (l *loopIter) Next(ctx *sql.Context) (sql.Row, error) {
263-
// It's technically valid to make an infinite loop, but we don't want to actually allow that
264-
const maxIterationCount = 10_000_000_000
265-
l.loopIteration++
266-
for ; l.loopIteration < maxIterationCount; l.loopIteration++ {
267-
// If the condition is false, then we stop evaluation
268-
condition, err := l.condition.Eval(ctx, nil)
269-
if err != nil {
270-
return nil, err
271-
}
272-
conditionBool, err := types.ConvertToBool(condition)
273-
if err != nil {
274-
return nil, err
275-
}
276-
if !conditionBool {
277-
return nil, io.EOF
278-
}
279-
280-
if l.blockIter == nil {
281-
var err error
282-
b := &BaseBuilder{}
283-
l.blockIter, err = b.loopAcquireRowIter(ctx, nil, l.label, l.block, false)
284-
if err != nil {
285-
return nil, err
286-
}
287-
}
288-
289-
if err := startTransaction(ctx); err != nil {
290-
return nil, err
291-
}
292-
293-
nextRow, err := l.blockIter.Next(ctx)
294-
if err != nil {
295-
restart := false
296-
if err == io.EOF {
297-
restart = true
298-
} else if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == l.label {
299-
if controlFlow.IsExit {
300-
return nil, io.EOF
301-
} else {
302-
restart = true
303-
}
304-
}
305-
306-
if restart {
307-
err = l.blockIter.Close(ctx)
308-
if err != nil {
309-
return nil, err
310-
}
311-
l.blockIter = nil
312-
continue
313-
}
314-
return nil, err
315-
}
316-
return nextRow, nil
317-
}
318-
if l.loopIteration >= maxIterationCount {
319-
return nil, fmt.Errorf("infinite LOOP detected")
320-
}
321-
return nil, io.EOF
322-
}
323-
324-
// Close implements the interface sql.RowIter.
325-
func (l *loopIter) Close(ctx *sql.Context) error {
326-
if l.blockIter != nil {
327-
return l.blockIter.Close(ctx)
328-
}
329-
return nil
330-
}
331-
332246
// loopError is an error used to control a loop's flow.
333247
type loopError struct {
334248
Label string

sql/rowexec/rel.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ func (b *BaseBuilder) buildOrderedDistinct(ctx *sql.Context, n *plan.OrderedDist
281281
}
282282

283283
func (b *BaseBuilder) buildWith(ctx *sql.Context, n *plan.With, row sql.Row) (sql.RowIter, error) {
284-
return nil, fmt.Errorf("*plan.With has not execution iterator")
284+
return nil, fmt.Errorf("*plan.With has no execution iterator")
285285
}
286286

287287
func (b *BaseBuilder) buildProject(ctx *sql.Context, n *plan.Project, row sql.Row) (sql.RowIter, error) {

0 commit comments

Comments
 (0)