Skip to content

Commit 130cafc

Browse files
keyprocedureGasoonjiadigantdesai
authored
[EXIR] Register _clone_dim_order op and map aten.clone (#12971)
### Summary This is PR 2 of 3 implementing a dim order aware clone op. This PR registers the new `_clone_dim_order` op and maps `aten.clone` ops to `dim_order_ops._clone_dim_order` in EXIR during export to preserve memory layout changes (contiguous/channels_last). It also updates Core ML and ARM backends to handle the new clone op. Related PRs: - PR 1: [#12974](#12974) - Add `_clone_dim_order` portable kernel - PR 3: [#12976](#12976) - Update RemoveCloneOpsTransform to be dim order aware Fixes #12645 ### Test plan - Operator level tests to verify kernel behavior for layout preservation and changes. - Graph level checks to confirm that clone mapping occurs. - End to end tests to validate that functional clone behavior is unchanged. All tests pass via: `python -m unittest exir.tests.test_memory_format_ops_pass` `python -m unittest backends.apple.coreml.test.test_torch_ops` `pytest backends/arm/test/ops/test_clone.py` `pytest backends/arm/test/passes/test_remove_clone_pass.py` --------- Co-authored-by: Gasoonjia <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent 18151e4 commit 130cafc

File tree

11 files changed

+231
-5
lines changed

11 files changed

+231
-5
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from coremltools.converters.mil.frontend.torch.ops import (
1616
_get_inputs,
1717
_get_kwinputs,
18+
noop,
1819
NUM_TO_NUMPY_DTYPE,
1920
NUM_TO_TORCH_DTYPE,
2021
split,
@@ -91,6 +92,28 @@ def _to_dim_order_copy(context, node):
9192
to(context, node)
9293

9394

95+
@register_torch_op(
96+
torch_alias=[
97+
"dim_order_ops::_clone_dim_order",
98+
"dim_order_ops._clone_dim_order",
99+
],
100+
override=False,
101+
)
102+
def _clone_dim_order(context, node):
103+
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
104+
node.kwinputs.pop("dim_order")
105+
106+
# In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format.
107+
dim_order = [int(d) for d in dim_order.val]
108+
memory_format = get_memory_format(dim_order)
109+
assert (
110+
memory_format == _torch.contiguous_format
111+
), "Only contiguous memory format is supported in CoreML"
112+
113+
# Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone.
114+
noop(context, node)
115+
116+
94117
# https://github.com/apple/coremltools/pull/2558
95118
@register_torch_op(
96119
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,28 @@ def test_dequantize_codebook_embedding(self):
221221
et_prog = delegated_program.to_executorch()
222222
self._compare_outputs(et_prog, model, example_inputs)
223223

224+
def test__clone_dim_order_contiguous(self):
225+
class Model(torch.nn.Module):
226+
def forward(self, x):
227+
return torch.ops.dim_order_ops._clone_dim_order(
228+
x, dim_order=[0, 1, 2, 3]
229+
)
230+
231+
model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),)
232+
ep = torch.export.export(model, example_inputs)
233+
delegated_program = executorch.exir.to_edge_transform_and_lower(
234+
ep,
235+
partitioner=[self._coreml_partitioner()],
236+
)
237+
for node in delegated_program.exported_program().graph.nodes:
238+
if node.op == "call_function":
239+
assert node.target.__name__ in [
240+
"executorch_call_delegate",
241+
"getitem",
242+
], f"Got unexpected node target after delegation: {node.target.__name__}"
243+
et_prog = delegated_program.to_executorch()
244+
self._compare_outputs(et_prog, model, example_inputs)
245+
224246

225247
if __name__ == "__main__":
226248
test_runner = TestTorchOps()
@@ -231,3 +253,4 @@ def test_dequantize_codebook_embedding(self):
231253
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
232254
test_runner.test_dequantize_codebook_linear()
233255
test_runner.test_dequantize_codebook_embedding()
256+
test_runner.test__clone_dim_order_contiguous()

backends/arm/_passes/remove_clone_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class RemoveClonePass(ExportPass):
1414
"""Remove all clones from graph_module"""
1515

1616
def call_operator(self, op, args, kwargs, meta):
17-
if op != exir_ops.edge.aten.clone.default:
17+
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
1818
return super().call_operator(op, args, kwargs, meta)
1919

2020
if len(args) != 1:

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
clone_dim_order_support,
910
convolution_support,
1011
embedding_support,
1112
ethos_u55_support,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
import logging
8+
9+
import torch
10+
import torch.fx as fx
11+
12+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
13+
register_tosa_support_check,
14+
SupportedTOSAOperatorCheck,
15+
)
16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
from executorch.exir.dialects._ops import ops as exir_ops
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@register_tosa_support_check
23+
class CloneDimOrderSupport(SupportedTOSAOperatorCheck):
24+
targets = [
25+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
26+
]
27+
28+
tosa_specs = [
29+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
30+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
31+
]
32+
33+
def is_node_tosa_supported(
34+
self, node: fx.Node, tosa_spec: TosaSpecification
35+
) -> bool:
36+
assert node.target in self.targets
37+
38+
# Check input type
39+
assert len(node.all_input_nodes) == 1
40+
input_val = node.all_input_nodes[0].meta["val"]
41+
assert isinstance(input_val, torch._subclasses.FakeTensor)
42+
input_dtype = input_val.dtype
43+
44+
# Check output type
45+
output_val = node.meta["val"]
46+
assert isinstance(output_val, torch._subclasses.FakeTensor)
47+
if output_val.dtype != input_dtype:
48+
self.reporter.report_reject(
49+
node,
50+
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
51+
)
52+
return False
53+
54+
# Check memory format
55+
if "memory_format" in node.kwargs:
56+
if node.kwargs["memory_format"] in (torch.preserve_format,):
57+
self.reporter.report_reject(
58+
node,
59+
f"Argument 'memory_format' is not supported for "
60+
f"{node.target} right now.",
61+
)
62+
return False
63+
64+
# Check dim_order
65+
if "dim_order" in node.kwargs:
66+
dim_order = node.kwargs["dim_order"]
67+
# pyre-ignore[6]
68+
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
69+
self.reporter.report_reject(
70+
node,
71+
f"Argument {dim_order=} is not supported for "
72+
f"{node.target} right now.",
73+
)
74+
return False
75+
76+
return True

