diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 81306c9a2fd..e8969c6a7bd 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -15,6 +15,7 @@ from coremltools.converters.mil.frontend.torch.ops import ( _get_inputs, _get_kwinputs, + noop, NUM_TO_NUMPY_DTYPE, NUM_TO_TORCH_DTYPE, split, @@ -67,6 +68,28 @@ def _to_dim_order_copy(context, node): to(context, node) +@register_torch_op( + torch_alias=[ + "dim_order_ops::_clone_dim_order", + "dim_order_ops._clone_dim_order", + ], + override=False, +) +def _clone_dim_order(context, node): + dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0] + node.kwinputs.pop("dim_order") + + # In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format. + dim_order = [int(d) for d in dim_order.val] + memory_format = get_memory_format(dim_order) + assert ( + memory_format == _torch.contiguous_format + ), "Only contiguous memory format is supported in CoreML" + + # Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone. + noop(context, node) + + # https://github.com/apple/coremltools/pull/2558 @register_torch_op( torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py index 0d6b581ee72..25691777aec 100644 --- a/backends/apple/coreml/test/test_torch_ops.py +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -221,6 +221,28 @@ def test_dequantize_codebook_embedding(self): et_prog = delegated_program.to_executorch() self._compare_outputs(et_prog, model, example_inputs) + def test__clone_dim_order_contiguous(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.dim_order_ops._clone_dim_order( + x, dim_order=[0, 1, 2, 3] + ) + + model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + if __name__ == "__main__": test_runner = TestTorchOps() @@ -231,3 +253,4 @@ def test_dequantize_codebook_embedding(self): test_runner.test_dequantize_affine_c8w_embedding_b4w_linear() test_runner.test_dequantize_codebook_linear() test_runner.test_dequantize_codebook_embedding() + test_runner.test__clone_dim_order_contiguous() diff --git a/backends/arm/_passes/remove_clone_pass.py b/backends/arm/_passes/remove_clone_pass.py index a2822c7378e..896d3f54673 100644 --- a/backends/arm/_passes/remove_clone_pass.py +++ b/backends/arm/_passes/remove_clone_pass.py @@ -14,7 +14,7 @@ class RemoveClonePass(ExportPass): """Remove all clones from graph_module""" def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.clone.default: + if op != exir_ops.edge.dim_order_ops._clone_dim_order.default: return super().call_operator(op, args, kwargs, meta) if len(args) != 1: diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 2075e0f554f..5557a2116c6 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -6,6 +6,7 @@ # pyre-unsafe from . import ( # noqa + clone_dim_order_support, convolution_support, embedding_support, ethos_u55_support, diff --git a/backends/arm/operator_support/clone_dim_order_support.py b/backends/arm/operator_support/clone_dim_order_support.py new file mode 100644 index 00000000000..7269f7e7932 --- /dev/null +++ b/backends/arm/operator_support/clone_dim_order_support.py @@ -0,0 +1,76 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import logging + +import torch +import torch.fx as fx + +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) + + +@register_tosa_support_check +class CloneDimOrderSupport(SupportedTOSAOperatorCheck): + targets = [ + exir_ops.edge.dim_order_ops._clone_dim_order.default, + ] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + assert node.target in self.targets + + # Check input type + assert len(node.all_input_nodes) == 1 + input_val = node.all_input_nodes[0].meta["val"] + assert isinstance(input_val, torch._subclasses.FakeTensor) + input_dtype = input_val.dtype + + # Check output type + output_val = node.meta["val"] + assert isinstance(output_val, torch._subclasses.FakeTensor) + if output_val.dtype != input_dtype: + self.reporter.report_reject( + node, + f"Input dtype {input_val.dtype} does not match {output_val.dtype}.", + ) + return False + + # Check memory format + if "memory_format" in node.kwargs: + if node.kwargs["memory_format"] in (torch.preserve_format,): + self.reporter.report_reject( + node, + f"Argument 'memory_format' is not supported for " + f"{node.target} right now.", + ) + return False + + # Check dim_order + if "dim_order" in node.kwargs: + dim_order = node.kwargs["dim_order"] + # pyre-ignore[6] + if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] + self.reporter.report_reject( + node, + f"Argument {dim_order=} is not supported for " + f"{node.target} right now.", + ) + return False + + return True diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index 1aaa2950337..04ecd57e7b1 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -38,7 +38,7 @@ ] linear_residual_exir_op: list[str] = [ "executorch_exir_dialects_edge__ops_aten_gelu_default", - "executorch_exir_dialects_edge__ops_aten_clone_default", + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", "executorch_exir_dialects_edge__ops_aten_linear_default", "executorch_exir_dialects_edge__ops_aten_add_Tensor", ] diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 7a24848697e..5c5f5e9979a 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -23,7 +23,7 @@ ) aten_op = "torch.ops.aten.clone.default" -exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default" +exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" input_t = Tuple[torch.Tensor] diff --git a/backends/arm/test/passes/test_remove_clone_pass.py b/backends/arm/test/passes/test_remove_clone_pass.py index dea0bb06f5e..5c2171795f7 100755 --- a/backends/arm/test/passes/test_remove_clone_pass.py +++ b/backends/arm/test/passes/test_remove_clone_pass.py @@ -35,9 +35,11 @@ def test_remove_clone_tosa_INT(): module.get_inputs(), quantize=True, ops_before_pass={ - "executorch_exir_dialects_edge__ops_aten_clone_default": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1, }, - ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"], + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default" + ], pass_list=[RemoveClonePass], ) pipeline.run() diff --git a/exir/passes/dim_order_ops_registry.py b/exir/passes/dim_order_ops_registry.py index f3fc009f109..7a5dff387c1 100644 --- a/exir/passes/dim_order_ops_registry.py +++ b/exir/passes/dim_order_ops_registry.py @@ -28,6 +28,14 @@ "_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor" +) + +lib.define( + "_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" +) + def _op_impl(target, *args, **kwargs): kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) @@ -57,12 +65,23 @@ def _empty_dim_order_out_impl(*args, **kwargs): return _op_impl(torch.ops.aten.empty.out, *args, **kwargs) +@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd") +def _clone_dim_order_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.clone.default, *args, **kwargs) + + +@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd") +def _clone_dim_order_out_impl(*args, **kwargs): + return _op_impl(torch.ops.aten.clone.out, *args, **kwargs) + + """ Defines a map of edge ops to the corresponding dim_order ops for quick lookup """ DimOrderOpsMap = { exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default, + exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default, } """ diff --git a/exir/tests/test_memory_format_ops_pass.py b/exir/tests/test_memory_format_ops_pass.py index 84cd0faa485..2384f6123a9 100644 --- a/exir/tests/test_memory_format_ops_pass.py +++ b/exir/tests/test_memory_format_ops_pass.py @@ -27,7 +27,10 @@ AmbiguousDimOrderError, MemoryFormatOpsPassTestUtils, MemoryFormatTestSet, + PropagateToCloneChannelsLastModule, PropagateToCopyChannalsLastModule, + SimpleCloneChannelsLastModule, + SimpleCloneContiguousModule, SimpleEmptyChannelLastModule, SimpleEmptyContiguoustModule, SimpleToCopyChannelsLastModule, @@ -91,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None: ), ) + def test_op_clone_replacement_contiguous(self) -> None: + model = SimpleCloneContiguousModule() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + op=torch.ops.aten.clone.default, + sample_input=( + torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last), + ), + target_memory_format=torch.contiguous_format, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + + def test_op_clone_replacement_channels_last(self) -> None: + model = SimpleCloneChannelsLastModule() + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=model.eval(), + op=torch.ops.aten.clone.default, + sample_input=( + torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format), + ), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + ) + def test_op_dim_order_update(self) -> None: MemoryFormatOpsPassTestUtils.memory_format_test_runner( self, @@ -128,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None: check_unambiguous_dim_order=True, ) + def test_op_clone_dim_order_propagation(self) -> None: + MemoryFormatOpsPassTestUtils.memory_format_test_runner( + self, + MemoryFormatTestSet( + module=PropagateToCloneChannelsLastModule().eval(), + op=torch.ops.aten.clone.default, + sample_input=( + torch.rand_like( + torch.zeros([2, 2, 2, 2]), + dtype=torch.float32, + memory_format=torch.contiguous_format, + ), + ), + target_memory_format=torch.channels_last, + _load_for_executorch_from_buffer=_load_for_executorch_from_buffer, + ), + check_unambiguous_dim_order=True, + ) + def test_op_dim_order_propagation_ambiguous(self) -> None: try: MemoryFormatOpsPassTestUtils.memory_format_test_runner( diff --git a/exir/tests/test_memory_format_ops_pass_utils.py b/exir/tests/test_memory_format_ops_pass_utils.py index 6daf38b187f..f5a786c6f74 100644 --- a/exir/tests/test_memory_format_ops_pass_utils.py +++ b/exir/tests/test_memory_format_ops_pass_utils.py @@ -38,6 +38,10 @@ "torch.ops.aten.empty.memory_format", "executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default", ), + torch.ops.aten.clone.default: ( + "torch.ops.aten.clone.default", + "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default", + ), } @@ -70,6 +74,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.to(dtype=torch.double, memory_format=torch.channels_last) +class SimpleCloneContiguousModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.clone(memory_format=torch.contiguous_format) + + +class SimpleCloneChannelsLastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.clone(memory_format=torch.channels_last) + + class SimpleEmptyContiguoustModule(torch.nn.Module): def __init__(self): super().__init__() @@ -102,6 +122,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return t1 * t2 +class PropagateToCloneChannelsLastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + t1 = x.clone(memory_format=torch.channels_last) + t2 = t1 + t1 + return t1 * t2 + + class AmbiguousDimOrderError(RuntimeError): pass