Skip to content

Commit e3cea52

Browse files
committed
refactor scfs to use = yield
1 parent 7af45b1 commit e3cea52

File tree

6 files changed

+442
-290
lines changed

6 files changed

+442
-290
lines changed

mlir_utils/ast/canonicalize.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def transform_cst(
4343

4444
logger.debug("[transformed code]\n\n%s", module_cst.code)
4545

46-
tree = ast.parse(module_cst.code, filename=inspect.getfile(f))
47-
tree = ast.increment_lineno(tree, f.__code__.co_firstlineno - 1)
48-
module_code_o = compile(tree, f.__code__.co_filename, "exec")
46+
code = "\n" * (f.__code__.co_firstlineno - 1) + module_cst.code
47+
module_code_o = compile(code, f.__code__.co_filename, "exec")
4948
new_f_code_o = next(
5049
c
5150
for c in module_code_o.co_consts

mlir_utils/dialects/ext/scf.py

Lines changed: 118 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,36 @@ def range_(
7777
yield iv, iter_args[0]
7878
else:
7979
yield iv
80-
if len(iter_args):
81-
previous_frame = inspect.currentframe().f_back
82-
replacements = tuple(map(maybe_cast, for_op.results_))
83-
_update_caller_vars(previous_frame, iter_args, replacements)
8480

8581

8682
def yield_(*args):
8783
if len(args) == 1 and isinstance(args[0], OpResultList):
8884
args = list(args[0])
89-
yield__(args)
85+
y = yield__(args)
86+
parent_op = y.operation.parent.opview
87+
if len(parent_op.results_):
88+
results = get_result_or_results(parent_op)
89+
assert (
90+
isinstance(results, (OpResult, OpResultList))
91+
or isinstance(results, list)
92+
and all(isinstance(r, OpResult) for r in results)
93+
), f"api has changed: {results=}"
94+
if isinstance(results, OpResult):
95+
results = [results]
96+
unpacked_args = args
97+
if any(isinstance(a, OpResultList) for a in unpacked_args):
98+
assert len(unpacked_args) == 1
99+
unpacked_args = list(unpacked_args[0])
100+
101+
assert len(results) == len(unpacked_args), f"{results=}, {unpacked_args=}"
102+
for i, r in enumerate(results):
103+
if r.type == T._placeholder_opaque_t():
104+
r.set_type(unpacked_args[i].type)
105+
106+
results = maybe_cast(results)
107+
if len(results) > 1:
108+
return results
109+
return results[0]
90110

91111

92112
def _if(cond, results_=None, *, has_else=False, loc=None, ip=None):
@@ -184,22 +204,51 @@ def stack_if(*args, **kwargs):
184204
return IfStack.push(*args, **kwargs)
185205

186206

187-
def stack_yield(*args):
188-
return IfStack.yield_(*args)
207+
def unstack_if(cond: Value, results_=None, has_else=False):
208+
if results_ is None:
209+
results_ = []
210+
if results_:
211+
has_else = True
212+
assert isinstance(cond, Value), f"cond must be a mlir.Value: {cond=}"
213+
if_op = _if(cond, results_, has_else=has_else)
214+
cond.owner.move_before(if_op)
215+
216+
ip = InsertionPoint(if_op.then_block)
217+
ip.__enter__()
218+
219+
return ip, if_op
189220

190221

191222
def end_branch():
192223
IfStack.pop_branch()
193224

194225

226+
def unstack_end_branch(ip):
227+
ip.__exit__(None, None, None)
228+
229+
195230
def else_():
196231
return IfStack.push_else()
197232

198233

234+
def unstack_else_if(if_op):
235+
assert len(
236+
if_op.regions[1].blocks
237+
), f"can't have else without bb in second region of {if_op=}"
238+
239+
ip = InsertionPoint(if_op.else_block)
240+
ip.__enter__()
241+
return ip
242+
243+
199244
def end_if():
200245
IfStack.pop()
201246

202247

248+
def unstack_end_if(ip):
249+
ip.__exit__(None, None, None)
250+
251+
203252
def insert_body_maybe_semicolon(
204253
node: cst.CSTNode, index: int, new_node: cst.CSTNode, before=False
205254
):
@@ -233,125 +282,121 @@ def insert_body_maybe_semicolon(
233282

234283

235284
class ReplaceYieldWithSCFYield(StrictTransformer):
236-
@m.call_if_inside(m.If(test=m.NamedExpr(value=m.Comparison())))
285+
@m.call_if_inside(m.If())
237286
@m.leave(m.Yield(value=m.Tuple()))
238287
def tuple_yield_inside_conditional(
239-
self, original_node: cst.Yield, updated_node: cst.Yield
288+
self, original_node: cst.Yield, _updated_node: cst.Yield
240289
):
241290
args = [cst.Arg(e.value) for e in original_node.value.elements]
242-
return ast_call(stack_yield.__name__, args)
291+
return ast_call(yield_.__name__, args)
243292

244-
@m.call_if_inside(m.If(test=m.NamedExpr(value=m.Comparison())))
293+
@m.call_if_inside(m.If())
245294
@m.leave(m.Yield(value=~m.Tuple()))
246295
def single_yield_inside_conditional(
247-
self, original_node: cst.Yield, updated_node: cst.Yield
296+
self, original_node: cst.Yield, _updated_node: cst.Yield
248297
):
249298
args = [cst.Arg(original_node.value)] if original_node.value else []
250-
return ast_call(stack_yield.__name__, args)
299+
return ast_call(yield_.__name__, args)
251300

252-
@m.call_if_not_inside(m.If(test=m.NamedExpr(value=m.Comparison())))
301+
@m.call_if_not_inside(m.If())
253302
@m.leave(m.Yield(value=m.Tuple()))
254-
def tuple_yield(self, original_node: cst.Yield, updated_node: cst.Yield):
303+
def tuple_yield(self, original_node: cst.Yield, _updated_node: cst.Yield):
255304
args = [cst.Arg(e.value) for e in original_node.value.elements]
256305
return ast_call(yield_.__name__, args)
257306

258-
@m.call_if_not_inside(m.If(test=m.NamedExpr(value=m.Comparison())))
307+
@m.call_if_not_inside(m.If())
259308
@m.leave(m.Yield(value=~m.Tuple()))
260-
def single_yield(self, original_node: cst.Yield, updated_node: cst.Yield):
309+
def single_yield(self, original_node: cst.Yield, _updated_node: cst.Yield):
261310
args = [cst.Arg(original_node.value)] if original_node.value else []
262311
return ast_call(yield_.__name__, args)
263312

264313

265-
class InsertEmptySCFYield(StrictTransformer):
314+
class InsertEmptyYield(StrictTransformer):
266315
@m.leave(m.If() | m.Else())
267316
def leave_(
268317
self, _original_node: cst.If | cst.Else, updated_node: cst.If | cst.Else
269318
) -> cst.If | cst.Else:
270319
indented_block = updated_node.body
271320
last_statement = indented_block.body[-1]
272-
if not m.matches(last_statement, m.SimpleStatementLine([m.Expr(m.Yield())])):
273-
return insert_body_maybe_semicolon(
274-
updated_node, -1, ast_call(yield_.__name__)
275-
)
321+
if not m.matches(last_statement, m.SimpleStatementLine()):
322+
return insert_body_maybe_semicolon(updated_node, -1, cst.Yield())
323+
elif m.matches(last_statement, m.SimpleStatementLine()) and not m.findall(
324+
last_statement, m.Yield()
325+
):
326+
return insert_body_maybe_semicolon(updated_node, -1, cst.Yield())
276327
# VERY IMPORTANT: you have to return the updated node if you believe
277328
# at any point there was a mutation anywhere in the tree below
278329
return updated_node
279330

280331

281332
class CanonicalizeElIfs(StrictTransformer):
282-
@m.leave(m.If(orelse=m.If(test=m.NamedExpr())))
283-
def leave_if_with_elif_named(
333+
@m.leave(m.If(orelse=m.If()))
334+
def leave_if_with_elif(
284335
self, _original_node: cst.If, updated_node: cst.If
285336
) -> cst.If:
286-
return updated_node.with_changes(
287-
orelse=cst.Else(
288-
cst.IndentedBlock(
289-
[
290-
updated_node.orelse,
291-
cst.SimpleStatementLine(
292-
[cst.Expr(cst.Yield(updated_node.orelse.test.target))]
337+
indented_block = updated_node.orelse.body
338+
last_statement = indented_block.body[-1]
339+
if m.matches(last_statement, m.SimpleStatementLine()) and m.matches(
340+
last_statement.body[-1], m.Assign(value=m.Yield())
341+
):
342+
assign_targets = last_statement.body[-1].targets
343+
last_statement = cst.SimpleStatementLine(
344+
[
345+
cst.Assign(
346+
targets=assign_targets,
347+
value=cst.Yield(
348+
cst.Tuple([cst.Element(a.target) for a in assign_targets])
349+
if len(assign_targets) > 1
350+
else assign_targets[0].target
293351
),
294-
]
295-
)
352+
)
353+
]
296354
)
297-
)
298-
299-
@m.leave(m.If(orelse=m.If(test=~m.NamedExpr())))
300-
def leave_if_with_elif(
301-
self, _original_node: cst.If, updated_node: cst.If
302-
) -> cst.If:
303-
return updated_node.with_changes(
304-
orelse=cst.Else(cst.IndentedBlock([updated_node.orelse]))
305-
)
355+
body = [updated_node.orelse, last_statement]
356+
else:
357+
body = [updated_node.orelse]
358+
return updated_node.with_changes(orelse=cst.Else(cst.IndentedBlock(body)))
306359

307360

308361
class ReplaceSCFCond(StrictTransformer):
309-
@m.leave(m.If(test=m.NamedExpr(value=m.Call(func=m.Name(stack_if.__name__)))))
362+
@m.leave(m.If(test=m.Call(func=m.Name(stack_if.__name__))))
310363
def insert_with_results(
311364
self, original_node: cst.If, _updated_node: cst.If
312365
) -> cst.If:
313366
return original_node
314367

315-
@m.leave(m.If(test=m.NamedExpr(value=m.Comparison())))
368+
@m.leave(m.If(test=~m.Call(func=m.Name(stack_if.__name__))))
316369
def insert_with_results(
317370
self, original_node: cst.If, updated_node: cst.If
318371
) -> cst.If:
319372
indented_block = updated_node.body
320373
last_statement = indented_block.body[-1]
321374
assert m.matches(
322375
last_statement, m.SimpleStatementLine()
323-
), f"conditional with := must explicitly yield on last line"
324-
yield_expr = last_statement.body[0]
325-
if m.matches(yield_expr.value, m.Call(func=m.Name(stack_yield.__name__))):
326-
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
327-
yield_expr.value.args
328-
)
329-
elif m.matches(yield_expr.value.value, m.Name()):
330-
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))]
331-
elif m.matches(yield_expr.value.value, m.Tuple()):
332-
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
333-
yield_expr.value.value.elements
334-
)
376+
), f"conditional must end with a statement"
377+
yield_expr = m.findall(last_statement, m.Call(func=m.Name(yield_.__name__)))
378+
assert (
379+
len(yield_expr) == 1
380+
), f"conditional with must explicitly {yield_.__name__} on last line: {yield_expr}"
381+
yield_expr = yield_expr[0]
382+
results = [cst.Element(ast_call(T._placeholder_opaque_t.__name__))] * len(
383+
yield_expr.args
384+
)
335385
results = cst.Tuple(results)
336386

337387
test = original_node.test
338-
compare = test.value
339-
assert m.matches(
340-
compare, m.Comparison()
341-
), f"expected cst.Compare from {compare=}"
342-
new_compare = ast_call(
343-
stack_if.__name__, args=[cst.Arg(compare), cst.Arg(results)]
388+
new_test = ast_call(
389+
stack_if.__name__,
390+
args=[
391+
cst.Arg(test),
392+
cst.Arg(results),
393+
cst.Arg(
394+
cst.Name(str(bool(original_node.orelse))),
395+
keyword=cst.Name("has_else"),
396+
),
397+
],
344398
)
345-
new_test = test.deep_replace(compare, new_compare)
346-
return updated_node.with_changes(test=new_test)
347-
348-
@m.leave(m.If(test=m.Comparison()))
349-
def insert_no_results(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
350-
test = original_node.test
351-
args = [cst.Arg(test)]
352-
if original_node.orelse:
353-
args += [cst.Arg(cst.Tuple([])), cst.Arg(cst.Name(str(True)))]
354-
new_test = ast_call(stack_if.__name__, args=args)
399+
new_test = test.deep_replace(test, new_test)
355400
return updated_node.with_changes(test=new_test)
356401

357402

@@ -424,7 +469,6 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
424469
f.__globals__[end_branch.__name__] = end_branch
425470
f.__globals__[end_if.__name__] = end_if
426471
f.__globals__[stack_if.__name__] = stack_if
427-
f.__globals__[stack_yield.__name__] = stack_yield
428472
f.__globals__[yield_.__name__] = yield_
429473
f.__globals__[T._placeholder_opaque_t.__name__] = T._placeholder_opaque_t
430474
return code
@@ -433,7 +477,7 @@ def patch_bytecode(self, code: ConcreteBytecode, f):
433477
class SCFCanonicalizer(Canonicalizer):
434478
cst_transformers = [
435479
CanonicalizeElIfs,
436-
InsertEmptySCFYield,
480+
InsertEmptyYield,
437481
ReplaceYieldWithSCFYield,
438482
ReplaceSCFCond,
439483
InsertEndIfs,

mlir_utils/util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_result_or_results(
3636
if isinstance(op, Value):
3737
return op
3838
return (
39-
get_op_results_or_values(op)
39+
list(get_op_results_or_values(op))
4040
if len(op.operation.results) > 1
4141
else get_op_result_or_value(op)
4242
if len(op.operation.results) > 0
@@ -94,14 +94,17 @@ def get_value_caster(typeid: TypeID):
9494
return __VALUE_CASTERS[typeid]
9595

9696

97-
def maybe_cast(val: Value):
97+
def maybe_cast(val: Value | list[Value]):
9898
"""Maybe cast an ir.Value to one of Tensor, Scalar.
9999
100100
Args:
101101
val: The ir.Value to maybe cast.
102102
"""
103103
from mlir_utils.dialects.ext.arith import Scalar
104104

105+
if isinstance(val, list):
106+
return list(map(maybe_cast, val))
107+
105108
if not isinstance(val, Value):
106109
return val
107110

tests/test_location_tracking.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ def test_if_replace_yield_5(ctx: MLIRContext):
3434
def iffoo():
3535
one = constant(1.0)
3636
two = constant(2.0)
37-
if res := stack_if(one < two, (T.f64_t, T.f64_t, T.f64_t)):
37+
if stack_if(one < two, (T.f64_t, T.f64_t, T.f64_t)):
3838
three = constant(3.0)
39-
yield three, three, three
39+
res1, res2, res3 = yield three, three, three
4040
else:
4141
four = constant(4.0)
42-
yield four, four, four
42+
res1, res2, res3 = yield four, four, four
4343
return
4444

4545
iffoo()
@@ -49,24 +49,24 @@ def iffoo():
4949
module {{
5050
%cst = arith.constant 1.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:35:10
5151
%cst_0 = arith.constant 2.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:36:10
52-
%0 = arith.cmpf olt, %cst, %cst_0 : f64 THIS_DIR{sep}test_location_tracking.py:37:23
52+
%0 = arith.cmpf olt, %cst, %cst_0 : f64 THIS_DIR{sep}test_location_tracking.py:37:16
5353
%1:3 = scf.if %0 -> (f64, f64, f64) {{
5454
%cst_1 = arith.constant 3.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:38:16
55-
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:39:8
55+
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:39:27
5656
}} else {{
5757
%cst_1 = arith.constant 4.000000e+00 : f64 THIS_DIR{sep}test_location_tracking.py:41:24
58-
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:42:8
59-
}} THIS_DIR{sep}test_location_tracking.py:37:14
58+
scf.yield %cst_1, %cst_1, %cst_1 : f64, f64, f64 THIS_DIR{sep}test_location_tracking.py:42:27
59+
}} THIS_DIR{sep}test_location_tracking.py:37:7
6060
}} [unknown]
6161
#loc = [unknown]
6262
#loc1 = THIS_DIR{sep}test_location_tracking.py:35:10
6363
#loc2 = THIS_DIR{sep}test_location_tracking.py:36:10
64-
#loc3 = THIS_DIR{sep}test_location_tracking.py:37:23
65-
#loc4 = THIS_DIR{sep}test_location_tracking.py:37:14
64+
#loc3 = THIS_DIR{sep}test_location_tracking.py:37:16
65+
#loc4 = THIS_DIR{sep}test_location_tracking.py:37:7
6666
#loc5 = THIS_DIR{sep}test_location_tracking.py:38:16
67-
#loc6 = THIS_DIR{sep}test_location_tracking.py:39:8
67+
#loc6 = THIS_DIR{sep}test_location_tracking.py:39:27
6868
#loc7 = THIS_DIR{sep}test_location_tracking.py:41:24
69-
#loc8 = THIS_DIR{sep}test_location_tracking.py:42:8
69+
#loc8 = THIS_DIR{sep}test_location_tracking.py:42:27
7070
"""
7171
)
7272
asm = get_asm(ctx.module.operation)

0 commit comments

Comments
 (0)