Skip to content

Commit da6db05

Browse files
author
James Cor
committed
double replace issue
1 parent b19b8b5 commit da6db05

File tree

4 files changed

+237
-12
lines changed

4 files changed

+237
-12
lines changed

sql/procedures/interpreter_logic.go

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,42 @@ type Parameter struct {
3939
Value any
4040
}
4141

42+
func unreplaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) ast.SQLNode {
43+
switch e := expr.(type) {
44+
case *ast.AliasedExpr:
45+
newExpr := unreplaceVariablesInExpr(stack, e.Expr)
46+
e.Expr = newExpr.(ast.Expr)
47+
case *ast.BinaryExpr:
48+
newLeftExpr := unreplaceVariablesInExpr(stack, e.Left)
49+
newRightExpr := unreplaceVariablesInExpr(stack, e.Right)
50+
e.Left = newLeftExpr.(ast.Expr)
51+
e.Right = newRightExpr.(ast.Expr)
52+
case *ast.ComparisonExpr:
53+
newLeftExpr := unreplaceVariablesInExpr(stack, e.Left)
54+
newRightExpr := unreplaceVariablesInExpr(stack, e.Right)
55+
e.Left = newLeftExpr.(ast.Expr)
56+
e.Right = newRightExpr.(ast.Expr)
57+
case *ast.FuncExpr:
58+
for i := range e.Exprs {
59+
newExpr := unreplaceVariablesInExpr(stack, e.Exprs[i])
60+
e.Exprs[i] = newExpr.(ast.SelectExpr)
61+
}
62+
case *ast.NotExpr:
63+
newExpr := unreplaceVariablesInExpr(stack, e.Expr)
64+
e.Expr = newExpr.(ast.Expr)
65+
case *ast.Set:
66+
for _, setExpr := range e.Exprs {
67+
newExpr := unreplaceVariablesInExpr(stack, setExpr.Expr)
68+
setExpr.Expr = newExpr.(ast.Expr)
69+
}
70+
case *ast.SQLVal:
71+
if oldVal, ok := stack.replaceMap[expr]; ok {
72+
return oldVal
73+
}
74+
}
75+
return expr
76+
}
77+
4278
func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) {
4379
switch e := expr.(type) {
4480
case *ast.AliasedExpr:
@@ -69,18 +105,39 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN
69105
}
70106
e.Left = newLeftExpr.(ast.Expr)
71107
e.Right = newRightExpr.(ast.Expr)
108+
case *ast.FuncExpr:
109+
for i := range e.Exprs {
110+
newExpr, err := replaceVariablesInExpr(stack, e.Exprs[i])
111+
if err != nil {
112+
return nil, err
113+
}
114+
e.Exprs[i] = newExpr.(ast.SelectExpr)
115+
}
72116
case *ast.NotExpr:
73117
newExpr, err := replaceVariablesInExpr(stack, e.Expr)
74118
if err != nil {
75119
return nil, err
76120
}
77121
e.Expr = newExpr.(ast.Expr)
122+
case *ast.Set:
123+
for _, setExpr := range e.Exprs {
124+
newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr)
125+
if err != nil {
126+
return nil, err
127+
}
128+
err = stack.SetVariable(nil, setExpr.Name.String(), newExpr)
129+
if err != nil {
130+
return nil, err
131+
}
132+
}
78133
case *ast.ColName:
79134
iv := stack.GetVariable(e.Name.String())
80135
if iv == nil {
81136
return expr, nil
82137
}
83-
return iv.ToAST(), nil
138+
newExpr := iv.ToAST()
139+
stack.replaceMap[newExpr] = e
140+
return newExpr, nil
84141
}
85142
return expr, nil
86143
}
@@ -149,15 +206,19 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
149206
return nil, err
150207
}
151208
rowIters = append(rowIters, rowIter)
209+
210+
for i := range selectStmt.SelectExprs {
211+
newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
212+
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
213+
}
214+
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
215+
152216
case OpCode_Declare:
153217
declareStmt := operation.PrimaryData.(*ast.Declare)
154218
for _, decl := range declareStmt.Variables.Names {
155-
var varType sql.Type
156-
switch declareStmt.Variables.VarType.Type {
157-
case "int":
158-
varType = types.Int32
159-
default:
160-
panic("unimplemented type")
219+
varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType)
220+
if err != nil {
221+
return nil, err
161222
}
162223
varName := strings.ToLower(decl.String())
163224
if declareStmt.Variables.VarType.Default != nil {
@@ -166,15 +227,61 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
166227
stack.NewVariable(varName, varType)
167228
}
168229
}
230+
case OpCode_Set:
231+
selectStmt := operation.PrimaryData.(*ast.Select)
232+
if selectStmt.SelectExprs == nil {
233+
panic("select stmt with no select exprs")
234+
}
235+
for i := range selectStmt.SelectExprs {
236+
newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
237+
if err != nil {
238+
return nil, err
239+
}
240+
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
241+
}
242+
_, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil)
243+
if err != nil {
244+
return nil, err
245+
}
246+
row, err := rowIter.Next(ctx)
247+
if err != nil {
248+
return nil, err
249+
}
250+
if _, err = rowIter.Next(ctx); err != io.EOF {
251+
return nil, err
252+
}
253+
if err = rowIter.Close(ctx); err != nil {
254+
return nil, err
255+
}
256+
257+
err = stack.SetVariable(nil, strings.ToLower(operation.Target), row[0])
258+
if err != nil {
259+
return nil, err
260+
}
261+
262+
for i := range selectStmt.SelectExprs {
263+
newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
264+
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
265+
}
266+
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
267+
169268
case OpCode_Exception:
170269
// TODO: implement
171270
case OpCode_Execute:
172271
// TODO: replace variables
173-
rowIter, err := query(ctx, runner, operation.PrimaryData)
272+
stmt, err := replaceVariablesInExpr(&stack, operation.PrimaryData)
273+
if err != nil {
274+
return nil, err
275+
}
276+
rowIter, err := query(ctx, runner, stmt.(ast.Statement))
174277
if err != nil {
175278
return nil, err
176279
}
177280
rowIters = append(rowIters, rowIter)
281+
282+
stmt = unreplaceVariablesInExpr(&stack, stmt)
283+
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
284+
178285
case OpCode_Goto:
179286
// We must compare to the index - 1, so that the increment hits our target
180287
if counter <= operation.Index {
@@ -233,6 +340,13 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
233340
if !cond {
234341
counter = operation.Index - 1 // index of the else block, offset by 1
235342
}
343+
344+
for i := range selectStmt.SelectExprs {
345+
newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i])
346+
selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr)
347+
}
348+
stack.replaceMap = map[ast.SQLNode]ast.SQLNode{}
349+
236350
case OpCode_ScopeBegin:
237351
stack.PushScope()
238352
case OpCode_ScopeEnd:

