Skip to content

Commit ad7aa5f

Browse files
committed
Update on "Arm backend: Add 16A8W support and test for mul operation"
Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents 70f236c + d029491 commit ad7aa5f

File tree

163 files changed

+1627
-1696
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

163 files changed

+1627
-1696
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from coremltools.converters.mil.frontend.torch.ops import (
1616
_get_inputs,
1717
_get_kwinputs,
18+
noop,
1819
NUM_TO_NUMPY_DTYPE,
1920
NUM_TO_TORCH_DTYPE,
2021
split,
@@ -91,6 +92,28 @@ def _to_dim_order_copy(context, node):
9192
to(context, node)
9293

9394

95+
@register_torch_op(
96+
torch_alias=[
97+
"dim_order_ops::_clone_dim_order",
98+
"dim_order_ops._clone_dim_order",
99+
],
100+
override=False,
101+
)
102+
def _clone_dim_order(context, node):
103+
dim_order = _get_kwinputs(context, node, "dim_order", default=[None])[0]
104+
node.kwinputs.pop("dim_order")
105+
106+
# In CoreML, dim_order.val will be a ndarray, so we convert it to a list to check memory format.
107+
dim_order = [int(d) for d in dim_order.val]
108+
memory_format = get_memory_format(dim_order)
109+
assert (
110+
memory_format == _torch.contiguous_format
111+
), "Only contiguous memory format is supported in CoreML"
112+
113+
# Since CoreML only supports contiguous format, no dim_order preservation is needed. Treat this as a no-op clone.
114+
noop(context, node)
115+
116+
94117
# https://github.com/apple/coremltools/pull/2558
95118
@register_torch_op(
96119
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,28 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
268268
et_prog = delegated_program.to_executorch()
269269
self._compare_outputs(et_prog, model, example_inputs)
270270

271+
def test__clone_dim_order_contiguous(self):
272+
class Model(torch.nn.Module):
273+
def forward(self, x):
274+
return torch.ops.dim_order_ops._clone_dim_order(
275+
x, dim_order=[0, 1, 2, 3]
276+
)
277+
278+
model, example_inputs = Model(), (torch.randn(1, 3, 8, 8),)
279+
ep = torch.export.export(model, example_inputs)
280+
delegated_program = executorch.exir.to_edge_transform_and_lower(
281+
ep,
282+
partitioner=[self._coreml_partitioner()],
283+
)
284+
for node in delegated_program.exported_program().graph.nodes:
285+
if node.op == "call_function":
286+
assert node.target.__name__ in [
287+
"executorch_call_delegate",
288+
"getitem",
289+
], f"Got unexpected node target after delegation: {node.target.__name__}"
290+
et_prog = delegated_program.to_executorch()
291+
self._compare_outputs(et_prog, model, example_inputs)
292+
271293

272294
if __name__ == "__main__":
273295
test_runner = TestTorchOps()
@@ -280,3 +302,4 @@ def test_dequantize_codebook_embedding_per_grouped_row(self):
280302
test_runner.test_dequantize_codebook_linear_per_grouped_row()
281303
test_runner.test_dequantize_codebook_embedding_per_grouped_col()
282304
test_runner.test_dequantize_codebook_embedding_per_grouped_row()
305+
test_runner.test__clone_dim_order_contiguous()

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@
9191
UnsqueezeBeforeRepeatPass,
9292
UnsqueezeScalarPlaceholdersPass,
9393
)
94-
from executorch.backends.arm.tosa_specification import (
94+
95+
from executorch.backends.arm.tosa.specification import (
9596
TosaLoweringContext,
9697
TosaSpecification,
9798
)

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
from executorch.backends.arm._passes import ArmPass
1111
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
12-
from executorch.backends.arm.operator_support.pool_2d_support import AvgPool2dSupported
1312
from executorch.exir.backend.utils import WhyNoPartitionReporter
1413
from executorch.exir.dialects._ops import ops as exir_ops
1514

@@ -67,6 +66,11 @@ def __init__(self, graph_module, tosa_spec):
6766
super().__init__()
6867
self._graph_module = graph_module
6968
self._tosa_spec = tosa_spec
69+
# Lazy import to avoid circular dependency with operator_support
70+
from executorch.backends.arm.operator_support.pool_2d_support import (
71+
AvgPool2dSupported,
72+
)
73+
7074
self._avg_pool_checker = AvgPool2dSupported(
7175
self._tosa_spec, WhyNoPartitionReporter()
7276
)

backends/arm/_passes/remove_clone_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class RemoveClonePass(ExportPass):
1818
"""Remove all clones from graph_module"""
1919

2020
def call_operator(self, op, args, kwargs, meta):
21-
if op != exir_ops.edge.aten.clone.default:
21+
if op != exir_ops.edge.dim_order_ops._clone_dim_order.default:
2222
return super().call_operator(op, args, kwargs, meta)
2323

2424
if len(args) != 1:

backends/arm/arm_backend.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@
1313
from enum import Enum
1414
from typing import List, Optional
1515

16-
from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
17-
TosaSpecification,
18-
)
16+
from executorch.backends.arm.tosa import TosaSpecification
1917

2018
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
2119
CompileSpec,

backends/arm/ethosu/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from executorch.backends.arm.arm_vela import vela_compile
1818

19-
from executorch.backends.arm.tosa_backend import TOSABackend
19+
from executorch.backends.arm.tosa.backend import TOSABackend
2020
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2121
from executorch.exir.backend.compile_spec_schema import CompileSpec
2222
from torch.export.exported_program import ExportedProgram

backends/arm/ethosu/partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
is_ethosu,
1212
) # usort: skip
1313
from executorch.backends.arm.ethosu import EthosUBackend
14-
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
14+
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1515
from executorch.exir.backend.compile_spec_schema import CompileSpec
1616
from executorch.exir.backend.partitioner import DelegationSpec
1717
from torch.fx.passes.operator_support import OperatorSupportBase

