Skip to content
71 changes: 71 additions & 0 deletions mlir/python/mlir/dialects/transform/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,77 @@ def __init__(
)


@_ods_cext.register_operation(_Dialect, replace=True)
class FuseOp(FuseOp):
"""Specialization for FuseOp class."""

@overload
def __init__(
self,
loop_types: Union[Type, Sequence[Type]],
target: Union[Operation, Value, OpView],
*,
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
loc=None,
ip=None,
):
...

@overload
def __init__(
self,
target: Union[Operation, Value, OpView],
*,
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
loc=None,
ip=None,
):
...

def __init__(
self,
loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value],
target_or_none: Optional[Union[Operation, Value, OpView]] = None,
*,
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
tile_interchange: OptionalIntList = None,
apply_cleanup: Optional[bool] = False,
loc=None,
ip=None,
):
tile_sizes = tile_sizes if tile_sizes else []
tile_interchange = tile_interchange if tile_interchange else []
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)

if isinstance(loop_types_or_target, (Operation, Value, OpView)):
loop_types = [transform.AnyOpType.get()] * num_loops
target = loop_types_or_target
assert target_or_none is None, "Cannot construct FuseOp with two targets."
else:
loop_types = (
([loop_types_or_target] * num_loops)
if isinstance(loop_types_or_target, Type)
else loop_types_or_target
)
target = target_or_none
super().__init__(
target.type,
loop_types,
target,
tile_sizes=tile_sizes,
tile_interchange=tile_interchange,
apply_cleanup=apply_cleanup,
loc=loc,
ip=ip,
)


@_ods_cext.register_operation(_Dialect, replace=True)
class GeneralizeOp(GeneralizeOp):
"""Specialization for GeneralizeOp class."""
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/python/dialects/transform_structured_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,42 @@ def testFuseIntoContainingOpCompact(target):
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)


@run
@create_sequence
def testFuseOpCompact(target):
structured.FuseOp(
target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
# CHECK-SAME: interchange [0, 1] apply_cleanup = true
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)


@run
@create_sequence
def testFuseOpNoArg(target):
structured.FuseOp(target)
# CHECK-LABEL: TEST: testFuseOpNoArg
# CHECK: transform.sequence
# CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
# CHECK-SAME: (!transform.any_op) -> !transform.any_op


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we add a test where sizes are not constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added support for arrayAttr.

@run
@create_sequence
def testFuseOpAttributes(target):
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
# CHECK-LABEL: TEST: testFuseOpAttributes
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
# CHECK-SAME: interchange [0, 1]
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)


@run
@create_sequence
def testGeneralize(target):
Expand Down
Loading