2525 maybe_cast ,
2626 get_result_or_results ,
2727 get_user_code_loc ,
28+ region_adder ,
2829)
2930
3031logger = logging .getLogger (__name__ )
@@ -109,19 +110,24 @@ def yield_(*args):
109110 return results [0 ]
110111
111112
112- def _if (cond , results_ = None , * , has_else = False , loc = None , ip = None ):
113- if results_ is None :
114- results_ = []
115- if results_ :
113+ def _if (cond , results = None , * , has_else = False , loc = None , ip = None ):
114+ if results is None :
115+ results = []
116+ if results :
116117 has_else = True
117118 if loc is None :
118119 loc = get_user_code_loc ()
119- return IfOp (cond , results_ , hasElse = has_else , loc = loc , ip = ip )
120+ return IfOp (cond , results , hasElse = has_else , loc = loc , ip = ip )
120121
121122
122123if_ = region_op (_if , terminator = yield__ )
123124
124125
126+ @region_adder (terminator = yield__ )
127+ def else_ (ifop ):
128+ return ifop .regions [1 ]
129+
130+
125131class IpStack :
126132 # __current_if_op: list[IfOp]
127133 __if_ips : list [InsertionPoint ]
@@ -179,23 +185,6 @@ def unstack_else_if(prev_ips_ifop: tuple[IpStack, IfOp], cond: Value, results_=N
179185 return prev_ips + next_if_ip , next_if_op
180186
181187
182- def get_last_statement (original_node ):
183- statements = m .findall (original_node , m .SimpleStatementLine ())
184- assert len (statements ), "no statements...?"
185- return statements [- 1 ]
186-
187-
188- def insert_in_deep_last_statement (
189- original_node : cst .CSTNode ,
190- new_node : cst .CSTNode ,
191- ) -> cst .CSTNode :
192- last_statement = get_last_statement (original_node )
193- new_last_statement = last_statement .with_changes (
194- body = list (last_statement .body ) + [cst .Expr (new_node )]
195- )
196- return original_node .deep_replace (last_statement , new_last_statement )
197-
198-
199188class ReplaceYieldWithSCFYield (StrictTransformer ):
200189 @m .call_if_inside (m .If ())
201190 @m .leave (m .Yield (value = m .Tuple ()))
@@ -226,41 +215,35 @@ def single_yield(self, original_node: cst.Yield, _updated_node: cst.Yield):
226215 return ast_call (yield_ .__name__ , args )
227216
228217
229- def maybe_insert_yield_at_end_or_deep (node ):
218+ def maybe_insert_yield_at_end (node ):
230219 maybe_last_statement = node .body [- 1 ]
231220 if m .matches (maybe_last_statement , m .SimpleStatementLine ()):
232221 if len (m .findall (maybe_last_statement , m .Yield ())) > 0 :
233222 return node
234223
235- # if last thing in body is a simplestatement then you can talk the yield (with a semicolon)
224+ # if last thing in body is a simplestatement then you can tack the yield (with a semicolon)
236225 # onto the end
237- new_maybe_last_statement = insert_in_deep_last_statement (
238- maybe_last_statement , cst .Yield ()
226+ new_maybe_last_statement = maybe_last_statement . with_changes (
227+ body = list ( maybe_last_statement . body ) + [ cst .Expr ( cst . Yield ())]
239228 )
240- node = node .deep_replace (maybe_last_statement , new_maybe_last_statement )
229+ return node .deep_replace (maybe_last_statement , new_maybe_last_statement )
241230 else :
242- # this branch is different (i.e., doesn't check for a match)
243- # because if the last thing is an indented block, there's no way the user could've intentionally placed
244- # a yield there that handles this conditional (even if they placed a yield to handle a conditional in that
245- # last block)
246- node = insert_in_deep_last_statement (node , cst .Yield ())
247-
248- return node
231+ raise RuntimeError ("primitive must have statement as last line" )
249232
250233
251234class InsertEmptyYield (StrictTransformer ):
252235 @m .leave (m .If ())
253236 def leave_if (self , _original_node : cst .If , updated_node : cst .If ) -> cst .If :
254- new_body = maybe_insert_yield_at_end_or_deep (updated_node .body )
237+ new_body = maybe_insert_yield_at_end (updated_node .body )
255238 new_orelse = updated_node .orelse
256239 if new_orelse :
257- new_orelse_body = maybe_insert_yield_at_end_or_deep (new_orelse .body )
240+ new_orelse_body = maybe_insert_yield_at_end (new_orelse .body )
258241 new_orelse = new_orelse .with_changes (body = new_orelse_body )
259242 return updated_node .with_changes (body = new_body , orelse = new_orelse )
260243
261244 @m .leave (m .For ())
262245 def leave_for (self , _original_node : cst .For , updated_node : cst .For ) -> cst .For :
263- new_body = maybe_insert_yield_at_end_or_deep (updated_node .body )
246+ new_body = maybe_insert_yield_at_end (updated_node .body )
264247 return updated_node .with_changes (body = new_body )
265248
266249
@@ -269,57 +252,15 @@ class CheckMatchingYields(StrictTransformer):
269252 def leave_ (self , original_node : cst .If , _updated_node : cst .If ) -> cst .If :
270253 n_ifs = len (m .findall (original_node , m .If ()))
271254 n_elses = len (m .findall (original_node , m .Else ()))
255+ n_fors = len (m .findall (original_node , m .For ()))
272256 n_yields = len (m .findall (original_node , m .Call (func = m .Name (yield_ .__name__ ))))
273- if n_ifs + n_elses < = n_yields :
274- warnings . warn (
275- f"unmatched if/elses and yields: { n_ifs = } { n_elses = } { n_yields = } ; line { self .get_pos (original_node ).start .line } "
257+ if n_ifs + n_elses + n_fors ! = n_yields :
258+ raise RuntimeError (
259+ f"unmatched if/elses and yields: { n_ifs = } { n_elses = } { n_fors = } { n_yields = } ; line { self .get_pos (original_node ).start .line } "
276260 )
277261 return original_node
278262
279263
280- def check_unstack_if (original_node , metadata_resolver ):
281- return m .matches (
282- original_node ,
283- m .If (
284- test = m .NamedExpr (
285- target = m .MatchMetadataIfTrue (
286- QualifiedNameProvider ,
287- lambda qualnames : any (
288- unstack_if .__name__ in n .name
289- or unstack_else_if .__name__ in n .name
290- for n in qualnames
291- ),
292- )
293- )
294- ),
295- metadata_resolver = metadata_resolver ,
296- )
297-
298-
299- class CanonicalizeElIfTests (StrictTransformer ):
300- @m .call_if_inside (m .If (orelse = m .If ()))
301- @m .leave (m .If ())
302- def leave_last_elif (self , original_node : cst .If , updated_node : cst .If ) -> cst .If :
303- assert check_unstack_if (
304- original_node , self
305- ), f"if must already have had test replaced with unstack_if"
306- parent = self .get_parent (original_node )
307- if (
308- not check_unstack_if (parent , self )
309- # you need this because call_if_inside matches self as well as parent
310- or parent .orelse != original_node
311- ):
312- return updated_node
313-
314- test = updated_node .test
315- new_test_call = ast_call (
316- unstack_else_if .__name__ ,
317- args = [cst .Arg (parent .test .target )] + list (updated_node .test .value .args ),
318- )
319- new_test = test .with_changes (value = new_test_call )
320- return updated_node .with_changes (test = new_test )
321-
322-
323264class ReplaceSCFCond (StrictTransformer ):
324265 @m .leave (
325266 m .If (
@@ -337,16 +278,18 @@ def insert_with_results(
337278 def leave_if (self , original_node : cst .If , updated_node : cst .If ) -> cst .If :
338279 indented_block = updated_node .body
339280 last_statement = indented_block .body [- 1 ]
340- results = []
341- if m .matches (last_statement , m .SimpleStatementLine ()):
342- yield_expr = m .findall (last_statement , m .Call (func = m .Name (yield_ .__name__ )))
343- assert len (
344- yield_expr
345- ), f"conditional must explicitly { yield_ .__name__ } on last line: { yield_expr } "
346- yield_expr = yield_expr [0 ]
347- results = [cst .Element (ast_call (T ._placeholder_opaque_t .__name__ ))] * len (
348- yield_expr .args
349- )
281+
282+ assert m .matches (
283+ last_statement , m .SimpleStatementLine ()
284+ ), f"conditional must explicitly end with statement"
285+ yield_expr = m .findall (last_statement , m .Call (func = m .Name (yield_ .__name__ )))
286+ assert len (
287+ yield_expr
288+ ), f"conditional must explicitly { yield_ .__name__ } on last line: { yield_expr } "
289+ yield_expr = yield_expr [0 ]
290+ results = [cst .Element (ast_call (T ._placeholder_opaque_t .__name__ ))] * len (
291+ yield_expr .args
292+ )
350293 results = cst .Tuple (results )
351294
352295 test = original_node .test
@@ -362,22 +305,35 @@ def leave_if(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
362305 return updated_node .with_changes (test = new_test )
363306
364307
365- def in_last_statement_maybe_interleave_with_yields (node , new_node ):
366- last_statement = get_last_statement (node )
367- last_statement_body = list (last_statement .body )
368- for i , b in enumerate (last_statement_body [:- 1 ]):
369- next_b = last_statement_body [i + 1 ]
370- # two adjacent yields (this happens when InsertEmptyYield inserts a yield in a deep statement
371- if m .matches (b , m .Expr (m .Call (func = m .Name (yield_ .__name__ )))) and m .matches (
372- next_b , m .Expr (m .Call (func = m .Name (yield_ .__name__ )))
373- ):
374- last_statement_body .insert (i + 1 , new_node )
375- break
308+ def insert_end_if_in_body (node , assign ):
309+ maybe_last_statement = node .body [- 1 ]
310+ if m .matches (maybe_last_statement , m .SimpleStatementLine ()):
311+ # if last thing in body is a simplestatement then you can talk the yield (with a semicolon)
312+ # onto the end
313+ new_maybe_last_statement = maybe_last_statement .with_changes (
314+ body = list (maybe_last_statement .body ) + [assign ]
315+ )
316+ return node .deep_replace (maybe_last_statement , new_maybe_last_statement )
376317 else :
377- last_statement_body .append (new_node )
378- return node .deep_replace (
379- last_statement ,
380- last_statement .with_changes (body = last_statement_body ),
318+ raise RuntimeError ("if statement must have yield" )
319+
320+
321+ def check_unstack_if (original_node , metadata_resolver ):
322+ return m .matches (
323+ original_node ,
324+ m .If (
325+ test = m .NamedExpr (
326+ target = m .MatchMetadataIfTrue (
327+ QualifiedNameProvider ,
328+ lambda qualnames : any (
329+ unstack_if .__name__ in n .name
330+ or unstack_else_if .__name__ in n .name
331+ for n in qualnames
332+ ),
333+ )
334+ )
335+ ),
336+ metadata_resolver = metadata_resolver ,
381337 )
382338
383339
@@ -386,7 +342,7 @@ class InsertEndIfs(StrictTransformer):
386342 def leave_if (self , original_node : cst .If , updated_node : cst .If ) -> cst .If :
387343 assert check_unstack_if (
388344 original_node , self
389- ), f"if must already have had test replaced with unstack_if"
345+ ), f"if must already have had test replaced with unstack_if before endifs can be inserted "
390346
391347 assign = cst .Assign (
392348 targets = [cst .AssignTarget (updated_node .test .target )],
@@ -395,41 +351,12 @@ def leave_if(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
395351 ),
396352 )
397353
398- new_body = in_last_statement_maybe_interleave_with_yields (
399- updated_node .body , assign
400- )
354+ new_body = insert_end_if_in_body (updated_node .body , assign )
401355
402- new_orelse = None
356+ new_orelse = updated_node . orelse
403357 if updated_node .orelse :
404- new_orelse = in_last_statement_maybe_interleave_with_yields (
405- updated_node .orelse , assign
406- )
407- parent = self .get_parent (original_node )
408- if not check_unstack_if (parent , self ) or parent .orelse != original_node :
409- return updated_node .with_changes (body = new_body , orelse = new_orelse )
410-
411- # basically adds a yield for scf.elseif that yields the correct result (i.e., whatever is yielded in the inner
412- # block
413- maybe_assigned_yield_in_body = ast_call (yield_ .__name__ )
414- last_statement_in_body = updated_node .body .body [- 1 ]
415-
416- # if the inner block yields a named result, "re-yield" it
417- if m .matches (last_statement_in_body , m .SimpleStatementLine ()) and m .matches (
418- last_statement_in_body .body [0 ],
419- m .Assign (value = m .Call (func = m .Name (yield_ .__name__ ))),
420- ):
421- maybe_assigned_yield_in_body = last_statement_in_body .body [0 ]
422- # re-yield but you don't need to name it, i.e. it doesn't need to be visible at the python/frontend level
423- # i.e., if a user sets a breakpoint
424- maybe_assigned_yield_in_body = ast_call (
425- yield_ .__name__ ,
426- [cst .Arg (t .target ) for t in maybe_assigned_yield_in_body .targets ],
427- )
428-
429- maybe_assigned_yield_in_body = cst .Expr (maybe_assigned_yield_in_body )
430- new_orelse = in_last_statement_maybe_interleave_with_yields (
431- new_orelse , maybe_assigned_yield_in_body
432- )
358+ new_orelse_body = insert_end_if_in_body (new_orelse .body , assign )
359+ new_orelse = new_orelse .with_changes (body = new_orelse_body )
433360 return updated_node .with_changes (body = new_body , orelse = new_orelse )
434361
435362
@@ -444,7 +371,15 @@ def leave_if_else(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
444371 targets = [cst .AssignTarget (updated_node .test .target )],
445372 value = ast_call (unstack_else .__name__ , [cst .Arg (updated_node .test .target )]),
446373 )
447- new_body = insert_in_deep_last_statement (updated_node .body , assign )
374+
375+ last_statement = updated_node .body .body [- 1 ]
376+ assert m .matches (
377+ last_statement , m .SimpleStatementLine ()
378+ ), f"conditional must explicitly end with statement"
379+ new_last_statement = last_statement .with_changes (
380+ body = list (last_statement .body ) + [cst .Expr (assign )]
381+ )
382+ new_body = updated_node .body .deep_replace (last_statement , new_last_statement )
448383 return updated_node .with_changes (body = new_body )
449384
450385
@@ -505,7 +440,6 @@ class SCFCanonicalizer(Canonicalizer):
505440 ReplaceYieldWithSCFYield ,
506441 CheckMatchingYields ,
507442 ReplaceSCFCond ,
508- CanonicalizeElIfTests ,
509443 InsertEndIfs ,
510444 InsertPreElses ,
511445 ]
0 commit comments