Skip to content

Commit 9379d82

Browse files
committed
fixup! fixup! fixup! [MLIR][Python] Add structured.fuseop to generator.
1 parent 9e10ee1 commit 9379d82

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,25 +149,28 @@ def __init__(
149149
self,
150150
target: Union[Operation, Value, OpView],
151151
*,
152-
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
153-
interchange: OptionalIntList = None,
152+
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
153+
tile_interchange: OptionalIntList = None,
154154
loc=None,
155155
ip=None,
156156
):
157157
...
158158

159159
def __init__(
160160
self,
161-
loop_types_or_target: Union[Type, List[Type], Operation, Value],
161+
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
162162
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
163163
*,
164-
sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
165-
interchange: OptionalIntList = None,
164+
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
165+
tile_interchange: OptionalIntList = None,
166166
loc=None,
167167
ip=None,
168168
):
169-
sizes = sizes if sizes else []
170-
num_loops = sum(v if v == 0 else 1 for v in sizes)
169+
tile_sizes = tile_sizes if tile_sizes else []
170+
tile_interchange = tile_interchange if tile_interchange else []
171+
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
172+
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
173+
num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
171174

172175
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
173176
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -184,8 +187,8 @@ def __init__(
184187
target.type,
185188
loop_types,
186189
target,
187-
tile_sizes=sizes,
188-
tile_interchange=interchange,
190+
tile_sizes=tile_sizes,
191+
tile_interchange=tile_interchange,
189192
loc=loc,
190193
ip=ip,
191194
)

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def testFuseIntoContainingOpCompact(target):
104104
@run
105105
@create_sequence
106106
def testFuseOpCompact(target):
107-
structured.FuseOp(target, sizes=[4, 8], interchange=[0, 1])
107+
structured.FuseOp(target, tile_sizes=[4, 8], tile_interchange=[0, 1])
108108
# CHECK-LABEL: TEST: testFuseOpCompact
109109
# CHECK: transform.sequence
110110
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
@@ -122,6 +122,19 @@ def testFuseOpNoArg(target):
122122
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
123123

124124

125+
@run
126+
@create_sequence
127+
def testFuseOpAttributes(target):
128+
attr = DenseI64ArrayAttr.get([4, 8])
129+
ichange = DenseI64ArrayAttr.get([0, 1])
130+
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
131+
# CHECK-LABEL: TEST: testFuseOpAttributes
132+
# CHECK: transform.sequence
133+
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
134+
# CHECK-SAME: interchange [0, 1]
135+
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
136+
137+
125138
@run
126139
@create_sequence
127140
def testGeneralize(target):

0 commit comments

Comments
 (0)