Skip to content

Commit e64f37f

Browse files
author
James Cor
committed
some feedback
1 parent b7d7c61 commit e64f37f

File tree

7 files changed

+33
-12
lines changed

7 files changed

+33
-12
lines changed

sql/base_session.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,20 +254,24 @@ func (s *BaseSession) IncrementStatusVariable(ctx *Context, statVarName string,
254254
return
255255
}
256256

257+
// NewStoredProcParam creates a new Stored Procedure Parameter in the Session
257258
func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) {
258259
if _, ok := s.storedProcParams[name]; ok {
259260
return
260261
}
261262
s.storedProcParams[name] = param
262263
}
263264

265+
// GetStoredProcParam retrieves the named stored procedure parameter, from the Session, returning nil if not found.
264266
func (s *BaseSession) GetStoredProcParam(name string) *StoredProcParam {
265267
if param, ok := s.storedProcParams[name]; ok {
266268
return param
267269
}
268270
return nil
269271
}
270272

273+
// SetStoredProcParam sets the named Stored Procedure Parameter from the Session to val and marks it as HasSet.
274+
// If the Parameter has not been initialized, this will throw an error.
271275
func (s *BaseSession) SetStoredProcParam(name string, val any) error {
272276
param := s.GetStoredProcParam(name)
273277
if param == nil {

sql/core.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,17 @@ func IncrementStatusVariable(ctx *Context, name string, val int) {
880880
ctx.Session.IncrementStatusVariable(ctx, name, val)
881881
}
882882

883+
// StoredProcParam is a Parameter for a Stored Procedure.
884+
// Stored Procedures Parameters can be referenced from within other Stored Procedures, so we need to store them
885+
// somewhere that is accessible between interpreter calls to the engine.
883886
type StoredProcParam struct {
884887
Type Type
885888
Value any
886889
HasBeenSet bool
887890
Reference *StoredProcParam
888891
}
889892

893+
// SetValue saves val to the StoredProcParam, and set HasBeenSet to true.
890894
func (s *StoredProcParam) SetValue(val any) {
891895
s.Value = val
892896
s.HasBeenSet = true

sql/plan/call.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ import (
2323
"github.com/dolthub/go-mysql-server/sql/types"
2424
)
2525

26-
// TODO: we need different types of calls: one for external procedures one for stored procedures
27-
2826
type Call struct {
2927
db sql.Database
3028
Name string
@@ -39,7 +37,7 @@ type Call struct {
3937
Runner sql.StatementRunner
4038
Ops []procedures.InterpreterOperation
4139

42-
// TODO: sure whatever
40+
// retain the result schema
4341
resSch sql.Schema
4442
}
4543

sql/plan/procedure.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ func (p *Procedure) Resolved() bool {
138138
return true
139139
}
140140

141+
// IsReadOnly implements the sql.Node interface.
141142
func (p *Procedure) IsReadOnly() bool {
142143
if p.ExternalProc != nil {
143144
return p.ExternalProc.IsReadOnly()

sql/procedures/interpreter_logic.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ type InterpreterNode interface {
4040
SetSchema(sch sql.Schema)
4141
}
4242

43+
// replaceVariablesInExpr will search for every ast.Node and handle each one on a case by case basis.
44+
// If a new ast.Node is added to the vitess parser we may need to add a case for it here.
4345
func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.SQLNode, asOf *ast.AsOf) (ast.SQLNode, error) {
4446
switch e := expr.(type) {
4547
case *ast.ColName:
@@ -286,6 +288,9 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq
286288
if rErr == io.EOF {
287289
break
288290
}
291+
if cErr := rowIter.Close(ctx); cErr != nil {
292+
return nil, nil, cErr
293+
}
289294
return nil, nil, rErr
290295
}
291296
rows = append(rows, row)
@@ -402,6 +407,9 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
402407
return 0, nil, nil, nil, err
403408
}
404409
if _, err = rowIter.Next(ctx); err != io.EOF {
410+
if rErr := rowIter.Close(ctx); rErr != nil {
411+
return 0, nil, nil, nil, rErr
412+
}
405413
return 0, nil, nil, nil, err
406414
}
407415
if err = rowIter.Close(ctx); err != nil {
@@ -507,7 +515,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
507515
return 0, nil, nil, nil, fmt.Errorf("warnings not yet implemented")
508516
}
509517
} else {
510-
cond := stack.GetCondition(strings.ToLower(signalStmt.ConditionName))
518+
cond := stack.GetCondition(signalStmt.ConditionName)
511519
if cond == nil {
512520
return 0, nil, nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName)
513521
}
@@ -587,7 +595,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
587595

588596
case OpCode_Open:
589597
openCur := operation.PrimaryData.(*ast.OpenCursor)
590-
cursor := stack.GetCursor(strings.ToLower(openCur.Name))
598+
cursor := stack.GetCursor(openCur.Name)
591599
if cursor == nil {
592600
return 0, nil, nil, nil, sql.ErrCursorNotFound.New(openCur.Name)
593601
}
@@ -607,7 +615,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
607615

608616
case OpCode_Fetch:
609617
fetchCur := operation.PrimaryData.(*ast.FetchCursor)
610-
cursor := stack.GetCursor(strings.ToLower(fetchCur.Name))
618+
cursor := stack.GetCursor(fetchCur.Name)
611619
if cursor == nil {
612620
return 0, nil, nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name)
613621
}
@@ -641,7 +649,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
641649

642650
case OpCode_Close:
643651
closeCur := operation.PrimaryData.(*ast.CloseCursor)
644-
cursor := stack.GetCursor(strings.ToLower(closeCur.Name))
652+
cursor := stack.GetCursor(closeCur.Name)
645653
if cursor == nil {
646654
return 0, nil, nil, nil, sql.ErrCursorNotFound.New(closeCur.Name)
647655
}
@@ -680,7 +688,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
680688
return 0, nil, nil, nil, err
681689
}
682690

683-
err = stack.SetVariable(strings.ToLower(operation.Target), row[0])
691+
err = stack.SetVariable(operation.Target, row[0])
684692
if err != nil {
685693
err = ctx.Session.SetStoredProcParam(operation.Target, row[0])
686694
if err != nil {

sql/procedures/interpreter_stack.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package procedures
1717
import (
1818
"fmt"
1919
"strconv"
20+
"strings"
2021

2122
"github.com/dolthub/go-mysql-server/sql"
2223
"github.com/dolthub/go-mysql-server/sql/types"
@@ -205,6 +206,7 @@ func (is *InterpreterStack) NewVariableAlias(alias string, variable *Interpreter
205206
// GetVariable traverses the stack (starting from the top) to find a variable with a matching name. Returns nil if no
206207
// variable was found.
207208
func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable {
209+
name = strings.ToLower(name)
208210
for i := 0; i < is.stack.Len(); i++ {
209211
if iv, ok := is.stack.PeekDepth(i).variables[name]; ok {
210212
return iv
@@ -248,6 +250,7 @@ func (is *InterpreterStack) NewCondition(name string, sqlState string, mysqlErrC
248250
// GetCondition traverses the stack (starting from the top) to find a condition with a matching name. Returns nil if no
249251
// variable was found.
250252
func (is *InterpreterStack) GetCondition(name string) *InterpreterCondition {
253+
name = strings.ToLower(name)
251254
for i := 0; i < is.stack.Len(); i++ {
252255
if ic, ok := is.stack.PeekDepth(i).conditions[name]; ok {
253256
return ic
@@ -266,6 +269,7 @@ func (is *InterpreterStack) NewCursor(name string, selStmt ast.SelectStatement)
266269
// GetCursor traverses the stack (starting from the top) to find a condition with a matching name. Returns nil if no
267270
// variable was found.
268271
func (is *InterpreterStack) GetCursor(name string) *InterpreterCursor {
272+
name = strings.ToLower(name)
269273
for i := 0; i < is.stack.Len(); i++ {
270274
if ic, ok := is.stack.PeekDepth(i).cursors[name]; ok {
271275
return ic

sql/rowexec/proc.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,13 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq
262262

263263
// We might close transactions in the procedure, so we need to start a new one if we're not in one already
264264
if sess, ok := ctx.Session.(sql.TransactionSession); ok {
265-
tx, tErr := sess.StartTransaction(ctx, sql.ReadWrite)
266-
if tErr != nil {
267-
return nil, tErr
265+
if tx := ctx.GetTransaction(); tx == nil {
266+
tx, err = sess.StartTransaction(ctx, sql.ReadWrite)
267+
if err != nil {
268+
return nil, err
269+
}
270+
ctx.SetTransaction(tx)
268271
}
269-
ctx.SetTransaction(tx)
270272
}
271273

272274
return &callIter{

0 commit comments

Comments
 (0)