sql/procedures/interpreter_operation.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ type OpCode uint16
1919
const (
2020
OpCode_Select OpCode = iota
2121
OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html
22+
OpCode_Set
2223
OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING
2324
OpCode_Execute // Everything that's not a SELECT
2425
OpCode_Goto // All control-flow structures can be represented using Goto

sql/procedures/interpreter_stack.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ type InterpreterScopeDetails struct {
123123
// general purpose.
124124
type InterpreterStack struct {
125125
stack *Stack[*InterpreterScopeDetails]
126+
replaceMap map[ast.SQLNode]ast.SQLNode
126127
}
127128

128129
// NewInterpreterStack creates a new InterpreterStack.
@@ -134,6 +135,7 @@ func NewInterpreterStack() InterpreterStack {
134135
})
135136
return InterpreterStack{
136137
stack: stack,
138+
replaceMap: map[ast.SQLNode]ast.SQLNode{},
137139
}
138140
}
139141

sql/procedures/parse.go

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,33 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast
5353
}
5454
*ops = append(*ops, declareOp)
5555

56+
case *ast.Set:
57+
if len(s.Exprs) != 1 {
58+
panic("unexpected number of set expressions")
59+
}
60+
setExpr := s.Exprs[0]
61+
var setOp *InterpreterOperation
62+
if len(setExpr.Scope) != 0 {
63+
setOp = &InterpreterOperation{
64+
OpCode: OpCode_Execute,
65+
PrimaryData: s,
66+
}
67+
} else {
68+
selectStmt := &ast.Select{
69+
SelectExprs: ast.SelectExprs{
70+
&ast.AliasedExpr{
71+
Expr: setExpr.Expr,
72+
},
73+
},
74+
}
75+
setOp = &InterpreterOperation{
76+
OpCode: OpCode_Set,
77+
PrimaryData: selectStmt,
78+
Target: setExpr.Name.String(),
79+
}
80+
}
81+
*ops = append(*ops, setOp)
82+
5683
case *ast.IfStatement:
5784
// TODO: assume exactly one condition for now
5885
ifCond := s.Conditions[0]
@@ -80,14 +107,63 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast
80107
}
81108
*ops = append(*ops, gotoOp)
82109