backends/arm/test/misc/test_partition_decomposed_quantized_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
]
3939
linear_residual_exir_op: list[str] = [
4040
"executorch_exir_dialects_edge__ops_aten_gelu_default",
41-
"executorch_exir_dialects_edge__ops_aten_clone_default",
41+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
4242
"executorch_exir_dialects_edge__ops_aten_linear_default",
4343
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
4444
]

backends/arm/test/ops/test_clone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525
aten_op = "torch.ops.aten.clone.default"
26-
exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default"
26+
exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
2727

2828
input_t = Tuple[torch.Tensor]
2929

backends/arm/test/passes/test_remove_clone_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ def test_remove_clone_tosa_INT():
3535
module.get_inputs(),
3636
quantize=True,
3737
ops_before_pass={
38-
"executorch_exir_dialects_edge__ops_aten_clone_default": 1,
38+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1,
3939
},
40-
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"],
40+
ops_not_after_pass=[
41+
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
42+
],
4143
pass_list=[RemoveClonePass],
4244
)
4345
pipeline.run()

exir/passes/dim_order_ops_registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@
2828
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
2929
)
3030

31+
lib.define(
32+
"_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
33+
)
34+
35+
lib.define(
36+
"_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
37+
)
38+
3139

3240
def _op_impl(target, *args, **kwargs):
3341
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
@@ -57,12 +65,23 @@ def _empty_dim_order_out_impl(*args, **kwargs):
5765
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)
5866

5967

68+
@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd")
69+
def _clone_dim_order_impl(*args, **kwargs):
70+
return _op_impl(torch.ops.aten.clone.default, *args, **kwargs)
71+
72+
73+
@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd")
74+
def _clone_dim_order_out_impl(*args, **kwargs):
75+
return _op_impl(torch.ops.aten.clone.out, *args, **kwargs)
76+
77+
6078
"""
6179
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
6280
"""
6381
DimOrderOpsMap = {
6482
exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
6583
exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default,
84+
exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default,
6685
}
6786

6887
"""

exir/tests/test_memory_format_ops_pass.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
AmbiguousDimOrderError,
2828
MemoryFormatOpsPassTestUtils,
2929
MemoryFormatTestSet,
30+
PropagateToCloneChannelsLastModule,
3031
PropagateToCopyChannalsLastModule,
32+
SimpleCloneChannelsLastModule,
33+
SimpleCloneContiguousModule,
3134
SimpleEmptyChannelLastModule,
3235
SimpleEmptyContiguoustModule,
3336
SimpleToCopyChannelsLastModule,
@@ -91,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None:
9194
),
9295
)
9396

97+
def test_op_clone_replacement_contiguous(self) -> None:
98+
model = SimpleCloneContiguousModule()
99+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
100+
self,
101+
MemoryFormatTestSet(
102+
module=model.eval(),
103+
op=torch.ops.aten.clone.default,
104+
sample_input=(
105+
torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last),
106+
),
107+
target_memory_format=torch.contiguous_format,
108+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
109+
),
110+
)
111+
112+
def test_op_clone_replacement_channels_last(self) -> None:
113+
model = SimpleCloneChannelsLastModule()
114+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
115+
self,
116+
MemoryFormatTestSet(
117+
module=model.eval(),
118+
op=torch.ops.aten.clone.default,
119+
sample_input=(
120+
torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format),
121+
),
122+
target_memory_format=torch.channels_last,
123+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
124+
),
125+
)
126+
94127
def test_op_dim_order_update(self) -> None:
95128
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
96129
self,
@@ -128,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None:
128161
check_unambiguous_dim_order=True,
129162
)
130163

164+
def test_op_clone_dim_order_propagation(self) -> None:
165+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
166+
self,
167+
MemoryFormatTestSet(
168+
module=PropagateToCloneChannelsLastModule().eval(),
169+
op=torch.ops.aten.clone.default,
170+
sample_input=(
171+
torch.rand_like(
172+
torch.zeros([2, 2, 2, 2]),
173+
dtype=torch.float32,
174+
memory_format=torch.contiguous_format,
175+
),
176+
),
177+
target_memory_format=torch.channels_last,
178+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
179+
),
180+
check_unambiguous_dim_order=True,
181+
)
182+
131183
def test_op_dim_order_propagation_ambiguous(self) -> None:
132184
try:
133185
MemoryFormatOpsPassTestUtils.memory_format_test_runner(

0 commit comments

Comments
 (0)