backends/arm/operator_support/clone_support.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@
55

66
import logging
77

8+
import torch
89
import torch.fx as fx
910
from executorch.backends.arm.operator_support.tosa_supported_operators import (
1011
register_tosa_support_check,
1112
SupportedTOSAOperatorCheck,
1213
)
13-
from executorch.backends.arm.tosa_specification import TosaSpecification
14+
from executorch.backends.arm.tosa import TosaSpecification
1415
from executorch.exir.dialects._ops import ops as exir_ops
1516

1617
logger = logging.getLogger(__name__)
1718

1819

1920
@register_tosa_support_check
2021
class CloneSupported(SupportedTOSAOperatorCheck):
21-
targets = [exir_ops.edge.aten.clone.default]
22+
targets = [exir_ops.edge.dim_order_ops._clone_dim_order.default]
2223

2324
tosa_specs = [
2425
TosaSpecification.create_from_string("TOSA-1.0+INT"),
@@ -28,10 +29,62 @@ class CloneSupported(SupportedTOSAOperatorCheck):
2829
def is_node_tosa_supported(
2930
self, node: fx.Node, tosa_spec: TosaSpecification
3031
) -> bool:
32+
if node.target not in self.targets:
33+
self.reporter.report_reject(node, f"Target {node.target} is not supported.")
34+
return False
3135

3236
input_node = node.args[0]
3337
if not isinstance(input_node, fx.Node):
3438
self.reporter.report_reject(node, "Non tensor clones are not supported")
3539
return False
3640

41+
# Check input node
42+
if len(node.all_input_nodes) != 1:
43+
self.reporter.report_reject(
44+
node, f"Expected 1 input node, got {len(node.all_input_nodes)}"
45+
)
46+
return False
47+
48+
input_val = node.all_input_nodes[0].meta["val"]
49+
if not isinstance(input_val, torch._subclasses.FakeTensor):
50+
self.reporter.report_reject(node, "Expected input to be a FakeTensor.")
51+
return False
52+
53+
input_dtype = input_val.dtype
54+
55+
# Check output node
56+
output_val = node.meta["val"]
57+
if not isinstance(output_val, torch._subclasses.FakeTensor):
58+
self.reporter.report_reject(node, "Expected output to be a FakeTensor.")
59+
return False
60+
61+
if output_val.dtype != input_dtype:
62+
self.reporter.report_reject(
63+
node,
64+
f"Input dtype {input_val.dtype} does not match {output_val.dtype}.",
65+
)
66+
return False
67+
68+
# Check memory format
69+
if "memory_format" in node.kwargs:
70+
if node.kwargs["memory_format"] in (torch.preserve_format,):
71+
self.reporter.report_reject(
72+
node,
73+
f"Argument 'memory_format' is not supported for "
74+
f"{node.target} right now.",
75+
)
76+
return False
77+
78+
# Check dim_order
79+
if "dim_order" in node.kwargs:
80+
dim_order = node.kwargs["dim_order"]
81+
# pyre-ignore[6]
82+
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
83+
self.reporter.report_reject(
84+
node,
85+
f"Argument {dim_order=} is not supported for "
86+
f"{node.target} right now.",
87+
)
88+
return False
89+
3790
return True

backends/arm/operator_support/convolution_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_tosa_support_check,
1212
SupportedTOSAOperatorCheck,
1313
)
14-
from executorch.backends.arm.tosa_specification import TosaSpecification
14+
from executorch.backends.arm.tosa import TosaSpecification
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops
1717

0 commit comments

Comments
 (0)