@@ -77,16 +77,36 @@ def range_(
7777 yield iv , iter_args [0 ]
7878 else :
7979 yield iv
80- if len (iter_args ):
81- previous_frame = inspect .currentframe ().f_back
82- replacements = tuple (map (maybe_cast , for_op .results_ ))
83- _update_caller_vars (previous_frame , iter_args , replacements )
8480
8581
8682def yield_ (* args ):
8783 if len (args ) == 1 and isinstance (args [0 ], OpResultList ):
8884 args = list (args [0 ])
89- yield__ (args )
85+ y = yield__ (args )
86+ parent_op = y .operation .parent .opview
87+ if len (parent_op .results_ ):
88+ results = get_result_or_results (parent_op )
89+ assert (
90+ isinstance (results , (OpResult , OpResultList ))
91+ or isinstance (results , list )
92+ and all (isinstance (r , OpResult ) for r in results )
93+ ), f"api has changed: { results = } "
94+ if isinstance (results , OpResult ):
95+ results = [results ]
96+ unpacked_args = args
97+ if any (isinstance (a , OpResultList ) for a in unpacked_args ):
98+ assert len (unpacked_args ) == 1
99+ unpacked_args = list (unpacked_args [0 ])
100+
101+ assert len (results ) == len (unpacked_args ), f"{ results = } , { unpacked_args = } "
102+ for i , r in enumerate (results ):
103+ if r .type == T ._placeholder_opaque_t ():
104+ r .set_type (unpacked_args [i ].type )
105+
106+ results = maybe_cast (results )
107+ if len (results ) > 1 :
108+ return results
109+ return results [0 ]
90110
91111
92112def _if (cond , results_ = None , * , has_else = False , loc = None , ip = None ):
@@ -184,22 +204,51 @@ def stack_if(*args, **kwargs):
184204 return IfStack .push (* args , ** kwargs )
185205
186206
187- def stack_yield (* args ):
188- return IfStack .yield_ (* args )
207+ def unstack_if (cond : Value , results_ = None , has_else = False ):
208+ if results_ is None :
209+ results_ = []
210+ if results_ :
211+ has_else = True
212+ assert isinstance (cond , Value ), f"cond must be a mlir.Value: { cond = } "
213+ if_op = _if (cond , results_ , has_else = has_else )
214+ cond .owner .move_before (if_op )
215+
216+ ip = InsertionPoint (if_op .then_block )
217+ ip .__enter__ ()
218+
219+ return ip , if_op
189220
190221
191222def end_branch ():
192223 IfStack .pop_branch ()
193224
194225
226+ def unstack_end_branch (ip ):
227+ ip .__exit__ (None , None , None )
228+
229+
195230def else_ ():
196231 return IfStack .push_else ()
197232
198233
234+ def unstack_else_if (if_op ):
235+ assert len (
236+ if_op .regions [1 ].blocks
237+ ), f"can't have else without bb in second region of { if_op = } "
238+
239+ ip = InsertionPoint (if_op .else_block )
240+ ip .__enter__ ()
241+ return ip
242+
243+
199244def end_if ():
200245 IfStack .pop ()
201246
202247
248+ def unstack_end_if (ip ):
249+ ip .__exit__ (None , None , None )
250+
251+
203252def insert_body_maybe_semicolon (
204253 node : cst .CSTNode , index : int , new_node : cst .CSTNode , before = False
205254):
@@ -233,125 +282,121 @@ def insert_body_maybe_semicolon(
233282
234283
235284class ReplaceYieldWithSCFYield (StrictTransformer ):
236- @m .call_if_inside (m .If (test = m . NamedExpr ( value = m . Comparison ()) ))
285+ @m .call_if_inside (m .If ())
237286 @m .leave (m .Yield (value = m .Tuple ()))
238287 def tuple_yield_inside_conditional (
239- self , original_node : cst .Yield , updated_node : cst .Yield
288+ self , original_node : cst .Yield , _updated_node : cst .Yield
240289 ):
241290 args = [cst .Arg (e .value ) for e in original_node .value .elements ]
242- return ast_call (stack_yield .__name__ , args )
291+ return ast_call (yield_ .__name__ , args )
243292
244- @m .call_if_inside (m .If (test = m . NamedExpr ( value = m . Comparison ()) ))
293+ @m .call_if_inside (m .If ())
245294 @m .leave (m .Yield (value = ~ m .Tuple ()))
246295 def single_yield_inside_conditional (
247- self , original_node : cst .Yield , updated_node : cst .Yield
296+ self , original_node : cst .Yield , _updated_node : cst .Yield
248297 ):
249298 args = [cst .Arg (original_node .value )] if original_node .value else []
250- return ast_call (stack_yield .__name__ , args )
299+ return ast_call (yield_ .__name__ , args )
251300
252- @m .call_if_not_inside (m .If (test = m . NamedExpr ( value = m . Comparison ()) ))
301+ @m .call_if_not_inside (m .If ())
253302 @m .leave (m .Yield (value = m .Tuple ()))
254- def tuple_yield (self , original_node : cst .Yield , updated_node : cst .Yield ):
303+ def tuple_yield (self , original_node : cst .Yield , _updated_node : cst .Yield ):
255304 args = [cst .Arg (e .value ) for e in original_node .value .elements ]
256305 return ast_call (yield_ .__name__ , args )
257306
258- @m .call_if_not_inside (m .If (test = m . NamedExpr ( value = m . Comparison ()) ))
307+ @m .call_if_not_inside (m .If ())
259308 @m .leave (m .Yield (value = ~ m .Tuple ()))
260- def single_yield (self , original_node : cst .Yield , updated_node : cst .Yield ):
309+ def single_yield (self , original_node : cst .Yield , _updated_node : cst .Yield ):
261310 args = [cst .Arg (original_node .value )] if original_node .value else []
262311 return ast_call (yield_ .__name__ , args )
263312
264313
265- class InsertEmptySCFYield (StrictTransformer ):
314+ class InsertEmptyYield (StrictTransformer ):
266315 @m .leave (m .If () | m .Else ())
267316 def leave_ (
268317 self , _original_node : cst .If | cst .Else , updated_node : cst .If | cst .Else
269318 ) -> cst .If | cst .Else :
270319 indented_block = updated_node .body
271320 last_statement = indented_block .body [- 1 ]
272- if not m .matches (last_statement , m .SimpleStatementLine ([m .Expr (m .Yield ())])):
273- return insert_body_maybe_semicolon (
274- updated_node , - 1 , ast_call (yield_ .__name__ )
275- )
321+ if not m .matches (last_statement , m .SimpleStatementLine ()):
322+ return insert_body_maybe_semicolon (updated_node , - 1 , cst .Yield ())
323+ elif m .matches (last_statement , m .SimpleStatementLine ()) and not m .findall (
324+ last_statement , m .Yield ()
325+ ):
326+ return insert_body_maybe_semicolon (updated_node , - 1 , cst .Yield ())
276327 # VERY IMPORTANT: you have to return the updated node if you believe
277328 # at any point there was a mutation anywhere in the tree below
278329 return updated_node
279330
280331
281332class CanonicalizeElIfs (StrictTransformer ):
282- @m .leave (m .If (orelse = m .If (test = m . NamedExpr () )))
283- def leave_if_with_elif_named (
333+ @m .leave (m .If (orelse = m .If ()))
334+ def leave_if_with_elif (
284335 self , _original_node : cst .If , updated_node : cst .If
285336 ) -> cst .If :
286- return updated_node .with_changes (
287- orelse = cst .Else (
288- cst .IndentedBlock (
289- [
290- updated_node .orelse ,
291- cst .SimpleStatementLine (
292- [cst .Expr (cst .Yield (updated_node .orelse .test .target ))]
337+ indented_block = updated_node .orelse .body
338+ last_statement = indented_block .body [- 1 ]
339+ if m .matches (last_statement , m .SimpleStatementLine ()) and m .matches (
340+ last_statement .body [- 1 ], m .Assign (value = m .Yield ())
341+ ):
342+ assign_targets = last_statement .body [- 1 ].targets
343+ last_statement = cst .SimpleStatementLine (
344+ [
345+ cst .Assign (
346+ targets = assign_targets ,
347+ value = cst .Yield (
348+ cst .Tuple ([cst .Element (a .target ) for a in assign_targets ])
349+ if len (assign_targets ) > 1
350+ else assign_targets [0 ].target
293351 ),
294- ]
295- )
352+ )
353+ ]
296354 )
297- )
298-
299- @m .leave (m .If (orelse = m .If (test = ~ m .NamedExpr ())))
300- def leave_if_with_elif (
301- self , _original_node : cst .If , updated_node : cst .If
302- ) -> cst .If :
303- return updated_node .with_changes (
304- orelse = cst .Else (cst .IndentedBlock ([updated_node .orelse ]))
305- )
355+ body = [updated_node .orelse , last_statement ]
356+ else :
357+ body = [updated_node .orelse ]
358+ return updated_node .with_changes (orelse = cst .Else (cst .IndentedBlock (body )))
306359
307360
308361class ReplaceSCFCond (StrictTransformer ):
309- @m .leave (m .If (test = m .NamedExpr ( value = m . Call (func = m .Name (stack_if .__name__ ) ))))
362+ @m .leave (m .If (test = m .Call (func = m .Name (stack_if .__name__ ))))
310363 def insert_with_results (
311364 self , original_node : cst .If , _updated_node : cst .If
312365 ) -> cst .If :
313366 return original_node
314367
315- @m .leave (m .If (test = m . NamedExpr ( value = m .Comparison ( ))))
368+ @m .leave (m .If (test = ~ m . Call ( func = m .Name ( stack_if . __name__ ))))
316369 def insert_with_results (
317370 self , original_node : cst .If , updated_node : cst .If
318371 ) -> cst .If :
319372 indented_block = updated_node .body
320373 last_statement = indented_block .body [- 1 ]
321374 assert m .matches (
322375 last_statement , m .SimpleStatementLine ()
323- ), f"conditional with := must explicitly yield on last line"
324- yield_expr = last_statement .body [0 ]
325- if m .matches (yield_expr .value , m .Call (func = m .Name (stack_yield .__name__ ))):
326- results = [cst .Element (ast_call (T ._placeholder_opaque_t .__name__ ))] * len (
327- yield_expr .value .args
328- )
329- elif m .matches (yield_expr .value .value , m .Name ()):
330- results = [cst .Element (ast_call (T ._placeholder_opaque_t .__name__ ))]
331- elif m .matches (yield_expr .value .value , m .Tuple ()):
332- results = [cst .Element (ast_call (T ._placeholder_opaque_t .__name__ ))] * len (
333- yield_expr .value .value .elements
334- )
376+ ), f"conditional must end with a statement"
377+ yield_expr = m .findall (last_statement , m .Call (func = m .Name (yield_ .__name__ )))
378+ assert (
379+ len (yield_expr ) == 1
380+ ), f"conditional with must explicitly { yield_ .__name__ } on last line: { yield_expr } "
381+ yield_expr = yield_expr [0 ]
382+ results = [cst .Element (ast_call (T ._placeholder_opaque_t .__name__ ))] * len (
383+ yield_expr .args
384+ )
335385 results = cst .Tuple (results )
336386
337387 test = original_node .test
338- compare = test .value
339- assert m .matches (
340- compare , m .Comparison ()
341- ), f"expected cst.Compare from { compare = } "
342- new_compare = ast_call (
343- stack_if .__name__ , args = [cst .Arg (compare ), cst .Arg (results )]
388+ new_test = ast_call (
389+ stack_if .__name__ ,
390+ args = [
391+ cst .Arg (test ),
392+ cst .Arg (results ),
393+ cst .Arg (
394+ cst .Name (str (bool (original_node .orelse ))),
395+ keyword = cst .Name ("has_else" ),
396+ ),
397+ ],
344398 )
345- new_test = test .deep_replace (compare , new_compare )
346- return updated_node .with_changes (test = new_test )
347-
348- @m .leave (m .If (test = m .Comparison ()))
349- def insert_no_results (self , original_node : cst .If , updated_node : cst .If ) -> cst .If :
350- test = original_node .test
351- args = [cst .Arg (test )]
352- if original_node .orelse :
353- args += [cst .Arg (cst .Tuple ([])), cst .Arg (cst .Name (str (True )))]
354- new_test = ast_call (stack_if .__name__ , args = args )
399+ new_test = test .deep_replace (test , new_test )
355400 return updated_node .with_changes (test = new_test )
356401
357402
@@ -424,7 +469,6 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
424469 f .__globals__ [end_branch .__name__ ] = end_branch
425470 f .__globals__ [end_if .__name__ ] = end_if
426471 f .__globals__ [stack_if .__name__ ] = stack_if
427- f .__globals__ [stack_yield .__name__ ] = stack_yield
428472 f .__globals__ [yield_ .__name__ ] = yield_
429473 f .__globals__ [T ._placeholder_opaque_t .__name__ ] = T ._placeholder_opaque_t
430474 return code
@@ -433,7 +477,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
433477class SCFCanonicalizer (Canonicalizer ):
434478 cst_transformers = [
435479 CanonicalizeElIfs ,
436- InsertEmptySCFYield ,
480+ InsertEmptyYield ,
437481 ReplaceYieldWithSCFYield ,
438482 ReplaceSCFCond ,
439483 InsertEndIfs ,
0 commit comments