Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
877119f
Register clone_dim_order op; add test for op replacement
keyprocedure Jul 29, 2025
f75845d
Rename clone_dim_order op registration test
keyprocedure Jul 29, 2025
83d8c75
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Jul 29, 2025
cff39c9
Add graph level and end to end tests for _clone_dim_order op
keyprocedure Aug 3, 2025
1fe461f
Remove _clone_dim_order op registration (moved to PR #12974)
keyprocedure Aug 3, 2025
95db027
Register _clone_dim_order op
keyprocedure Aug 6, 2025
63d45e7
Merge branch 'main' into add-dim-order-clone-aot
keyprocedure Aug 11, 2025
d9a181c
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Aug 12, 2025
c7caa27
Register _clone_dim_order as no-op in CoreML
keyprocedure Aug 13, 2025
e54605f
Remove redundant _clone_dim_order graph check
keyprocedure Aug 13, 2025
246bc44
Add _clone_dim_order to RemoveClonePass and update op name in tests
keyprocedure Aug 13, 2025
c48467c
Register _clone_dim_order under TOSA support check
keyprocedure Aug 13, 2025
e262a36
Merge branch 'main' into add-dim-order-clone-aot
digantdesai Aug 14, 2025
5546360
Add clone_dim_order_support to TOSA operator support list
keyprocedure Aug 16, 2025
5c5e65a
Register node visitor for _clone_dim_order
keyprocedure Aug 16, 2025
fe7dd11
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Aug 19, 2025
7a0bc6a
Remove visitor node registration for _clone_dim_order
keyprocedure Aug 25, 2025
74e2cce
Remove aten.clone check from RemoveClonePass
keyprocedure Aug 25, 2025
f9f9515
Remove input dtype gating and add memory_format check
keyprocedure Aug 25, 2025
8d0cb06
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Aug 25, 2025
6839212
Add Core ML test for _clone_dim_order
keyprocedure Aug 25, 2025
e8ceb5d
Merge branch 'main' into add-dim-order-clone-aot
keyprocedure Sep 3, 2025
b76cf7b
Refactor clone_dim_order_support to existing clone_support file
keyprocedure Sep 3, 2025
5bfd58b
Replace edge.aten.clone with edge.dim_order_ops._clone_dim_order
keyprocedure Sep 3, 2025
f8ab347
Merge branch 'main' into add-dim-order-clone-aot
keyprocedure Sep 4, 2025
4a2383b
Fix formatting
keyprocedure Sep 4, 2025
30a1d13
Merge branch 'main' into add-dim-order-clone-aot
Gasoonjia Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backends/apple/coreml/compiler/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from coremltools.converters.mil.frontend.torch.ops import (
_get_inputs,
_get_kwinputs,
noop,
NUM_TO_NUMPY_DTYPE,
NUM_TO_TORCH_DTYPE,
split,
Expand Down Expand Up @@ -91,6 +92,28 @@ def _to_dim_order_copy(context, node):
to(context, node)


@register_torch_op(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

torch_alias=[
"dim_order_ops::_clone_dim_order",
"dim_order_ops._clone_dim_order",
],
override=False,
)
def _clone_dim_order(context, node):
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
node.kwinputs.pop("dim_order")

# In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format.
dim_order = [int(d) for d in dim_order.val]
memory_format = get_memory_format(dim_order)
assert (
memory_format == _torch.contiguous_format
), "Only contiguous memory format is supported in CoreML"

# Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone.
noop(context, node)


# https://github.com/apple/coremltools/pull/2558
@register_torch_op(
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
Expand Down
23 changes: 23 additions & 0 deletions backends/apple/coreml/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,28 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test__clone_dim_order_contiguous(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.ops.dim_order_ops._clone_dim_order(
x, dim_order=[0, 1, 2, 3]
)

model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)


if __name__ == "__main__":
test_runner = TestTorchOps()
Expand All @@ -280,3 +302,4 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
test_runner.test_dequantize_codebook_linear_per_grouped_row()
test_runner.test_dequantize_codebook_embedding_per_grouped_col()
test_runner.test_dequantize_codebook_embedding_per_grouped_row()
test_runner.test__clone_dim_order_contiguous()
2 changes: 1 addition & 1 deletion backends/arm/_passes/remove_clone_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class RemoveClonePass(ExportPass):
"""Remove all clones from graph_module"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.clone.default:
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
return super().call_operator(op, args, kwargs, meta)

if len(args) != 1:
Expand Down
55 changes: 54 additions & 1 deletion backends/arm/operator_support/clone_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import logging

import torch
import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
Expand All @@ -18,7 +19,7 @@

@register_tosa_support_check
class CloneSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.clone.default]
targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
Expand All @@ -28,10 +29,62 @@ class CloneSupported(SupportedTOSAOperatorCheck):
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
if node.target not in self.targets:
self.reporter.report_reject(node, f"Target {node.target} is not supported.")
return False

input_node = node.args[0]
if not isinstance(input_node, fx.Node):
self.reporter.report_reject(node, "Non tensor clones are not supported")
return False

# Check input node
if len(node.all_input_nodes) != 1:
self.reporter.report_reject(
node, f"Expected 1 input node, got {len(node.all_input_nodes)}"
)
return False

input_val = node.all_input_nodes[0].meta["val"]
if not isinstance(input_val, torch._subclasses.FakeTensor):
self.reporter.report_reject(node, "Expected input to be a FakeTensor.")
return False

input_dtype = input_val.dtype

# Check output node
output_val = node.meta["val"]
if not isinstance(output_val, torch._subclasses.FakeTensor):
self.reporter.report_reject(node, "Expected output to be a FakeTensor.")
return False

if output_val.dtype != input_dtype:
self.reporter.report_reject(
node,
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
)
return False

# Check memory format
if "memory_format" in node.kwargs:
if node.kwargs["memory_format"] in (torch.preserve_format,):
self.reporter.report_reject(
node,
f"Argument 'memory_format' is not supported for "
f"{node.target} right now.",
)
return False

# Check dim_order
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
self.reporter.report_reject(
node,
f"Argument {dim_order=} is not supported for "
f"{node.target} right now.",
)
return False

return True
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
]
linear_residual_exir_op: list[str] = [
"executorch_exir_dialects_edge__ops_aten_gelu_default",
"executorch_exir_dialects_edge__ops_aten_clone_default",
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
"executorch_exir_dialects_edge__ops_aten_linear_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
]
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)

aten_op = "torch.ops.aten.clone.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default"
exir_op = "executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"

input_t = Tuple[torch.Tensor]

Expand Down
6 changes: 4 additions & 2 deletions backends/arm/test/passes/test_remove_clone_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def test_remove_clone_tosa_INT():
module.get_inputs(),
quantize=True,
ops_before_pass={
"executorch_exir_dialects_edge__ops_aten_clone_default": 1,
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default": 1,
},
ops_not_after_pass=["executorch_exir_dialects_edge__ops_aten_clone_default"],
ops_not_after_pass=[
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default"
],
pass_list=[RemoveClonePass],
)
pipeline.run()
2 changes: 1 addition & 1 deletion backends/qualcomm/_passes/convert_bmm_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ConvertBmmToMatmul(ExportPass):

view_copy = exir_ops.edge.aten.view_copy.default
expand_copy = exir_ops.edge.aten.expand_copy.default
clone = exir_ops.edge.aten.clone.default
clone = exir_ops.edge.dim_order_ops._clone_dim_order.default
bmm = exir_ops.edge.aten.bmm.default
matmul = exir_ops.edge.aten.matmul.default
patterns = [
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/_passes/remove_redundancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, quantization_capture=False):
self.redundant_ops_general = {
torch.clone: self._default_condition,
torch.ops.aten.clone.default: self._default_condition,
exir_ops.edge.aten.clone.default: self._default_condition,
exir_ops.edge.dim_order_ops._clone_dim_order.default: self._default_condition,
torch.ops.aten.alias.default: self._default_condition,
exir_ops.edge.aten.alias.default: self._default_condition,
exir_ops.edge.aten.alias_copy.default: self._default_condition,
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from executorch.exir.dialects._ops import ops as exir_ops

not_supported_operator = [
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def _insert_clone(
users = list(node.users.keys())
inserted_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.clone.default,
exir_ops.edge.dim_order_ops._clone_dim_order.default,
(node,),
)
inserted_node.meta["val"] = node.meta["val"]
Expand Down
19 changes: 19 additions & 0 deletions exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@
"_empty_dim_order.out(int[] size, *, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)

lib.define(
"_clone_dim_order(Tensor self, *, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
)

lib.define(
"_clone_dim_order.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
)


def _op_impl(target, *args, **kwargs):
kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
Expand Down Expand Up @@ -57,12 +65,23 @@ def _empty_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.empty.out, *args, **kwargs)


@impl(lib, "_clone_dim_order", "CompositeImplicitAutograd")
def _clone_dim_order_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.clone.default, *args, **kwargs)


@impl(lib, "_clone_dim_order.out", "CompositeImplicitAutograd")
def _clone_dim_order_out_impl(*args, **kwargs):
return _op_impl(torch.ops.aten.clone.out, *args, **kwargs)


"""
Defines a map of edge ops to the corresponding dim_order ops for quick lookup
"""
DimOrderOpsMap = {
exir_ops.edge.aten._to_copy.default: exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
exir_ops.edge.aten.empty.memory_format: exir_ops.edge.dim_order_ops._empty_dim_order.default,
exir_ops.edge.aten.clone.default: exir_ops.edge.dim_order_ops._clone_dim_order.default,
}

"""
Expand Down
52 changes: 52 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
AmbiguousDimOrderError,
MemoryFormatOpsPassTestUtils,
MemoryFormatTestSet,
PropagateToCloneChannelsLastModule,
PropagateToCopyChannalsLastModule,
SimpleCloneChannelsLastModule,
SimpleCloneContiguousModule,
SimpleEmptyChannelLastModule,
SimpleEmptyContiguoustModule,
SimpleToCopyChannelsLastModule,
Expand Down Expand Up @@ -91,6 +94,36 @@ def test_op_empty_replacement_contiguous(self) -> None:
),
)

def test_op_clone_replacement_contiguous(self) -> None:
model = SimpleCloneContiguousModule()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten.clone.default,
sample_input=(
torch.randn((3, 4, 5, 6)).to(memory_format=torch.channels_last),
),
target_memory_format=torch.contiguous_format,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_clone_replacement_channels_last(self) -> None:
model = SimpleCloneChannelsLastModule()
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=model.eval(),
op=torch.ops.aten.clone.default,
sample_input=(
torch.randn((3, 4, 5, 6)).to(memory_format=torch.contiguous_format),
),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

def test_op_dim_order_update(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
Expand Down Expand Up @@ -128,6 +161,25 @@ def test_op_dim_order_propagation(self) -> None:
check_unambiguous_dim_order=True,
)

def test_op_clone_dim_order_propagation(self) -> None:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=PropagateToCloneChannelsLastModule().eval(),
op=torch.ops.aten.clone.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 2, 2, 2]),
dtype=torch.float32,
memory_format=torch.contiguous_format,
),
),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
check_unambiguous_dim_order=True,
)

def test_op_dim_order_propagation_ambiguous(self) -> None:
try:
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
Expand Down
30 changes: 30 additions & 0 deletions exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
"torch.ops.aten.empty.memory_format",
"executorch_exir_dialects_edge__ops_dim_order_ops__empty_dim_order_default",
),
torch.ops.aten.clone.default: (
"torch.ops.aten.clone.default",
"executorch_exir_dialects_edge__ops_dim_order_ops__clone_dim_order_default",
),
}


Expand Down Expand Up @@ -70,6 +74,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(dtype=torch.double, memory_format=torch.channels_last)


class SimpleCloneContiguousModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.clone(memory_format=torch.contiguous_format)


class SimpleCloneChannelsLastModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.clone(memory_format=torch.channels_last)


class SimpleEmptyContiguoustModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -102,6 +122,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return t1 * t2


class PropagateToCloneChannelsLastModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
t1 = x.clone(memory_format=torch.channels_last)
t2 = t1 + t1
return t1 * t2


class AmbiguousDimOrderError(RuntimeError):
pass

Expand Down
Loading