Skip to content

Commit 95a8b63

Browse files
author
pytorchbot
committed
2025-06-23 nightly release (4cb71a0)
1 parent 400d0df commit 95a8b63

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+886
-367
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ It supports a wide range of models including LLMs (Large Language Models), CV (C
1919
Platform Support:
2020
- Operating Systems:
2121
- iOS
22-
- Mac
22+
- MacOS (ARM64)
2323
- Android
2424
- Linux
2525
- Microcontrollers

backends/arm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if(NOT EXECUTORCH_ROOT)
1212
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
1313
endif()
1414

15+
add_compile_options("-Wall" "-Werror")
16+
1517
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1618

1719
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66

77
from . import arm_pass_utils # noqa
8+
from .arm_pass import ArmPass # noqa # usort: skip
9+
from .add_bias_pass import AddBiasPass # noqa
810
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
911
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10-
from .arm_pass import ArmPass # noqa
1112
from .broadcast_args_pass import BroadcastArgsPass # noqa
1213
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1314
from .cast_to_int32_pass import CastToInt32Pass # noqa
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm._passes import ArmPass
8+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
9+
from executorch.backends.transforms.utils import create_constant_placeholder
10+
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import PassResult
13+
from torch.export.graph_signature import InputKind
14+
15+
16+
class AddBiasPass(ArmPass):
17+
"""TOSA requires convolution nodes to have a bias input.
18+
This pass adds a bias input to convolution nodes that do not have one.
19+
The bias is set to zero.
20+
"""
21+
22+
targeted_ops = (exir_ops.edge.aten.convolution.default,)
23+
24+
def call(self, graph_module):
25+
modified = False
26+
for node in graph_module.graph.nodes:
27+
if node.op != "call_function":
28+
continue
29+
if node.target not in self.targeted_ops:
30+
continue
31+
32+
if len(node.all_input_nodes) < 3:
33+
modified = True
34+
# bias is missing
35+
weight_node = node.all_input_nodes[1]
36+
output_channels = get_first_fake_tensor(weight_node).shape[0]
37+
# add a node containging zeros
38+
# if quantized, use int32, otherwise use float32
39+
if (
40+
"output_qparams" in node.meta
41+
and len(node.meta["output_qparams"]) > 0
42+
):
43+
bias_data = torch.zeros(size=(output_channels,), dtype=torch.int32)
44+
else:
45+
bias_data = torch.zeros(
46+
size=(output_channels,), dtype=torch.float32
47+
)
48+
49+
with graph_module.graph.inserting_after(weight_node):
50+
bias_node = create_constant_placeholder(
51+
self.exported_program,
52+
graph=graph_module.graph,
53+
kind=InputKind.PARAMETER,
54+
data=bias_data,
55+
persistent_buffer=True,
56+
name=f"{node.name}_bias",
57+
)
58+
node.update_arg(2, bias_node)
59+
60+
if modified:
61+
graph_module = super().call(graph_module).graph_module
62+
return PassResult(graph_module, modified)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-unsafe
99
from executorch.backends.arm._passes import (
10+
AddBiasPass,
1011
AnnotateChannelsLastDimOrder,
1112
AnnotateDecomposedMatmulPass,
1213
BroadcastArgsPass,
@@ -134,6 +135,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
134135

135136
self.add_pass(FuseViewCopyTransform())
136137
self.add_pass(FuseConstantArgsPass(exported_program))
138+
self.add_pass(AddBiasPass(exported_program))
137139

138140
self.add_pass(InsertTableOpsPass(exported_program))
139141
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
@@ -194,6 +196,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
194196

195197
self.add_pass(FuseViewCopyTransform())
196198
self.add_pass(FuseConstantArgsPass(exported_program))
199+
self.add_pass(AddBiasPass(exported_program))
197200
self.add_pass(InsertTableOpsPass(exported_program))
198201
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
199202
self.add_pass(AnnotateChannelsLastDimOrder())

backends/arm/_passes/match_where_self_arg_dtype_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def call(self, graph_module: torch.fx.GraphModule):
4949

5050
input_dtype = input_.meta["val"].dtype
5151
other_dtype = other_.meta["val"].dtype
52-
target_dtype = torch.float32
52+
target_dtype = input_dtype
5353
if input_dtype != other_dtype:
5454
target_dtype = get_largest_dtype(input_dtype, other_dtype)
5555

backends/arm/arm_backend.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
# backends. Converts via TOSA as an intermediate form supported by AoT and
1111
# JIT compiler flows.
1212
#
13-
1413
from typing import List, Optional
1514

16-
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found]
16+
TosaSpecification,
17+
)
1718

18-
from executorch.exir.backend.compile_spec_schema import CompileSpec
19+
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
20+
CompileSpec,
21+
)
1922

2023

2124
class ArmCompileSpecBuilder:
@@ -28,6 +31,7 @@ def __init__(self):
2831

