@@ -39,6 +39,42 @@ type Parameter struct {
3939 Value any
4040}
4141
42+ func unreplaceVariablesInExpr (stack * InterpreterStack , expr ast.SQLNode ) ast.SQLNode {
43+ switch e := expr .(type ) {
44+ case * ast.AliasedExpr :
45+ newExpr := unreplaceVariablesInExpr (stack , e .Expr )
46+ e .Expr = newExpr .(ast.Expr )
47+ case * ast.BinaryExpr :
48+ newLeftExpr := unreplaceVariablesInExpr (stack , e .Left )
49+ newRightExpr := unreplaceVariablesInExpr (stack , e .Right )
50+ e .Left = newLeftExpr .(ast.Expr )
51+ e .Right = newRightExpr .(ast.Expr )
52+ case * ast.ComparisonExpr :
53+ newLeftExpr := unreplaceVariablesInExpr (stack , e .Left )
54+ newRightExpr := unreplaceVariablesInExpr (stack , e .Right )
55+ e .Left = newLeftExpr .(ast.Expr )
56+ e .Right = newRightExpr .(ast.Expr )
57+ case * ast.FuncExpr :
58+ for i := range e .Exprs {
59+ newExpr := unreplaceVariablesInExpr (stack , e .Exprs [i ])
60+ e .Exprs [i ] = newExpr .(ast.SelectExpr )
61+ }
62+ case * ast.NotExpr :
63+ newExpr := unreplaceVariablesInExpr (stack , e .Expr )
64+ e .Expr = newExpr .(ast.Expr )
65+ case * ast.Set :
66+ for _ , setExpr := range e .Exprs {
67+ newExpr := unreplaceVariablesInExpr (stack , setExpr .Expr )
68+ setExpr .Expr = newExpr .(ast.Expr )
69+ }
70+ case * ast.SQLVal :
71+ if oldVal , ok := stack .replaceMap [expr ]; ok {
72+ return oldVal
73+ }
74+ }
75+ return expr
76+ }
77+
4278func replaceVariablesInExpr (stack * InterpreterStack , expr ast.SQLNode ) (ast.SQLNode , error ) {
4379 switch e := expr .(type ) {
4480 case * ast.AliasedExpr :
@@ -69,18 +105,39 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN
69105 }
70106 e .Left = newLeftExpr .(ast.Expr )
71107 e .Right = newRightExpr .(ast.Expr )
108+ case * ast.FuncExpr :
109+ for i := range e .Exprs {
110+ newExpr , err := replaceVariablesInExpr (stack , e .Exprs [i ])
111+ if err != nil {
112+ return nil , err
113+ }
114+ e .Exprs [i ] = newExpr .(ast.SelectExpr )
115+ }
72116 case * ast.NotExpr :
73117 newExpr , err := replaceVariablesInExpr (stack , e .Expr )
74118 if err != nil {
75119 return nil , err
76120 }
77121 e .Expr = newExpr .(ast.Expr )
122+ case * ast.Set :
123+ for _ , setExpr := range e .Exprs {
124+ newExpr , err := replaceVariablesInExpr (stack , setExpr .Expr )
125+ if err != nil {
126+ return nil , err
127+ }
128+ err = stack .SetVariable (nil , setExpr .Name .String (), newExpr )
129+ if err != nil {
130+ return nil , err
131+ }
132+ }
78133 case * ast.ColName :
79134 iv := stack .GetVariable (e .Name .String ())
80135 if iv == nil {
81136 return expr , nil
82137 }
83- return iv .ToAST (), nil
138+ newExpr := iv .ToAST ()
139+ stack .replaceMap [newExpr ] = e
140+ return newExpr , nil
84141 }
85142 return expr , nil
86143}
@@ -149,15 +206,19 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
149206 return nil , err
150207 }
151208 rowIters = append (rowIters , rowIter )
209+
210+ for i := range selectStmt .SelectExprs {
211+ newNode := unreplaceVariablesInExpr (& stack , selectStmt .SelectExprs [i ])
212+ selectStmt .SelectExprs [i ] = newNode .(ast.SelectExpr )
213+ }
214+ stack .replaceMap = map [ast.SQLNode ]ast.SQLNode {}
215+
152216 case OpCode_Declare :
153217 declareStmt := operation .PrimaryData .(* ast.Declare )
154218 for _ , decl := range declareStmt .Variables .Names {
155- var varType sql.Type
156- switch declareStmt .Variables .VarType .Type {
157- case "int" :
158- varType = types .Int32
159- default :
160- panic ("unimplemented type" )
219+ varType , err := types .ColumnTypeToType (& declareStmt .Variables .VarType )
220+ if err != nil {
221+ return nil , err
161222 }
162223 varName := strings .ToLower (decl .String ())
163224 if declareStmt .Variables .VarType .Default != nil {
@@ -166,15 +227,61 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
166227 stack .NewVariable (varName , varType )
167228 }
168229 }
230+ case OpCode_Set :
231+ selectStmt := operation .PrimaryData .(* ast.Select )
232+ if selectStmt .SelectExprs == nil {
233+ panic ("select stmt with no select exprs" )
234+ }
235+ for i := range selectStmt .SelectExprs {
236+ newNode , err := replaceVariablesInExpr (& stack , selectStmt .SelectExprs [i ])
237+ if err != nil {
238+ return nil , err
239+ }
240+ selectStmt .SelectExprs [i ] = newNode .(ast.SelectExpr )
241+ }
242+ _ , rowIter , _ , err := runner .QueryWithBindings (ctx , "" , selectStmt , nil , nil )
243+ if err != nil {
244+ return nil , err
245+ }
246+ row , err := rowIter .Next (ctx )
247+ if err != nil {
248+ return nil , err
249+ }
250+ if _ , err = rowIter .Next (ctx ); err != io .EOF {
251+ return nil , err
252+ }
253+ if err = rowIter .Close (ctx ); err != nil {
254+ return nil , err
255+ }
256+
257+ err = stack .SetVariable (nil , strings .ToLower (operation .Target ), row [0 ])
258+ if err != nil {
259+ return nil , err
260+ }
261+
262+ for i := range selectStmt .SelectExprs {
263+ newNode := unreplaceVariablesInExpr (& stack , selectStmt .SelectExprs [i ])
264+ selectStmt .SelectExprs [i ] = newNode .(ast.SelectExpr )
265+ }
266+ stack .replaceMap = map [ast.SQLNode ]ast.SQLNode {}
267+
169268 case OpCode_Exception :
170269 // TODO: implement
171270 case OpCode_Execute :
172271 // TODO: replace variables
173- rowIter , err := query (ctx , runner , operation .PrimaryData )
272+ stmt , err := replaceVariablesInExpr (& stack , operation .PrimaryData )
273+ if err != nil {
274+ return nil , err
275+ }
276+ rowIter , err := query (ctx , runner , stmt .(ast.Statement ))
174277 if err != nil {
175278 return nil , err
176279 }
177280 rowIters = append (rowIters , rowIter )
281+
282+ stmt = unreplaceVariablesInExpr (& stack , stmt )
283+ stack .replaceMap = map [ast.SQLNode ]ast.SQLNode {}
284+
178285 case OpCode_Goto :
179286 // We must compare to the index - 1, so that the increment hits our target
180287 if counter <= operation .Index {
@@ -233,6 +340,13 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er
233340 if ! cond {
234341 counter = operation .Index - 1 // index of the else block, offset by 1
235342 }
343+
344+ for i := range selectStmt .SelectExprs {
345+ newNode := unreplaceVariablesInExpr (& stack , selectStmt .SelectExprs [i ])
346+ selectStmt .SelectExprs [i ] = newNode .(ast.SelectExpr )
347+ }
348+ stack .replaceMap = map [ast.SQLNode ]ast.SQLNode {}
349+
236350 case OpCode_ScopeBegin :
237351 stack .PushScope ()
238352 case OpCode_ScopeEnd :
0 commit comments