@@ -101,6 +101,7 @@ func (s *Smither) makePLpgSQLStatements(scope plpgsqlBlockScope, maxCount int) [
101
101
102
102
func (s * Smither ) makePLpgSQLIf (scope plpgsqlBlockScope ) * ast.If {
103
103
const maxBranchStmts = 3
104
+ scope .scopeMetas = append (scope .scopeMetas , scopeMeta {typ : ifScope })
104
105
ifStmt := & ast.If {
105
106
Condition : s .makePLpgSQLCond (scope ),
106
107
ThenBody : s .makePLpgSQLStatements (scope , maxBranchStmts ),
@@ -144,6 +145,8 @@ var (
144
145
{5 , makePLpgSQLNull },
145
146
{10 , makePLpgSQLAssign },
146
147
{10 , makePLpgSQLExecSQL },
148
+ {2 , makePLpgSQLExit },
149
+ {2 , makePLpgSQLContinue },
147
150
}
148
151
)
149
152
@@ -226,6 +229,30 @@ func makePLpgSQLNull(_ *Smither, _ plpgsqlBlockScope) (stmt ast.Statement, ok bo
226
229
return & ast.Null {}, true
227
230
}
228
231
232
+ func makePLpgSQLExit (s * Smither , scope plpgsqlBlockScope ) (stmt ast.Statement , ok bool ) {
233
+ if ! scope .inLoop () {
234
+ // EXIT statements can only be used within loops.
235
+ return nil , false
236
+ }
237
+ res := & ast.Exit {
238
+ // TODO(#106368): optionally add a label.
239
+ Condition : s .makePLpgSQLCond (scope ),
240
+ }
241
+ return res , true
242
+ }
243
+
244
+ func makePLpgSQLContinue (s * Smither , scope plpgsqlBlockScope ) (stmt ast.Statement , ok bool ) {
245
+ if ! scope .inLoop () {
246
+ // CONTINUE statements can only be used within loops.
247
+ return nil , false
248
+ }
249
+ res := & ast.Continue {
250
+ // TODO(#106368): optionally add a label.
251
+ Condition : s .makePLpgSQLCond (scope ),
252
+ }
253
+ return res , true
254
+ }
255
+
229
256
func makePLpgSQLForLoop (s * Smither , scope plpgsqlBlockScope ) (stmt ast.Statement , ok bool ) {
230
257
// TODO(#105246): add support for other query and cursor FOR loops.
231
258
control := ast.IntForLoopControl {
@@ -239,6 +266,7 @@ func makePLpgSQLForLoop(s *Smither, scope plpgsqlBlockScope) (stmt ast.Statement
239
266
newScope := scope .makeChild (1 /* numNewVars */ )
240
267
loopVarName := s .makePLpgSQLVarName ("loop" , newScope )
241
268
newScope .addVariable (string (loopVarName ), types .Int , false /* constant */ )
269
+ newScope .scopeMetas = append (newScope .scopeMetas , scopeMeta {typ : loopScope , name : string (loopVarName )})
242
270
const maxLoopStmts = 3
243
271
return & ast.ForLoop {
244
272
// TODO(#106368): optionally add a label.
@@ -249,14 +277,35 @@ func makePLpgSQLForLoop(s *Smither, scope plpgsqlBlockScope) (stmt ast.Statement
249
277
}
250
278
251
279
func makePLpgSQLWhile (s * Smither , scope plpgsqlBlockScope ) (stmt ast.Statement , ok bool ) {
280
+ newScope := scope .makeChild (1 /* numNewVars */ )
281
+ loopVarName := s .makePLpgSQLVarName ("loop" , newScope )
282
+ newScope .scopeMetas = append (newScope .scopeMetas , scopeMeta {typ : loopScope , name : string (loopVarName )})
252
283
const maxLoopStmts = 3
253
284
return & ast.While {
254
285
// TODO(#106368): optionally add a label.
255
- Condition : s .makePLpgSQLCond (scope ),
256
- Body : s .makePLpgSQLStatements (scope , maxLoopStmts ),
286
+ Condition : s .makePLpgSQLCond (newScope ),
287
+ Body : s .makePLpgSQLStatements (newScope , maxLoopStmts ),
257
288
}, true
258
289
}
259
290
291
+ // scopeType is a type that represents the type of scope that the current block
292
+ // is nested within.
293
+ type scopeType int
294
+
295
+ const (
296
+ loopScope scopeType = iota
297
+ ifScope
298
+ )
299
+
300
+ // scopeMeta is a name-tagged scopeType. It is used to determine if certain
301
+ // statements are allowed in the current scope, such as EXIT and CONTINUE,
302
+ // which can only be used within loops.
303
+ type scopeMeta struct {
304
+ typ scopeType
305
+ // // TODO(#106368): propagate `name` with the loop label.
306
+ name string
307
+ }
308
+
260
309
// plpgsqlBlockScope holds the information needed to ensure that generated
261
310
// statements obey PL/pgSQL syntax and scoping rules.
262
311
type plpgsqlBlockScope struct {
@@ -270,6 +319,12 @@ type plpgsqlBlockScope struct {
270
319
// current block.
271
320
vars []string
272
321
322
+ // scopeMetas tracks the nested scopes of the current block. For example,
323
+ // entering a FOR loop adds a loopScope with the loop label name, while
324
+ // entering an IF statement adds an ifScope without a name. Used to validate
325
+ // statements like EXIT and CONTINUE, which are only allowed in loops.
326
+ scopeMetas []scopeMeta
327
+
273
328
// refs is the list of colRefs for every variable in the current scope. It
274
329
// could be rebuilt from the vars and varTypes fields, but is kept up-to-date
275
330
// here for convenience.
@@ -305,9 +360,20 @@ func (s *plpgsqlBlockScope) makeChild(numNewVars int) plpgsqlBlockScope {
305
360
}
306
361
newScope .vars = append (newScope .vars , s .vars ... )
307
362
newScope .refs = append (newScope .refs , s .refs ... )
363
+ newScope .scopeMetas = append (newScope .scopeMetas , s .scopeMetas ... )
308
364
return newScope
309
365
}
310
366
367
+ // inLoop returns true if the current scope is within a for/while loop.
368
+ func (s * plpgsqlBlockScope ) inLoop () bool {
369
+ for _ , m := range s .scopeMetas {
370
+ if m .typ == loopScope {
371
+ return true
372
+ }
373
+ }
374
+ return false
375
+ }
376
+
311
377
func (s * plpgsqlBlockScope ) hasVariable (name string ) bool {
312
378
return s .varTypes [name ] != nil
313
379
}
0 commit comments