Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions enginetest/queries/procedure_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2192,6 +2192,70 @@ END;`,
},
},
},
{
Name: "recursive procedure",
SetUpScript: []string{
`
create procedure recursive_proc(in counter int)
begin
set counter := counter + 1;
if counter > 3 then
select concat('ended with value: ', counter) as result;
else
call recursive_proc(counter);
end if;
end;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "call recursive_proc(1);",
Expected: []sql.Row{
{"ended with value: 4"},
},
},
},
},
{
Name: "multi recursive procedures",
SetUpScript: []string{
`
create procedure procA(in counter int)
begin
set counter := counter + 1;
if counter > 3 then
select concat('ended in procA with value: ', counter) as result;
else
call procB(counter);
end if;
end;
`,
`
create procedure procB(in counter int)
begin
set counter := counter + 1;
if counter > 3 then
select concat('ended in procB with value: ', counter) as result;
else
call procA(counter);
end if;
end;
`,
},
Assertions: []ScriptTestAssertion{
{
Query: "call procA(1);",
Expected: []sql.Row{
{"ended in procA with value: 4"},
},
},
{
Query: "call procB(1);",
Expected: []sql.Row{
{"ended in procB with value: 4"},
},
},
},
},
}

var ProcedureCallTests = []ScriptTest{
Expand Down
2 changes: 1 addition & 1 deletion sql/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ type StoredProcParam struct {
func (s *StoredProcParam) SetValue(val any) {
s.Value = val
s.HasBeenSet = true
if s.Reference != nil {
if s.Reference != nil && s != s.Reference {
s.Reference.SetValue(val)
}
}
Expand Down
3 changes: 3 additions & 0 deletions sql/rowexec/proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq
paramName := strings.ToLower(param.Name)
for spp := ctx.Session.GetStoredProcParam(paramName); spp != nil; {
spp.Value = paramVal
if spp.Reference == spp {
break
}
spp = spp.Reference
}
}
Expand Down