Skip to content

Commit 84de45c

Browse files
committed
Merge remote-tracking branch 'origin/main' into jni-layer-llama-1
2 parents 08a9f96 + 686bb71 commit 84de45c

File tree

152 files changed

+2006
-595
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

152 files changed

+2006
-595
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from coremltools.converters.mil.frontend.torch.ops import (
1616
_get_inputs,
1717
_get_kwinputs,
18+
noop,
1819
NUM_TO_NUMPY_DTYPE,
1920
NUM_TO_TORCH_DTYPE,
2021
split,
@@ -91,6 +92,28 @@ def _to_dim_order_copy(context, node):
9192
to(context, node)
9293

9394

95+
@register_torch_op(
96+
torch_alias=[
97+
"dim_order_ops::_clone_dim_order",
98+
"dim_order_ops._clone_dim_order",
99+
],
100+
override=False,
101+
)
102+
def _clone_dim_order(context, node):
103+
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
104+
node.kwinputs.pop("dim_order")
105+
106+
# In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format.
107+
dim_order = [int(d) for d in dim_order.val]
108+
memory_format = get_memory_format(dim_order)
109+
assert (
110+
memory_format == _torch.contiguous_format
111+
), "Only contiguous memory format is supported in CoreML"
112+
113+
# Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone.
114+
noop(context, node)
115+
116+
94117
# https://github.com/apple/coremltools/pull/2558
95118
@register_torch_op(
96119
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,28 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
268268
et_prog = delegated_program.to_executorch()
269269
self._compare_outputs(et_prog, model, example_inputs)
270270

271+
def test__clone_dim_order_contiguous(self):
272+
class Model(torch.nn.Module):
273+
def forward(self, x):
274+
return torch.ops.dim_order_ops._clone_dim_order(
275+
x, dim_order=[0, 1, 2, 3]
276+
)
277+
278+
model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),)
279+
ep = torch.export.export(model, example_inputs)
280+
delegated_program = executorch.exir.to_edge_transform_and_lower(
281+
ep,
282+
partitioner=[self._coreml_partitioner()],
283+
)
284+
for node in delegated_program.exported_program().graph.nodes:
285+
if node.op == "call_function":
286+
assert node.target.__name__ in [
287+
"executorch_call_delegate",
288+
"getitem",
289+
], f"Got unexpected node target after delegation: {node.target.__name__}"
290+
et_prog = delegated_program.to_executorch()
291+
self._compare_outputs(et_prog, model, example_inputs)
292+
271293

272294
if __name__ == "__main__":
273295
test_runner = TestTorchOps()
@@ -280,3 +302,4 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
280302
test_runner.test_dequantize_codebook_linear_per_grouped_row()
281303
test_runner.test_dequantize_codebook_embedding_per_grouped_col()
282304
test_runner.test_dequantize_codebook_embedding_per_grouped_row()
305+
test_runner.test__clone_dim_order_contiguous()

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@
9191
UnsqueezeBeforeRepeatPass,
9292
UnsqueezeScalarPlaceholdersPass,
9393
)
94-
from executorch.backends.arm.tosa_specification import (
94+
95+
from executorch.backends.arm.tosa.specification import (
9596
TosaLoweringContext,
9697
TosaSpecification,
9798
)

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
from executorch.backends.arm._passes import ArmPass
1111
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
12-
from executorch.backends.arm.operator_support.pool_2d_support import AvgPool2dSupported
1312
from executorch.exir.backend.utils import WhyNoPartitionReporter
1413
from executorch.exir.dialects._ops import ops as exir_ops
1514

@@ -67,6 +66,11 @@ def __init__(self, graph_module, tosa_spec):
6766
super().__init__()
6867
self._graph_module = graph_module
6968
self._tosa_spec = tosa_spec
69+
# Lazy import to avoid circular dependency with operator_support
70+
from executorch.backends.arm.operator_support.pool_2d_support import (
71+
AvgPool2dSupported,
72+
)
73+
7074
self._avg_pool_checker = AvgPool2dSupported(
7175
self._tosa_spec, WhyNoPartitionReporter()
7276
)

backends/arm/_passes/remove_clone_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class RemoveClonePass(ExportPass):
1818
"""Remove all clones from graph_module"""
1919

2020
def call_operator(self, op, args, kwargs, meta):
21-
if op != exir_ops.edge.aten.clone.default:
21+
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
2222
return super().call_operator(op, args, kwargs, meta)
2323

2424
if len(args) != 1:

backends/arm/arm_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from enum import Enum
1414
from typing import List, Optional
1515

16-
from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
17-
TosaSpecification,
18-
)
16+
from executorch.backends.arm.tosa import TosaSpecification
1917

