Skip to content

Commit c0d2451

Browse files
Merge branch 'main' into improve-tosa-supported-ops-check
2 parents 4e7cb0c + 2845fd3 commit c0d2451

File tree

184 files changed

+12024
-1240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

184 files changed

+12024
-1240
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,14 @@ jobs:
929929
CMAKE_ARGS="-DEXECUTORCH_BUILD_VULKAN=ON" \
930930
.ci/scripts/setup-linux.sh --build-tool "cmake"
931931
932+
# Custom operator tests
932933
PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add
934+
./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear
935+
./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d
936+
937+
# Run e2e testing for selected operators. More operators will be tested via this
938+
# route in the future.
939+
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"
933940
934941
nxp-build-test:
935942
name: nxp-build-test

CMakeLists.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,7 @@ if(EXECUTORCH_BUILD_KERNELS_TORCHAO)
699699
${EXECUTORCH_ROOT}/backends/xnnpack/third-party/pthreadpool/include
700700
${EXECUTORCH_ROOT}/backends/xnnpack/third-party/cpuinfo/include
701701
)
702-
add_subdirectory(
703-
${CMAKE_CURRENT_SOURCE_DIR}/third-party/ao/torchao/experimental
704-
)
702+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/ao/torchao/csrc/cpu)
705703
unset(EXECUTORCH_INCLUDE_DIRS)
706704

707705
executorch_target_link_options_shared_lib(torchao_ops_executorch)

backends/arm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ For more information on TOSA see https://www.mlplatform.org/tosa/tosa_spec.html
3434
## Layout of key components
3535

