Skip to content

Commit a1f449d

Browse files
committed
remove craziness
1 parent 1294bbe commit a1f449d

File tree

5 files changed

+851
-845
lines changed

5 files changed

+851
-845
lines changed

mlir_utils/dialects/ext/scf.py

Lines changed: 78 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
maybe_cast,
2626
get_result_or_results,
2727
get_user_code_loc,
28+
region_adder,
2829
)
2930

3031
logger = logging.getLogger(__name__)
@@ -109,19 +110,24 @@ def yield_(*args):
109110
return results[0]
110111

111112

112-
def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):
113-
if results_ is None:
114-
results_ = []
115-
if results_:
113+
def _if(cond, results=None, *, has_else=False, loc=None, ip=None):
114+
if results is None:
115+
results = []
116+
if results:
116117
has_else = True
117118
if loc is None:
118119
loc = get_user_code_loc()
119-
return IfOp(cond, results_, hasElse=has_else, loc=loc, ip=ip)
120+
return IfOp(cond, results, hasElse=has_else, loc=loc, ip=ip)
120121

121122

122123
if_ = region_op(_if, terminator=yield__)
123124

124125

126+
@region_adder(terminator=yield__)
127+
def else_(ifop):
128+
return ifop.regions[1]
129+
130+
125131
class IpStack:
126132
# __current_if_op: list[IfOp]
127133
__if_ips: list[InsertionPoint]
@@ -179,23 +185,6 @@ def unstack_else_if(prev_ips_ifop: tuple[IpStack, IfOp], cond: Value, results_=N
179185
return prev_ips + next_if_ip, next_if_op
180186

181187

182-
def get_last_statement(original_node):
183-
statements = m.findall(original_node, m.SimpleStatementLine())
184-
assert len(statements), "no statements...?"
185-
return statements[-1]
186-
187-
188-
def insert_in_deep_last_statement(
189-
original_node: cst.CSTNode,
190-
new_node: cst.CSTNode,
191-
) -> cst.CSTNode:
192-
last_statement = get_last_statement(original_node)
193-
new_last_statement = last_statement.with_changes(
194-
body=list(last_statement.body) + [cst.Expr(new_node)]
195-
)
196-
return original_node.deep_replace(last_statement, new_last_statement)
197-
198-
199188
class ReplaceYieldWithSCFYield(StrictTransformer):
200189
@m.call_if_inside(m.If())
201190
@m.leave(m.Yield(value=m.Tuple()))
@@ -226,41 +215,35 @@ def single_yield(self, original_node: cst.Yield, _updated_node: cst.Yield):
226215
return ast_call(yield_.__name__, args)
227216

228217

229-
def maybe_insert_yield_at_end_or_deep(node):
218+
def maybe_insert_yield_at_end(node):
230219
maybe_last_statement = node.body[-1]
231220
if m.matches(maybe_last_statement, m.SimpleStatementLine()):
232221
if len(m.findall(maybe_last_statement, m.Yield())) > 0:
233222
return node
234223

235-
# if last thing in body is a simplestatement then you can talk the yield (with a semicolon)
224+
# if last thing in body is a simplestatement then you can tack the yield (with a semicolon)
236225
# onto the end
237-
new_maybe_last_statement = insert_in_deep_last_statement(
238-
maybe_last_statement, cst.Yield()
226+
new_maybe_last_statement = maybe_last_statement.with_changes(
227+
body=list(maybe_last_statement.body) + [cst.Expr(cst.Yield())]
239228
)
240-
node = node.deep_replace(maybe_last_statement, new_maybe_last_statement)
229+
return node.deep_replace(maybe_last_statement, new_maybe_last_statement)
241230
else:
242-
# this branch is different (i.e., doesn't check for a match)
243-
# because if the last thing is an indented block, there's no way the user could've intentionally placed
244-
# a yield there that handles this conditional (even if they placed a yield to handle a conditional in that
245-
# last block)
246-
node = insert_in_deep_last_statement(node, cst.Yield())
247-
248-
return node
231+
raise RuntimeError("primitive must have statement as last line")
249232

250233

251234
class InsertEmptyYield(StrictTransformer):
252235
@m.leave(m.If())
253236
def leave_if(self, _original_node: cst.If, updated_node: cst.If) -> cst.If:
254-
new_body = maybe_insert_yield_at_end_or_deep(updated_node.body)
237+
new_body = maybe_insert_yield_at_end(updated_node.body)
255238
new_orelse = updated_node.orelse
256239
if new_orelse:
257-
new_orelse_body = maybe_insert_yield_at_end_or_deep(new_orelse.body)
240+
new_orelse_body = maybe_insert_yield_at_end(new_orelse.body)
258241
new_orelse = new_orelse.with_changes(body=new_orelse_body)
259242
return updated_node.with_changes(body=new_body, orelse=new_orelse)
260243

261244
@m.leave(m.For())
262245
def leave_for(self, _original_node: cst.For, updated_node: cst.For) -> cst.For:
263-
new_body = maybe_insert_yield_at_end_or_deep(updated_node.body)
246+
new_body = maybe_insert_yield_at_end(updated_node.body)
264247
return updated_node.with_changes(body=new_body)
265248

266249

@@ -269,57 +252,15 @@ class CheckMatchingYields(StrictTransformer):
269252
def leave_(self, original_node: cst.If, _updated_node: cst.If) -> cst.If:
270253
n_ifs = len(m.findall(original_node, m.If()))
271254
n_elses = len(m.findall(original_node, m.Else()))
255+
n_fors = len(m.findall(original_node, m.For()))
272256
n_yields = len(m.findall(original_node, m.Call(func=m.Name(yield_.__name__))))
273-
if n_ifs + n_elses <= n_yields:
274-
warnings.warn(
275-
f"unmatched if/elses and yields: {n_ifs=} {n_elses=} {n_yields=}; line {self.get_pos(original_node).start.line}"
257+
if n_ifs + n_elses + n_fors != n_yields:
258+
raise RuntimeError(
259+
f"unmatched if/elses and yields: {n_ifs=} {n_elses=} {n_fors=} {n_yields=}; line {self.get_pos(original_node).start.line}"
276260
)
277261
return original_node
278262

279263

280-
def check_unstack_if(original_node, metadata_resolver):
281-
return m.matches(
282-
original_node,
283-
m.If(
284-
test=m.NamedExpr(
285-
target=m.MatchMetadataIfTrue(
286-
QualifiedNameProvider,
287-
lambda qualnames: any(
288-
unstack_if.__name__ in n.name
289-
or unstack_else_if.__name__ in n.name
290-
for n in qualnames
291-
),
292-
)
293-
)
294-
),
295-
metadata_resolver=metadata_resolver,
296-
)
297-
298-
299-
class CanonicalizeElIfTests(StrictTransformer):
300-
@m.call_if_inside(m.If(orelse=m.If()))
301-
@m.leave(m.If())
302-
def leave_last_elif(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
303-
assert check_unstack_if(
304-
original_node, self
305-
), f"if must already have had test replaced with unstack_if"
306-
parent = self.get_parent(original_node)
307-
if (
308-
not check_unstack_if(parent, self)
309-
# you need this because call_if_inside matches self as well as parent
310-
or parent.orelse != original_node
311-
):
312-
return updated_node
313-
314-
test = updated_node.test
315-
new_test_call = ast_call(
316-
unstack_else_if.__name__,
317-
args=[cst.Arg(parent.test.target)] + list(updated_node.test.value.args),
318-
)
319-
new_test = test.with_changes(value=new_test_call)
320-
return updated_node.with_changes(test=new_test)
321-
322-
323264
class ReplaceSCFCond(StrictTransformer):
324265
@m.leave(
325266
m.If(
@@ -337,16 +278,18 @@ def insert_with_results(
337278
def leave_if(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
338279
indented_block = updated_node.body
339280
last_statement = indented_block.body[-1]
340-
results = []
341-
if m.matches(last_statement, m.SimpleStatementLine()):
342-
yield_expr = m.findall(last_statement, m.Call(func=m.Name(yield_.__name__)))
343-
assert len(
344-
yield_expr
345-
), f"conditional must explicitly {yield_.__name__} on last line: {yield_expr}"
346-
yield_expr = yield_expr[0]
347-
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
348-
yield_expr.args
349-
)
281+
282+
assert m.matches(
283+
last_statement, m.SimpleStatementLine()
284+
), f"conditional must explicitly end with statement"
285+
yield_expr = m.findall(last_statement, m.Call(func=m.Name(yield_.__name__)))
286+
assert len(
287+
yield_expr
288+
), f"conditional must explicitly {yield_.__name__} on last line: {yield_expr}"
289+
yield_expr = yield_expr[0]
290+
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
291+
yield_expr.args
292+
)
350293
results = cst.Tuple(results)
351294

352295
test = original_node.test
@@ -362,22 +305,35 @@ def leave_if(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
362305
return updated_node.with_changes(test=new_test)
363306

364307

365-
def in_last_statement_maybe_interleave_with_yields(node, new_node):
366-
last_statement = get_last_statement(node)
367-
last_statement_body = list(last_statement.body)
368-
for i, b in enumerate(last_statement_body[:-1]):
369-
next_b = last_statement_body[i + 1]
370-
# two adjacent yields (this happens when InsertEmptyYield inserts a yield in a deep statement
371-
if m.matches(b, m.Expr(m.Call(func=m.Name(yield_.__name__)))) and m.matches(
372-
next_b, m.Expr(m.Call(func=m.Name(yield_.__name__)))
373-
):
374-
last_statement_body.insert(i + 1, new_node)
375-
break
308+
def insert_end_if_in_body(node, assign):
309+
maybe_last_statement = node.body[-1]
310+
if m.matches(maybe_last_statement, m.SimpleStatementLine()):
311+
# if last thing in body is a simplestatement then you can talk the yield (with a semicolon)
312+
# onto the end
313+
new_maybe_last_statement = maybe_last_statement.with_changes(
314+
body=list(maybe_last_statement.body) + [assign]
315+
)
316+
return node.deep_replace(maybe_last_statement, new_maybe_last_statement)
376317
else:
377-
last_statement_body.append(new_node)
378-
return node.deep_replace(
379-
last_statement,
380-
last_statement.with_changes(body=last_statement_body),
318+
raise RuntimeError("if statement must have yield")
319+
320+
321+
def check_unstack_if(original_node, metadata_resolver):
322+
return m.matches(
323+
original_node,
324+
m.If(
325+
test=m.NamedExpr(
326+
target=m.MatchMetadataIfTrue(
327+
QualifiedNameProvider,
328+
lambda qualnames: any(
329+
unstack_if.__name__ in n.name
330+
or unstack_else_if.__name__ in n.name
331+
for n in qualnames
332+
),
333+
)
334+
)
335+
),
336+
metadata_resolver=metadata_resolver,
381337
)
382338

383339

@@ -386,7 +342,7 @@ class InsertEndIfs(StrictTransformer):
386342
def leave_if(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
387343
assert check_unstack_if(
388344
original_node, self
389-
), f"if must already have had test replaced with unstack_if"
345+
), f"if must already have had test replaced with unstack_if before endifs can be inserted"
390346

391347
assign = cst.Assign(
392348
targets=[cst.AssignTarget(updated_node.test.target)],
@@ -395,41 +351,12 @@ def leave_if(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
395351
),
396352
)
397353

398-
new_body = in_last_statement_maybe_interleave_with_yields(
399-
updated_node.body, assign
400-
)
354+
new_body = insert_end_if_in_body(updated_node.body, assign)
401355

402-
new_orelse = None
356+
new_orelse = updated_node.orelse
403357
if updated_node.orelse:
404-
new_orelse = in_last_statement_maybe_interleave_with_yields(
405-
updated_node.orelse, assign
406-
)
407-
parent = self.get_parent(original_node)
408-
if not check_unstack_if(parent, self) or parent.orelse != original_node:
409-
return updated_node.with_changes(body=new_body, orelse=new_orelse)
410-
411-
# basically adds a yield for scf.elseif that yields the correct result (i.e., whatever is yielded in the inner
412-
# block
413-
maybe_assigned_yield_in_body = ast_call(yield_.__name__)
414-
last_statement_in_body = updated_node.body.body[-1]
415-
416-
# if the inner block yields a named result, "re-yield" it
417-
if m.matches(last_statement_in_body, m.SimpleStatementLine()) and m.matches(
418-
last_statement_in_body.body[0],
419-
m.Assign(value=m.Call(func=m.Name(yield_.__name__))),
420-
):
421-
maybe_assigned_yield_in_body = last_statement_in_body.body[0]
422-
# re-yield but you don't need to name it, i.e. it doesn't need to be visible at the python/frontend level
423-
# i.e., if a user sets a breakpoint
424-
maybe_assigned_yield_in_body = ast_call(
425-
yield_.__name__,
426-
[cst.Arg(t.target) for t in maybe_assigned_yield_in_body.targets],
427-
)
428-
429-
maybe_assigned_yield_in_body = cst.Expr(maybe_assigned_yield_in_body)
430-
new_orelse = in_last_statement_maybe_interleave_with_yields(
431-
new_orelse, maybe_assigned_yield_in_body
432-
)
358+
new_orelse_body = insert_end_if_in_body(new_orelse.body, assign)
359+
new_orelse = new_orelse.with_changes(body=new_orelse_body)
433360
return updated_node.with_changes(body=new_body, orelse=new_orelse)
434361

435362

@@ -444,7 +371,15 @@ def leave_if_else(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
444371
targets=[cst.AssignTarget(updated_node.test.target)],
445372
value=ast_call(unstack_else.__name__, [cst.Arg(updated_node.test.target)]),
446373
)
447-
new_body = insert_in_deep_last_statement(updated_node.body, assign)
374+
375+
last_statement = updated_node.body.body[-1]
376+
assert m.matches(
377+
last_statement, m.SimpleStatementLine()
378+
), f"conditional must explicitly end with statement"
379+
new_last_statement = last_statement.with_changes(
380+
body=list(last_statement.body) + [cst.Expr(assign)]
381+
)
382+
new_body = updated_node.body.deep_replace(last_statement, new_last_statement)
448383
return updated_node.with_changes(body=new_body)
449384

450385

@@ -505,7 +440,6 @@ class SCFCanonicalizer(Canonicalizer):
505440
ReplaceYieldWithSCFYield,
506441
CheckMatchingYields,
507442
ReplaceSCFCond,
508-
CanonicalizeElIfTests,
509443
InsertEndIfs,
510444
InsertPreElses,
511445
]

0 commit comments

Comments
 (0)