Skip to content

Commit f8b199d

Browse files
authored
Merge branch 'main' into toupstream/model_update
2 parents 735d97b + 12af535 commit f8b199d

File tree

63 files changed

+740
-230
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+740
-230
lines changed

CMakeLists.txt

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,33 @@ project(executorch)
4848
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION --------------------------------------------------
4949

5050
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
51+
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)
52+
include(CMakeDependentOption)
53+
include(ExternalProject)
5154

5255
if(NOT CMAKE_CXX_STANDARD)
5356
set(CMAKE_CXX_STANDARD 17)
5457
endif()
5558
announce_configured_options(CMAKE_CXX_STANDARD)
5659

60+
if(NOT CMAKE_SYSTEM_PROCESSOR)
61+
set(CMAKE_SYSTEM_PROCESSOR ${CMAKE_HOST_SYSTEM_PROCESSOR})
62+
endif()
63+
announce_configured_options(CMAKE_SYSTEM_PROCESSOR)
64+
5765
if(NOT CMAKE_BUILD_TYPE)
5866
set(CMAKE_BUILD_TYPE Debug)
5967
endif()
6068
announce_configured_options(CMAKE_BUILD_TYPE)
6169

70+
if(NOT PYTHON_EXECUTABLE)
71+
resolve_python_executable()
72+
endif()
73+
announce_configured_options(PYTHON_EXECUTABLE)
74+
6275
announce_configured_options(CMAKE_CXX_COMPILER_ID)
6376
announce_configured_options(CMAKE_TOOLCHAIN_FILE)
6477
announce_configured_options(BUCK2)
65-
announce_configured_options(PYTHON_EXECUTABLE)
6678

6779
load_build_preset()
6880
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
@@ -72,10 +84,6 @@ print_configured_options()
7284

7385
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION ----------------------------------------------------
7486

75-
include(tools/cmake/Utils.cmake)
76-
include(CMakeDependentOption)
77-
include(ExternalProject)
78-
7987
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
8088

8189
# Setup RPATH.
@@ -251,11 +259,6 @@ if(EXECUTORCH_BUILD_TESTS)
251259
include(CTest)
252260
endif()
253261

254-
if(NOT PYTHON_EXECUTABLE)
255-
resolve_python_executable()
256-
endif()
257-
message(STATUS "Using python executable '${PYTHON_EXECUTABLE}'")
258-
259262
# TODO(dbort): Fix these warnings and remove this flag.
260263
set(_common_compile_options -Wno-deprecated-declarations -fPIC)
261264

backends/apple/mps/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ endif()
1818

1919
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2020

21-
if(NOT PYTHON_EXECUTABLE)
22-
resolve_python_executable()
23-
endif()
24-
2521
set(_common_compile_options -Wno-deprecated-declarations)
2622
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
2723

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2525
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
2626
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
27+
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
2728
from .decompose_linear_pass import DecomposeLinearPass # noqa
2829
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2930
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DecomposeLayerNormPass,
3030
DecomposeLeakyReLUPass,
3131
DecomposeLinearPass,
32+
DecomposeLinearVectorNormPass,
3233
DecomposeMeanDimPass,
3334
DecomposeNotEqualPass,
3435
DecomposeSelectPass,
@@ -86,6 +87,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8687
self.add_pass(ConvertSplitToSlicePass())
8788
self.add_pass(ConvertMmToBmmPass())
8889
self.add_pass(DecomposeLinearPass())
90+
self.add_pass(DecomposeLinearVectorNormPass())
8991
self.add_pass(DecomposeMeanDimPass())
9092
self.add_pass(ConvertFullLikeToFullPass())
9193
self.add_pass(ConvertToClampPass())
@@ -133,6 +135,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
133135
self.add_pass(FuseBatchnorm2DPass(exported_program))
134136
self.add_pass(ConvertMmToBmmPass())
135137
self.add_pass(DecomposeLinearPass())
138+
self.add_pass(DecomposeLinearVectorNormPass())
136139
self.add_pass(DecomposeLeakyReLUPass())
137140
self.add_pass(DecomposeBatchNormPass())
138141
self.add_pass(DecomposeLayerNormPass())
@@ -207,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
207210
self.add_pass(DecomposeCosineSimilarityPass())
208211
self.add_pass(DecomposeDivPass())
209212
self.add_pass(DecomposeLeakyReLUPass())
213+
self.add_pass(DecomposeLinearVectorNormPass())
210214
self.add_pass(DecomposeSqrtPass())
211215
self.add_pass(DecomposeSiluPass())
212216

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
10+
class DecomposeLinearVectorNormPass(ExportPass):
11+
"""
12+
This pass decomposes aten.linalg_vector_norm.default into more primitive ops.
13+
We need to add this pass before quantization for graph annotation.
14+
By default, aten.linalg_vector_norm op is decomposed during legalization to Edge IR.
15+
16+
The decomposition is as follows:
17+
18+
For p == 1:
19+
out = REDUCE_SUM(ABS(x), dims, keepdim)
20+
21+
For p == 2:
22+
out = SQRT(REDUCE_SUM(MUL(x, x), dims, keepdim))
23+
24+
For arbitrary p:
25+
We dont support arbitrary p, because our decomposition looks like
26+
out = POW(REDUCE_SUM(POW(ABS(x), p), dims, keepdim), 1/p)
27+
In this case we need to wrap p into Tensor and we need to know
28+
dtype prior, but we dont know this from FX graph.
29+
"""
30+
31+
torch_linalg_vector_norm = (torch.ops.aten.linalg_vector_norm.default,)
32+
33+
def call_operator(self, op, args, kwargs, meta):
34+
if op not in self.torch_linalg_vector_norm:
35+
return super().call_operator(op, args, kwargs, meta)
36+
37+
# Extract inputs and optional arguments.
38+
# Expected args:
39+
# args[0]: input tensor
40+
# args[1]: norm order 'p' (optional, default: 2.0)
41+
# args[2]: dimensions to reduce (should be provided)
42+
# args[3]: keepdim flag (optional, default: False)
43+
input_tensor = args[0]
44+
norm_order = args[1] if len(args) > 1 else 2.0
45+
norm_dim = args[2] if len(args) > 2 else None
46+
keepdim = args[3] if len(args) > 3 else False
47+
48+
if norm_order not in (1, 2):
49+
raise ValueError(
50+
f"The order of {norm_order}\n"
51+
f"is not supported for linalg_vector_norm operator"
52+
)
53+
54+
if norm_dim is None:
55+
raise ValueError("The norm_dim for linalg_vector_norm is None.")
56+
57+
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
58+
59+
# Decomposition based on norm order.
60+
if norm_order == 1:
61+
op1 = super().call_operator(
62+
torch.ops.aten.abs.default, (input_tensor,), {}, meta
63+
)
64+
op2 = super().call_operator(
65+
torch.ops.aten.sum.dim_IntList, (op1, dims, keepdim), {}, meta
66+
)
67+
return op2
68+
69+
elif norm_order == 2:
70+
# For p == 2, decomposition is sqrt(sum(x * x, dims, keepdim))
71+
op1 = super().call_operator(
72+
torch.ops.aten.mul.Tensor, (input_tensor, input_tensor), {}, meta
73+
)
74+
op2 = super().call_operator(
75+
torch.ops.aten.sum.dim_IntList, (op1, dims, keepdim), {}, meta
76+
)
77+
op3 = super().call_operator(torch.ops.aten.sqrt.default, (op2,), {}, meta)
78+
return op3