2018
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
2119
CompileSpec,

backends/arm/ethosu/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from executorch.backends.arm.arm_vela import vela_compile
1818

19-
from executorch.backends.arm.tosa_backend import TOSABackend
19+
from executorch.backends.arm.tosa.backend import TOSABackend
2020
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2121
from executorch.exir.backend.compile_spec_schema import CompileSpec
2222
from torch.export.exported_program import ExportedProgram

backends/arm/ethosu/partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
is_ethosu,
1212
) # usort: skip
1313
from executorch.backends.arm.ethosu import EthosUBackend
14-
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
14+
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1515
from executorch.exir.backend.compile_spec_schema import CompileSpec
1616
from executorch.exir.backend.partitioner import DelegationSpec
1717
from torch.fx.passes.operator_support import OperatorSupportBase

backends/arm/operator_support/clone_support.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@
55

66
import logging
77

8+
import torch
89
import torch.fx as fx
910
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1011
register_tosa_support_check,
1112
SupportedTOSAOperatorCheck,
1213
)
13-
from executorch.backends.arm.tosa_specification import TosaSpecification
14+
from executorch.backends.arm.tosa import TosaSpecification
1415
from executorch.exir.dialects._ops import ops as exir_ops
1516

1617
logger = logging.getLogger(__name__)
1718

1819

1920
@register_tosa_support_check
2021
class CloneSupported(SupportedTOSAOperatorCheck):
21-
targets = [exir_ops.edge.aten.clone.default]
22+
targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default]
2223

2324
tosa_specs = [
2425
TosaSpecification.create_from_string("TOSA-1.0+INT"),
@@ -28,10 +29,62 @@ class CloneSupported(SupportedTOSAOperatorCheck):
2829
def is_node_tosa_supported(
2930
self, node: fx.Node, tosa_spec: TosaSpecification
3031
) -> bool:
32+
if node.target not in self.targets:
33+
self.reporter.report_reject(node, f"Target {node.target} is not supported.")
34+
return False
3135

3236
input_node = node.args[0]
3337
if not isinstance(input_node, fx.Node):
3438
self.reporter.report_reject(node, "Non tensor clones are not supported")
3539
return False
3640

41+
# Check input node
42+
if len(node.all_input_nodes) != 1:
43+
self.reporter.report_reject(
44+
node, f"Expected 1 input node, got {len(node.all_input_nodes)}"
45+
)
46+
return False
47+
48+
input_val = node.all_input_nodes[0].meta["val"]
49+
if not isinstance(input_val, torch._subclasses.FakeTensor):
50+
self.reporter.report_reject(node, "Expected input to be a FakeTensor.")
51+
return False
52+
53+
input_dtype = input_val.dtype
54+
55+
# Check output node
56+
output_val = node.meta["val"]
57+
if not isinstance(output_val, torch._subclasses.FakeTensor):
58+
self.reporter.report_reject(node, "Expected output to be a FakeTensor.")
59+
return False
60+
61+
if output_val.dtype != input_dtype:
62+
self.reporter.report_reject(
63+
node,
64+
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
65+
)
66+
return False
67+
68+
# Check memory format
69+
if "memory_format" in node.kwargs:
70+
if node.kwargs["memory_format"] in (torch.preserve_format,):
71+
self.reporter.report_reject(
72+
node,
73+
f"Argument 'memory_format' is not supported for "
74+
f"{node.target} right now.",
75+
)
76+
return False
77+
78+
# Check dim_order
79+
if "dim_order" in node.kwargs:
80+
dim_order = node.kwargs["dim_order"]
81+
# pyre-ignore[6]
82+
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
83+
self.reporter.report_reject(
84+
node,
85+
f"Argument {dim_order=} is not supported for "
86+
f"{node.target} right now.",
87+
)
88+
return False
89+
3790
return True

backends/arm/operator_support/convolution_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import TosaSpecification
14+
from executorch.backends.arm.tosa import TosaSpecification
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops
1717

0 commit comments

Comments
 (0)