Skip to content

Commit 82cbdf3

Browse files
authored
support mixed scf bounds (#127)
1 parent 9d0393f commit 82cbdf3

File tree

2 files changed

+76
-18
lines changed

2 files changed

+76
-18
lines changed

mlir/extras/dialects/ext/scf.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,33 @@
3838

3939
opaque = lambda dialect_namespace, buffer: OpaqueType.get(dialect_namespace, buffer)
4040

41-
range_ = for_
4241

42+
def canonicalize_start_stop_step(start, stop, step):
43+
if step is None:
44+
step = 1
45+
if stop is None:
46+
stop = start
47+
start = 0
48+
params = [start, stop, step]
49+
type = IndexType.get()
50+
maybe_types = {p.type for p in params if isinstance(p, Value)}
51+
if maybe_types:
52+
if len(maybe_types) > 1:
53+
raise ValueError(
54+
f"all {start=} and {stop=} and {step=} ir.Value objects must have the same type"
55+
)
56+
type = maybe_types.pop()
4357

44-
def placeholder_opaque_t():
45-
return opaque("scf", "placeholder")
58+
for i, p in enumerate(params):
59+
if isinstance(p, int):
60+
p = _ext_arith_constant(p, type=type)
61+
assert isinstance(p, Value)
62+
params[i] = p
63+
64+
return params[0], params[1], params[2]
4665

4766

48-
def _for(
67+
def _build_for(
4968
start,
5069
stop=None,
5170
step=None,
@@ -54,25 +73,39 @@ def _for(
5473
loc=None,
5574
ip=None,
5675
):
57-
if step is None:
58-
step = 1
59-
if stop is None:
60-
stop = start
61-
start = 0
62-
params = [start, stop, step]
63-
for i, p in enumerate(params):
64-
if isinstance(p, int):
65-
p = _ext_arith_constant(p, index=True)
66-
if not _is_index_type(p.type):
67-
p = index_cast(p)
68-
params[i] = p
76+
start, stop, step = canonicalize_start_stop_step(start, stop, step)
77+
return ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
6978

79+
80+
def range_(
81+
start,
82+
stop=None,
83+
step=None,
84+
iter_args: Optional[Sequence[Value]] = None,
85+
*,
86+
loc=None,
87+
ip=None,
88+
):
7089
if loc is None:
7190
loc = get_user_code_loc()
72-
return ForOp(*params, iter_args, loc=loc, ip=ip)
91+
92+
for_op = _build_for(start, stop, step, iter_args, loc=loc, ip=ip)
93+
iv = for_op.induction_variable
94+
iter_args = tuple(for_op.inner_iter_args)
95+
with InsertionPoint(for_op.body):
96+
if len(iter_args) > 1:
97+
yield iv, iter_args, for_op.results
98+
elif len(iter_args) == 1:
99+
yield iv, iter_args[0], for_op.results[0]
100+
else:
101+
yield iv
102+
103+
104+
def placeholder_opaque_t():
105+
return opaque("scf", "placeholder")
73106

74107

75-
for_ = region_op(_for, terminator=yield__)
108+
for_ = region_op(_build_for, terminator=yield__)
76109

77110

78111
@_cext.register_operation(_Dialect, replace=True)

tests/test_scf.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,31 @@ def test_for_bare(ctx: MLIRContext):
201201
filecheck(correct, ctx.module)
202202

203203

204+
def test_mixed_start_stop_for(ctx: MLIRContext):
205+
ten = constant(10)
206+
207+
for i in range_(0, ten):
208+
three = constant(3.0)
209+
four = constant(4.0)
210+
yield_()
211+
212+
ctx.module.operation.verify()
213+
correct = dedent(
214+
"""\
215+
module {
216+
%c10_i32 = arith.constant 10 : i32
217+
%c0_i32 = arith.constant 0 : i32
218+
%c1_i32 = arith.constant 1 : i32
219+
scf.for %arg0 = %c0_i32 to %c10_i32 step %c1_i32 : i32 {
220+
%cst = arith.constant 3.000000e+00 : f32
221+
%cst_0 = arith.constant 4.000000e+00 : f32
222+
}
223+
}
224+
"""
225+
)
226+
filecheck(correct, ctx.module)
227+
228+
204229
def test_scf_canonicalizer_with_implicit_yield(ctx: MLIRContext):
205230
@canonicalize(using=canonicalizer)
206231
def foo():

0 commit comments

Comments
 (0)