backends/arm/operator_support/to_copy_support.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7+
import copy
78
import logging
89

910
import torch
@@ -42,7 +43,9 @@ def _merge_supported_types(
4243
dtypes1: SupportedTypeDict,
4344
dtypes2: SupportedTypeDict,
4445
) -> SupportedTypeDict:
45-
merged_dtypes = dtypes1
46+
merged_dtypes = copy.deepcopy(
47+
dtypes1
48+
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
4649
for k, v in dtypes2.items():
4750
merged_dtypes[k] = merged_dtypes.get(k, []) + v
4851
return merged_dtypes

backends/arm/operators/op_abs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def define_node(
164164
scale_back = 1.0
165165
if inputs[0].dtype == ts.DType.INT8:
166166
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
167-
tosa_graph, inputs, node, self.tosa_specs
167+
tosa_graph, inputs, node, self.tosa_spec
168168
) # type: ignore[possibly-undefined]
169169
else:
170170
# input[0].dtype == ts.DType.INT32
@@ -192,7 +192,7 @@ def define_node(
192192
# Scale output back to 8 bit
193193
# pyre-ignore
194194
tqutils.insert_rescale_op_to_int8(
195-
tosa_graph, abs_output, scale_back, node, self.tosa_specs
195+
tosa_graph, abs_output, scale_back, node, self.tosa_spec
196196
) # type: ignore[possibly-undefined]
197197

198198

backends/arm/operators/op_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def define_node(
174174
scale_back = 1.0
175175
if inputs[0].dtype == ts.DType.INT8:
176176
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
177-
tosa_graph, inputs, node, self.tosa_specs
177+
tosa_graph, inputs, node, self.tosa_spec
178178
)
179179
else:
180180
# input[0].dtype == ts.DType.INT32
@@ -202,7 +202,7 @@ def define_node(
202202
# Scale output back to 8 bit
203203
# pyre-ignore
204204
tqutils.insert_rescale_op_to_int8(
205-
tosa_graph, add_output, scale_back, node, self.tosa_specs
205+
tosa_graph, add_output, scale_back, node, self.tosa_spec
206206
) # type: ignore[possibly-undefined]
207207

208208

backends/arm/operators/op_eq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def define_node(
9898
if inputs[0].dtype == ts.DType.INT8:
9999
# Rescale inputs to 32 bit
100100
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
101-
tosa_graph, inputs, node, self.tosa_specs
101+
tosa_graph, inputs, node, self.tosa_spec
102102
)
103103

104104
# Update IO

backends/arm/operators/op_ge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def define_node(
9797
if inputs[0].dtype == ts.DType.INT8:
9898
# Rescale inputs to 32 bit
9999
rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32(
100-
tosa_graph, inputs, node, self.tosa_specs
100+
tosa_graph, inputs, node, self.tosa_spec
101101
)
102102

103103
# Update IO

0 commit comments

Comments
 (0)