Skip to content

Commit dd7406b

Browse files
committed
[MLIR][Python] Add structured.fuseop to generator.
1 parent 3eca15c commit dd7406b

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,58 @@ def __init__(
139139
ip=ip,
140140
)
141141

142+
@_ods_cext.register_operation(_Dialect, replace=True)
143+
class FuseOp(FuseOp):
144+
"""Specialization for FuseOp class."""
145+
146+
@overload
147+
def __init__(
148+
self,
149+
target: Union[Operation, Value, OpView],
150+
*,
151+
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
152+
interchange: OptionalIntList = None,
153+
loc=None,
154+
ip=None,
155+
):
156+
...
157+
158+
def __init__(
159+
self,
160+
loop_types_or_target: Union[Type, List[Type], Operation, Value],
161+
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
162+
*,
163+
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
164+
interchange: OptionalIntList = None,
165+
loc=None,
166+
ip=None,
167+
):
168+
sizes = sizes if sizes else []
169+
num_loops = sum(v if v == 0 else 1 for v in sizes)
170+
171+
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
172+
loop_types = [transform.AnyOpType.get()] * num_loops
173+
target = loop_types_or_target
174+
assert (
175+
target_or_none is None
176+
), "Cannot construct FuseOp with two targets."
177+
else:
178+
loop_types = (
179+
([loop_types_or_target] * num_loops)
180+
if isinstance(loop_types_or_target, Type)
181+
else loop_types_or_target
182+
)
183+
target = target_or_none
184+
super().__init__(
185+
target.type,
186+
loop_types,
187+
target,
188+
tile_sizes=sizes,
189+
tile_interchange=interchange,
190+
loc=loc,
191+
ip=ip,
192+
)
193+
142194

143195
@_ods_cext.register_operation(_Dialect, replace=True)
144196
class GeneralizeOp(GeneralizeOp):

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,27 @@ def testFuseIntoContainingOpCompact(target):
101101
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
102102

103103

104+
@run
105+
@create_sequence
106+
def testFuseOpCompact(target):
107+
structured.FuseOp(target, sizes=[4, 8], interchange=[0, 1])
108+
# CHECK-LABEL: TEST: testFuseOpCompact
109+
# CHECK: transform.sequence
110+
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
111+
# CHECK-SAME: interchange [0, 1]
112+
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
113+
114+
115+
@run
116+
@create_sequence
117+
def testFuseOpNoArg(target):
118+
structured.FuseOp(target)
119+
# CHECK-LABEL: TEST: testFuseOpNoArg
120+
# CHECK: transform.sequence
121+
# CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
122+
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
123+
124+
104125
@run
105126
@create_sequence
106127
def testGeneralize(target):

0 commit comments

Comments
 (0)