3636
Export:
37-
* `tosa_backend.py` - The TOSA conversion flow all other backends rely on.
37+
* `tosa/backend.py` - The TOSA conversion flow all other backends rely on.
3838
* `ethosu/backend.py` - Main entrypoint for the EthosUBackend.
3939
* `vgf_backend.py` - Main entrypoint for VgfBackend.
4040
* For more information see the section on [Arm Backend Architecture](#arm-backend-architecture).

backends/arm/TARGETS

Lines changed: 22 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
# @noautodeps
27
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
38

@@ -12,6 +17,17 @@ python_library(
1217
":arm_partitioner",
1318
]
1419
)
20+
python_library(
21+
name = "vgf_partitioner",
22+
srcs = [
23+
"vgf/__init__.py",
24+
"vgf/backend.py",
25+
"vgf/partitioner.py"
26+
],
27+
deps = [
28+
":arm_partitioner",
29+
]
30+
)
1531
python_library(
1632
name = "constants",
1733
srcs = [
@@ -37,14 +53,13 @@ python_library(
3753
python_library(
3854
name = "arm_partitioner",
3955
srcs = [
40-
"tosa_backend.py",
41-
"tosa_partitioner.py",
42-
"vgf_backend.py",
43-
"vgf_partitioner.py",
56+
"tosa/backend.py",
57+
"tosa/partitioner.py",
4458
],
4559
deps = [
4660
":arm_backend",
4761
":constants",
62+
"//executorch/backends/arm/debug:schema",
4863
"//executorch/backends/arm/operator_support:operator_support",
4964
"//executorch/backends/arm/_passes:passes",
5065
"//executorch/exir:lib",
@@ -76,9 +91,9 @@ python_library(
7691
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa",
7792
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa",
7893
"//executorch/backends/arm/operators:node_visitor",
79-
"//executorch/backends/arm:tosa_mapping",
80-
"//executorch/backends/arm:tosa_quant_utils",
81-
"//executorch/backends/arm:tosa_utils",
94+
"//executorch/backends/arm/tosa:mapping",
95+
"//executorch/backends/arm/tosa:quant_utils",
96+
"//executorch/backends/arm/tosa:utils",
8297
"//executorch/exir:lib",
8398
],
8499
)
@@ -91,54 +106,6 @@ python_library(
91106
"fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela",
92107
],
93108
)
94-
python_library(
95-
name = "tosa_mapping",
96-
srcs = [
97-
"tosa_mapping.py",
98-
],
99-
deps = [
100-
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer",
101-
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer",
102-
"//caffe2:torch",
103-
],
104-
)
105-
python_library(
106-
name = "tosa_quant_utils",
107-
srcs = [
108-
"tosa_quant_utils.py",
109-
],
110-
deps = [
111-
"fbsource//third-party/pypi/numpy:numpy",
112-
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer",
113-
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer",
114-
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa",
115-
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa",
116-
":constants",
117-
":tosa_mapping",
118-
"//executorch/exir/dialects:lib",
119-
],
120-
)
121-
python_library(
122-
name = "tosa_specification",
123-
srcs = [
124-
"tosa_specification.py",
125-
],
126-
deps = [
127-
"fbsource//third-party/pypi/packaging:packaging",
128-
"//executorch/exir/backend:compile_spec_schema",
129-
],
130-
)
131-
python_library(
132-
name = "tosa_utils",
133-
srcs = [
134-
"tosa_utils.py",
135-
],
136-
deps = [
137-
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/serializer:serializer",
138-
":tosa_quant_utils",
139-
"//executorch/backends/arm/operators:node_visitor",
140-
],
141-
)
142109
python_library(
143110
name = "arm_model_evaluator",
144111
srcs = [

backends/arm/_passes/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ python_library(
66
deps = [
77
"//executorch/backends/arm:common",
88
"//executorch/backends/arm:constants",
9-
"//executorch/backends/arm:tosa_quant_utils",
10-
"//executorch/backends/arm:tosa_utils",
9+
"//executorch/backends/arm/tosa:quant_utils",
10+
"//executorch/backends/arm/tosa:utils",
1111
"//executorch/backends/arm/tosa/dialect:lib",
1212
"//executorch/backends/transforms:fuse_view_copy",
1313
"//executorch/backends/transforms:remove_getitem_op",

backends/arm/debug/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# @noautodeps
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3+
4+
python_library(
5+
name = "schema",
6+
srcs = [
7+
"__init__.py",
8+
"schema.py",
9+
],
10+
deps = [
11+
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/serializer:serializer",
12+
"//caffe2:torch",
13+
],
14+
)

backends/arm/operator_support/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ python_library(
66
deps = [
77
"//executorch/backends/arm:constants",
88
"//executorch/backends/arm/_passes:passes",
9-
"//executorch/backends/arm:tosa_specification",
9+
"//executorch/backends/arm/tosa:tosa",
1010
"//executorch/backends/transforms:remove_getitem_op",
1111
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1212
"//executorch/exir:lib",

backends/arm/operators/TARGETS

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ python_library(
55
name = "node_visitor",
66
srcs = ["node_visitor.py"],
77
deps = [
8-
"//executorch/backends/arm:tosa_mapping",
9-
"//executorch/backends/arm:tosa_specification",
8+
"//executorch/backends/arm/debug:schema",
9+
"//executorch/backends/arm/tosa:mapping",
10+
"//executorch/backends/arm/tosa:tosa",
1011
],
1112
)
1213

@@ -23,9 +24,9 @@ python_library(
2324
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa",
2425
":node_visitor",
2526
":operator_validation_utils",
26-
"//executorch/backends/arm:tosa_mapping",
27-
"//executorch/backends/arm:tosa_quant_utils",
28-
"//executorch/backends/arm:tosa_utils",
27+
"//executorch/backends/arm/tosa:mapping",
28+
"//executorch/backends/arm/tosa:quant_utils",
29+
"//executorch/backends/arm/tosa:utils",
2930
"//executorch/backends/arm/_passes:passes",
3031
"//executorch/exir:lib",
3132
],

backends/arm/operators/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515
op_avg_pool2d,
1616
op_bmm,
1717
op_cat,
18+
op_ceil,
1819
op_clamp,
1920
op_constant_pad_nd,
2021
op_conv2d,
2122
op_cos,
2223
op_eq,
2324
op_erf,
2425
op_exp,
26+
op_floor,
2527
op_ge,
2628
op_gt,
2729
op_index_select,
2830
op_index_tensor,
2931
op_le,
3032
op_log,
33+
op_logical_not,
3134
op_lt,
3235
op_max_pool2d,
3336
op_maximum,
@@ -57,5 +60,4 @@
5760
op_where,
5861
ops_binary,
5962
ops_identity,
60-
ops_unary,
6163
)

backends/arm/operators/op_add.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,16 @@ def define_node(
4747

4848
validate_num_inputs(self.target, inputs, 2)
4949
validate_same_dtype(self.target, [*inputs, output], ts)
50+
valid_dtypes = []
51+
if self.tosa_spec.support_integer():
52+
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
53+
if self.tosa_spec.support_float():
54+
valid_dtypes.extend([ts.DType.INT32])
55+
5056
validate_valid_dtype(
5157
self.target,
5258
[*inputs, output],
53-
[ts.DType.INT8, ts.DType.INT32],
59+
valid_dtypes,
5460
output.tosa_spec,
5561
)
5662
scale_back = 1.0
@@ -59,15 +65,15 @@ def define_node(
5965
tosa_graph, inputs, node, self.tosa_spec
6066
)
6167
else:
62-
# input[0].dtype == ts.DType.INT32
68+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
6369
# Non quantized input, natively support by TOSA.ADD
6470
rescaled_inputs = inputs
6571

6672
if output.dtype == ts.DType.INT8:
6773
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6874
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6975
else:
70-
# output.dtype == ts.DType.INT32
76+
# output.dtype == ts.DType.INT16 or ts.DType.INT32
7177
add_output = output
7278

7379
input1, input2 = rescaled_inputs
@@ -117,7 +123,7 @@ def define_node(
117123
validate_num_inputs(self.target, inputs, 2)
118124
validate_same_dtype(self.target, [*inputs, output], ts)
119125

120-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
126+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
121127
# Call the inherited define_node for handling integers
122128
super().define_node(node, tosa_graph, inputs, output)
123129
else:

0 commit comments

Comments
 (0)