Skip to content

Commit bf5a416

Browse files
author
James Cor
committed
implement case errors
1 parent c41fed4 commit bf5a416

File tree

6 files changed

+24
-67
lines changed

6 files changed

+24
-67
lines changed

sql/planbuilder/scalar.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
106106
case *ast.NullVal:
107107
return expression.NewLiteral(nil, types.Null)
108108
case *ast.ColName:
109-
if v.Metadata != nil {
110-
return b.ConvertVal(v.Metadata)
109+
if v.StoredProcVal != nil {
110+
return b.ConvertVal(v.StoredProcVal)
111111
}
112112
dbName := strings.ToLower(v.Qualifier.DbQualifier.String())
113113
tblName := strings.ToLower(v.Qualifier.Name.String())

sql/planbuilder/show.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@ func (b *Builder) buildAsOfExpr(inScope *scope, time ast.Expr) sql.Expression {
604604
}
605605
return expression.NewLiteral(ret.(string), types.LongText)
606606
case *ast.ColName:
607-
if v.Metadata != nil {
608-
return b.buildAsOfExpr(inScope, v.Metadata)
607+
if v.StoredProcVal != nil {
608+
return b.buildAsOfExpr(inScope, v.StoredProcVal)
609609
}
610610
sysVar, _, ok := b.buildSysVar(v, ast.SetScope_None)
611611
if ok {

sql/procedures/interpreter_logic.go

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -39,42 +39,6 @@ type Parameter struct {
3939
Value any
4040
}
4141

42-
func unreplaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) ast.SQLNode {
43-
switch e := expr.(type) {
44-
case *ast.AliasedExpr:
45-
newExpr := unreplaceVariablesInExpr(stack, e.Expr)
46-
e.Expr = newExpr.(ast.Expr)
47-
case *ast.BinaryExpr:
48-
newLeftExpr := unreplaceVariablesInExpr(stack, e.Left)
49-
newRightExpr := unreplaceVariablesInExpr(stack, e.Right)
50-
e.Left = newLeftExpr.(ast.Expr)
51-
e.Right = newRightExpr.(ast.Expr)
52-
case *ast.ComparisonExpr:
53-
newLeftExpr := unreplaceVariablesInExpr(stack, e.Left)
54-
newRightExpr := unreplaceVariablesInExpr(stack, e.Right)
55-
e.Left = newLeftExpr.(ast.Expr)
56-
e.Right = newRightExpr.(ast.Expr)
57-
case *ast.FuncExpr:
58-
for i := range e.Exprs {
59-
newExpr := unreplaceVariablesInExpr(stack, e.Exprs[i])
60-
e.Exprs[i] = newExpr.(ast.SelectExpr)
61-
}
62-
case *ast.NotExpr:
63-
newExpr := unreplaceVariablesInExpr(stack, e.Expr)
64-
e.Expr = newExpr.(ast.Expr)
65-
case *ast.Set:
66-
for _, setExpr := range e.Exprs {
67-
newExpr := unreplaceVariablesInExpr(stack, setExpr.Expr)
68-
setExpr.Expr = newExpr.(ast.Expr)
69-
}
70-
case *ast.SQLVal:
71-
if oldVal, ok := stack.replaceMap[expr]; ok {
72-
return oldVal
73-
}
74-
}
75-
return expr
76-
}
77-
7842
func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) {
7943
switch e := expr.(type) {
8044
case *ast.AliasedExpr:
@@ -136,8 +100,11 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN
136100
return expr, nil
137101
}
138102
newExpr := iv.ToAST()
139-
stack.replaceMap[newExpr] = e
140-
return newExpr, nil
103+
return &ast.ColName{
104+
Name: e.Name,
105+
Qualifier: e.Qualifier,
106+
StoredProcVal: newExpr,
107+
}, nil
141108
}
142109
return expr, nil
143110
}
@@ -207,12 +174,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
207174
}
208175
rowIters = append(rowIters, rowIter)
209176

210-
for i := range selectStmt.SelectExprs {
211-
newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
212-
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
213-
}
214-
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
215-
216177
case OpCode_Declare:
217178
declareStmt := operation.PrimaryData.(*ast.Declare)
218179
for _, decl := range declareStmt.Variables.Names {
@@ -259,14 +220,9 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
259220
return nil, err
260221
}
261222

262-
for i := range selectStmt.SelectExprs {
263-
newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
264-
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
265-
}
266-
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
267-
268223
case OpCode_Exception:
269-
// TODO: implement
224+
return nil, operation.Error
225+
270226
case OpCode_Execute:
271227
// TODO: replace variables
272228
stmt, err := replaceVariablesInExpr(&stack, operation.PrimaryData)
@@ -279,9 +235,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
279235
}
280236
rowIters = append(rowIters, rowIter)
281237

282-
stmt = unreplaceVariablesInExpr(&stack, stmt)
283-
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
284-
285238
case OpCode_Goto:
286239
// We must compare to the index - 1, so that the increment hits our target
287240
if counter <= operation.Index {
@@ -341,12 +294,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
341294
counter = operation.Index - 1 // index of the else block, offset by 1
342295
}
343296

344-
for i := range selectStmt.SelectExprs {
345-
newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
346-
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
347-
}
348-
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
349-
350297
case OpCode_ScopeBegin:
351298
stack.PushScope()
352299
case OpCode_ScopeEnd:

sql/procedures/interpreter_operation.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,5 @@ type InterpreterOperation struct {
3636
SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc.
3737
Target string // This is the variable that will store the results (if applicable)
3838
Index int // This is the index that should be set for operations that move the function counter
39+
Error error // This is the error that should be returned for OpCode_Exception
3940
}

sql/procedures/interpreter_stack.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ type InterpreterScopeDetails struct {
123123
// general purpose.
124124
type InterpreterStack struct {
125125
stack *Stack[*InterpreterScopeDetails]
126-
replaceMap map[ast.SQLNode]ast.SQLNode
127126
}
128127

129128
// NewInterpreterStack creates a new InterpreterStack.
@@ -135,7 +134,6 @@ func NewInterpreterStack() InterpreterStack {
135134
})
136135
return InterpreterStack{
137136
stack: stack,
138-
replaceMap: map[ast.SQLNode]ast.SQLNode{},
139137
}
140138
}
141139

sql/procedures/parse.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package procedures
1616

1717
import (
18+
"github.com/dolthub/vitess/go/mysql"
19+
1820
ast "github.com/dolthub/vitess/go/vt/sqlparser"
1921
)
2022

@@ -153,7 +155,16 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast
153155

154156
caseOp.Index = len(*ops) // start of next case
155157
}
156-
if s.Else != nil {
158+
if s.Else == nil {
159+
// throw an error if when there is no else block
160+
// this is just an empty case statement that will always hit the else
161+
// todo: alternatively, use an error opcode
162+
errOp := &InterpreterOperation{
163+
OpCode: OpCode_Exception,
164+
Error: mysql.NewSQLError(1339, "20000", "Case not found for CASE statement"),
165+
}
166+
*ops = append(*ops, errOp)
167+
} else {
157168
for _, elseStmt := range s.Else {
158169
if err := ConvertStmt(ops, stack, elseStmt); err != nil {
159170
return err

0 commit comments

Comments
 (0)