Skip to content

Commit c127a7e

Browse files
author
James Cor
committed
Merge branch 'james/proc' of github.com:dolthub/go-mysql-server into james/proc
2 parents 4823d27 + a34f76e commit c127a7e

File tree

4 files changed

+77
-0
lines changed

4 files changed

+77
-0
lines changed

sql/procedures/interpreter_logic.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.
147147
if e.AsOf == nil && asOf != nil {
148148
e.AsOf = asOf.Time
149149
}
150+
if len(e.ProcName.Qualifier.String()) == 0 {
151+
e.ProcName.Qualifier = ast.NewTableIdent(stack.GetDatabase())
152+
}
150153
case *ast.Limit:
151154
newOffset, err := replaceVariablesInExpr(ctx, stack, e.Offset, asOf)
152155
if err != nil {
@@ -251,12 +254,18 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.
251254
e.Values[i] = newExpr.(ast.ValTuple)
252255
}
253256
case *ast.Insert:
257+
if asOf != nil {
258+
return nil, sql.ErrProcedureCallAsOfReadOnly.New()
259+
}
254260
newExpr, err := replaceVariablesInExpr(ctx, stack, e.Rows, asOf)
255261
if err != nil {
256262
return nil, err
257263
}
258264
e.Rows = newExpr.(ast.InsertRows)
259265
case *ast.Delete:
266+
if asOf != nil {
267+
return nil, sql.ErrProcedureCallAsOfReadOnly.New()
268+
}
260269
if e.Where != nil {
261270
newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf)
262271
if err != nil {
@@ -265,6 +274,9 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.
265274
e.Where.Expr = newExpr.(ast.Expr)
266275
}
267276
case *ast.Update:
277+
if asOf != nil {
278+
return nil, sql.ErrProcedureCallAsOfReadOnly.New()
279+
}
268280
if e.Where != nil {
269281
newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf)
270282
if err != nil {
@@ -706,6 +718,42 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac
706718
}
707719
}
708720

721+
case OpCode_Call:
722+
stmt, err := replaceVariablesInExpr(ctx, stack, operation.PrimaryData, asOf)
723+
if err != nil {
724+
return 0, nil, nil, nil, err
725+
}
726+
// put stack variables into session variables
727+
callStmt := stmt.(*ast.Call)
728+
stackToParam := make(map[*InterpreterVariable]*sql.StoredProcParam)
729+
for _, param := range callStmt.Params {
730+
colName, isColName := param.(*ast.ColName)
731+
if !isColName {
732+
continue
733+
}
734+
paramName := colName.Name.String()
735+
iv := stack.GetVariable(paramName)
736+
if iv == nil {
737+
continue
738+
}
739+
spp := &sql.StoredProcParam{
740+
Type: iv.Type,
741+
Value: iv.Value,
742+
}
743+
ctx.Session.NewStoredProcParam(paramName, spp)
744+
stackToParam[iv] = spp
745+
}
746+
sch, rowIter, err := query(ctx, runner, callStmt)
747+
if err != nil {
748+
return 0, nil, nil, nil, err
749+
}
750+
// assign stored proc params to stack variables
751+
for iv, spp := range stackToParam {
752+
iv.Value = spp.Value
753+
}
754+
755+
return counter, sch, nil, rowIter, err
756+
709757
case OpCode_If:
710758
selectStmt := operation.PrimaryData.(*ast.Select)
711759
if selectStmt.SelectExprs == nil {
@@ -828,6 +876,9 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta
828876
var retSch sql.Schema
829877
runner := iNode.GetRunner()
830878
statements := iNode.GetStatements()
879+
if dbNode, isDbNode := iNode.(sql.Databaser); isDbNode {
880+
stack.SetDatabase(dbNode.Database().Name())
881+
}
831882
for {
832883
counter++
833884
if counter < 0 {

sql/procedures/interpreter_operation.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const (
2323
OpCode_Fetch
2424
OpCode_Close
2525
OpCode_Set
26+
OpCode_Call
2627
OpCode_If
2728
OpCode_Goto
2829
OpCode_Execute

sql/procedures/interpreter_stack.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ type InterpreterScopeDetails struct {
152152

153153
// labels mark the counter of the start of a loop or block.
154154
labels map[string]int
155+
156+
// database is the current database for this scope.
157+
database string
155158
}
156159

157160
// InterpreterStack represents the working information that an interpreter will use during execution. It is not exactly
@@ -304,6 +307,21 @@ func (is *InterpreterStack) NewLabel(name string, index int) {
304307
is.stack.Peek().labels[name] = index
305308
}
306309

310+
// GetDatabase returns the current database for this scope.
311+
func (is *InterpreterStack) GetDatabase() string {
312+
for i := 0; i < is.stack.Len(); i++ {
313+
if db := is.stack.PeekDepth(i).database; db != "" {
314+
return db
315+
}
316+
}
317+
return ""
318+
}
319+
320+
// SetDatabase sets the current database for this scope.
321+
func (is *InterpreterStack) SetDatabase(db string) {
322+
is.stack.Peek().database = db
323+
}
324+
307325
// GetLabel traverses the stack (starting from the top) to find a label with a matching name. Returns -1 if no
308326
// variable was found.
309327
func (is *InterpreterStack) GetLabel(name string) int {

sql/procedures/parse.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast
141141
}
142142
*ops = append(*ops, setOp)
143143

144+
case *ast.Call:
145+
callOp := &InterpreterOperation{
146+
OpCode: OpCode_Call,
147+
PrimaryData: s,
148+
}
149+
*ops = append(*ops, callOp)
150+
144151
case *ast.IfStatement:
145152
var ifElseGotoOps []*InterpreterOperation
146153
for _, ifCond := range s.Conditions {

0 commit comments

Comments
 (0)