83-
ifOp.Index = len(*ops)
110+
ifOp.Index = len(*ops) // start of else block
84111
for _, elseStmt := range s.Else {
85112
if err := ConvertStmt(ops, stack, elseStmt); err != nil {
86113
return err
87114
}
88115
}
89116

90-
gotoOp.Index = len(*ops)
117+
gotoOp.Index = len(*ops) // end of if statement
118+
119+
case *ast.CaseStatement:
120+
var caseGotoOps []*InterpreterOperation
121+
for _, caseStmt := range s.Cases {
122+
caseExpr := caseStmt.Case
123+
if s.Expr != nil {
124+
caseExpr = &ast.ComparisonExpr{
125+
Operator: ast.EqualStr,
126+
Left: s.Expr,
127+
Right: caseExpr,
128+
}
129+
}
130+
caseCond := &ast.Select{
131+
SelectExprs: ast.SelectExprs{
132+
&ast.AliasedExpr{
133+
Expr: caseExpr,
134+
},
135+
},
136+
}
137+
caseOp := &InterpreterOperation{
138+
OpCode: OpCode_If,
139+
PrimaryData: caseCond,
140+
}
141+
*ops = append(*ops, caseOp)
142+
143+
for _, ifStmt := range caseStmt.Statements {
144+
if err := ConvertStmt(ops, stack, ifStmt); err != nil {
145+
return err
146+
}
147+
}
148+
gotoOp := &InterpreterOperation{
149+
OpCode: OpCode_Goto,
150+
}
151+
caseGotoOps = append(caseGotoOps, gotoOp)
152+
*ops = append(*ops, gotoOp)
153+
154+
caseOp.Index = len(*ops) // start of next case
155+
}
156+
if s.Else != nil {
157+
for _, elseStmt := range s.Else {
158+
if err := ConvertStmt(ops, stack, elseStmt); err != nil {
159+
return err
160+
}
161+
}
162+
}
163+
164+
for _, gotoOp := range caseGotoOps {
165+
gotoOp.Index = len(*ops) // end of case block
166+
}
91167

92168
case *ast.While:
93169
loopStart := len(*ops)
@@ -117,7 +193,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast
117193
}
118194
*ops = append(*ops, gotoOp)
119195

120-
whileOp.Index = len(*ops)
196+
whileOp.Index = len(*ops) // end of while block
121197

122198
case *ast.Repeat:
123199
loopStart := len(*ops)
@@ -148,8 +224,40 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast
148224
}
149225
*ops = append(*ops, gotoOp)
150226

151-
repeatOp.Index = len(*ops)
227+
repeatOp.Index = len(*ops) // end of repeat block
228+
229+
case *ast.Loop:
230+
loopStart := len(*ops)
231+
for _, loopStmt := range s.Statements {
232+
if err := ConvertStmt(ops, stack, loopStmt); err != nil {
233+
return err
234+
}
235+
}
236+
gotoOp := &InterpreterOperation{
237+
OpCode: OpCode_Goto,
238+
Index: loopStart,
239+
}
240+
*ops = append(*ops, gotoOp)
241+
242+
// perform second pass over loop statements to add labels
243+
for idx := loopStart; idx < len(*ops); idx++ {
244+
op := (*ops)[idx]
245+
switch op.OpCode {
246+
case OpCode_Goto:
247+
if op.Target == s.Label {
248+
(*ops)[idx].Index = len(*ops)
249+
}
250+
default:
251+
continue
252+
}
253+
}
152254

255+
case *ast.Leave:
256+
leaveOp := &InterpreterOperation{
257+
OpCode: OpCode_Goto,
258+
Target: s.Label, // hacky? way to signal a leave
259+
}
260+
*ops = append(*ops, leaveOp)
153261
default:
154262
execOp := &InterpreterOperation{
155263
OpCode: OpCode_Execute,

0 commit comments

Comments
 (0)