diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 9121aa8e40237..bf40cc532065d 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -140,6 +140,77 @@ def __init__( ) +@_ods_cext.register_operation(_Dialect, replace=True) +class FuseOp(FuseOp): + """Specialization for FuseOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, Sequence[Type]], + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + tile_sizes = tile_sizes if tile_sizes else [] + tile_interchange = tile_interchange if tile_interchange else [] + _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) + _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) + num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct FuseOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + super().__init__( + target.type, + loop_types, + target, + tile_sizes=tile_sizes, + tile_interchange=tile_interchange, + apply_cleanup=apply_cleanup, + loc=loc, + ip=ip, + ) + + @_ods_cext.register_operation(_Dialect, replace=True) class GeneralizeOp(GeneralizeOp): """Specialization for GeneralizeOp class.""" diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index fb4c75b533792..8785d6d360074 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -101,6 +101,42 @@ def testFuseIntoContainingOpCompact(target): # CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) +@run +@create_sequence +def testFuseOpCompact(target): + structured.FuseOp( + target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True + ) + # CHECK-LABEL: TEST: testFuseOpCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK-SAME: interchange [0, 1] apply_cleanup = true + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpNoArg(target): + structured.FuseOp(target) + # CHECK-LABEL: TEST: testFuseOpNoArg + # CHECK: transform.sequence + # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} : + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +@create_sequence +def testFuseOpAttributes(target): + attr = DenseI64ArrayAttr.get([4, 8]) + ichange = DenseI64ArrayAttr.get([0, 1]) + structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) + # CHECK-LABEL: TEST: testFuseOpAttributes + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK-SAME: interchange [0, 1] + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + @run @create_sequence def testGeneralize(target):