2932
def vgf_compile_spec(
3033
self,
34+
tosa_spec: TosaSpecification = None, # type: ignore[assignment]
3135
compiler_flags: Optional[str] = "",
3236
) -> "ArmCompileSpecBuilder":
3337
"""
@@ -40,7 +44,33 @@ def vgf_compile_spec(
4044
self.compiler_flags = [
4145
compiler_flags,
4246
]
43-
self.tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+MI")
47+
48+
if tosa_spec is None:
49+
tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP")
50+
51+
tosa_version = tosa_spec.version # type: ignore[attr-defined]
52+
tosa_profiles = tosa_spec.profiles # type: ignore[attr-defined]
53+
54+
if tosa_version.major != 1:
55+
raise ValueError(
56+
"Arm backend only supports converter-backend for TOSA version 1. "
57+
f"Invalid TOSA version: {tosa_version}"
58+
)
59+
60+
if not ("FP" or "INT" in tosa_profiles):
61+
raise ValueError(
62+
"Arm backend only supports converter-backend for FP or INT. "
63+
f"Invalid TOSA profile: {tosa_profiles}"
64+
)
65+
66+
if len(tosa_profiles) != 1:
67+
raise ValueError(
68+
"For now Arm backend only supports converter-backend for either FP or INT. "
69+
f"Invalid TOSA profile: {tosa_profiles}"
70+
)
71+
72+
self.tosa_spec = tosa_spec
73+
4474
return self
4575

4676
def ethosu_compile_spec(

backends/arm/operators/op_conv2d.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,6 @@ def define_node(
109109
local_bound=False,
110110
)
111111

112-
# Non-bias case.
113-
if len(node.all_input_nodes) == 2:
114-
# Create a zero bias tensor if not presented
115-
out_channels = weight.shape[0]
116-
bias_name = "bias" + node.name.split("default", 1)[1]
117-
bias_type = output.dtype
118-
if output.dtype == ts.DType.INT8:
119-
# Conv is quantized to int8, but the TOSA operator has
120-
# output type int32, and the bias must be the same type
121-
# as the TOSA output type
122-
bias_type = ts.DType.INT32
123-
bias = tosa_graph.addConst(
124-
[out_channels],
125-
bias_type,
126-
[0] * out_channels,
127-
name=bias_name,
128-
)
129-
130112
# The output type is int32 when input type is int8.
131113
conv2d_output_name = output.name
132114
if output.dtype == ts.DType.INT8:
@@ -313,24 +295,6 @@ def define_node(
313295
name=f"{conv2d_output_name}_weight_zp",
314296
)
315297

316-
# Non-bias case.
317-
if len(node.all_input_nodes) == 2:
318-
# Create a zero bias tensor if not presented
319-
out_channels = weight.shape[0]
320-
bias_name = f"{conv2d_output_name}_bias"
321-
bias_type = output.dtype
322-
if output.dtype == ts.DType.INT8:
323-
# Conv is quantized to int8, but the TOSA operator has
324-
# output type int32, and the bias must be the same type
325-
# as the TOSA output type
326-
bias_type = ts.DType.INT32
327-
bias = tosa_graph.addConst(
328-
[out_channels],
329-
bias_type,
330-
[0] * out_channels,
331-
name=bias_name,
332-
)
333-
334298
# Given input.shape is (N, Ci, H, W), and weight.shape is (Co, Ci/G, H, W)
335299
in_channels = input.shape[1]
336300
out_channels = weight.shape[0]

backends/arm/test/ops/test_where.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ def scalar_condition(input: torch.Tensor):
121121
scalar_condition,
122122
)
123123

124+
int32_scalar_cond = Where(
125+
1,
126+
torch.int32,
127+
scalar_condition,
128+
)
129+
124130
test_modules_common = {
125131
"two_dim_tensor_cond": lambda: two_dim_tensor_cond,
126132
"three_dim_tensor_cond": lambda: three_dim_tensor_cond,
@@ -134,6 +140,7 @@ def scalar_condition(input: torch.Tensor):
134140
**test_modules_common,
135141
"float32_tensor_cond_tuple_dtype": lambda: float32_tensor_cond_tuple_dtype,
136142
"float32_tensor_cond_tuple_dtype_bool": lambda: float32_tensor_cond_tuple_dtype_bool,
143+
"int32_scalar_cond": lambda: int32_scalar_cond,
137144
}
138145

139146
test_modules_BI = {

backends/cadence/aot/fuse_ops.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -712,32 +712,14 @@ def _create_requantize_node(
712712
out_dtype: torch.dtype,
713713
graph: torch.fx.Graph,
714714
) -> torch.fx.Node:
715-
in_scale_tensor = graph.call_function(
716-
exir_ops.edge.aten.full.default, args=((1,), in_scale)
717-
)
718-
in_zero_point_tensor = graph.call_function(
719-
exir_ops.edge.aten.full.default,
720-
args=((1,), in_zero_point),
721-
kwargs={"dtype": torch.int32},
722-
)
723-
out_scale_tensor = graph.call_function(
724-
exir_ops.edge.aten.full.default, args=((1,), out_scale)
725-
)
726-
out_zero_point_tensor = graph.call_function(
727-
exir_ops.edge.aten.full.default,
728-
args=((1,), out_zero_point),
729-
kwargs={"dtype": torch.int32},
730-
)
731-
# cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y
732-
# TODO(hardiksharma): Add support for per-tensor requantize.
733715
return graph.call_function(
734-
exir_ops.edge.cadence.requantize.default,
716+
exir_ops.edge.cadence.requantize.per_tensor,
735717
args=(
736718
in_tensor,
737-
in_scale_tensor,
738-
in_zero_point_tensor,
739-
out_scale_tensor,
740-
out_zero_point_tensor,
719+
in_scale,
720+
in_zero_point,
721+
out_scale,
722+
out_zero_point,
741723
out_dtype,
742724
),
743725
)

0 commit comments

Comments
 (0)