Skip to content

Commit 79defe2

Browse files
author
Github Executorch
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 157df30 + d32c542 commit 79defe2

File tree

5 files changed

+95
-29
lines changed

5 files changed

+95
-29
lines changed

.ci/scripts/test_llava.sh

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ fi
3030
NPROC=8
3131
if hash nproc &> /dev/null; then NPROC=$(nproc); fi
3232

33+
python_lib=$($PYTHON_EXECUTABLE -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
34+
CMAKE_PREFIX_PATH="$(python3 -c 'import torch as _; print(_.__path__[0])')"
3335
EXECUTORCH_COMMON_CMAKE_ARGS=" \
3436
-DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \
35-
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
37+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
3638
-DEXECUTORCH_ENABLE_LOGGING=ON \
3739
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
3840
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
@@ -46,6 +48,7 @@ EXECUTORCH_COMMON_CMAKE_ARGS=" \
4648
cmake_install_executorch_libraries() {
4749
cmake \
4850
${EXECUTORCH_COMMON_CMAKE_ARGS} \
51+
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" \
4952
-B${BUILD_DIR} .
5053

5154
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
@@ -56,6 +59,7 @@ cmake_install_executorch_libraries_for_android() {
5659
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
5760
-DANDROID_ABI=arm64-v8a \
5861
${EXECUTORCH_COMMON_CMAKE_ARGS} \
62+
"-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" \
5963
-B${BUILD_DIR} .
6064

6165
cmake --build ${BUILD_DIR} -j${NPROC} --target install --config ${CMAKE_BUILD_TYPE}
@@ -76,7 +80,7 @@ cmake_build_llava_runner() {
7680

7781
cmake \
7882
${LLAVA_COMMON_CMAKE_ARGS} \
79-
-DCMAKE_PREFIX_PATH="$python_lib" \
83+
-DCMAKE_PREFIX_PATH="$python_lib;${CMAKE_PREFIX_PATH}" \
8084
-B${BUILD_DIR}/${dir} \
8185
${dir}
8286

@@ -92,7 +96,7 @@ cmake_build_llava_runner_for_android() {
9296
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
9397
-DANDROID_ABI=arm64-v8a \
9498
${LLAVA_COMMON_CMAKE_ARGS} \
95-
-DCMAKE_PREFIX_PATH="$python_lib" \
99+
-DCMAKE_PREFIX_PATH="$python_lib;${CMAKE_PREFIX_PATH}" \
96100
-DLLAVA_RUNNER_NO_TORCH_DUMMY_IMAGE=ON \
97101
-B${BUILD_DIR}/${dir} \
98102
${dir}

backends/arm/test/misc/test_tosa_spec.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@
2020
"TOSA-0.80+MI+8k",
2121
"TOSA-0.80+BI+u55",
2222
]
23-
test_valid_1_00_strings = [
24-
"TOSA-1.00.0+INT+FP+fft",
25-
"TOSA-1.00.0+FP+bf16+fft",
26-
"TOSA-1.00.0+INT+int4+cf",
27-
"TOSA-1.00.0+FP+cf+bf16+8k",
28-
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
29-
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
23+
test_valid_1_0_strings = [
24+
"TOSA-1.0.0+INT+FP+fft",
25+
"TOSA-1.0.0+FP+bf16+fft",
26+
"TOSA-1.0.0+INT+int4+cf",
27+
"TOSA-1.0.0+FP+cf+bf16+8k",
28+
"TOSA-1.0.0+FP+INT+bf16+fft+int4+cf",
29+
"TOSA-1.0.0+FP+INT+fft+int4+cf+8k",
30+
"TOSA-1.0+INT+FP+fft",
31+
"TOSA-1.0+FP+bf16+fft",
32+
"TOSA-1.0+INT+int4+cf",
33+
"TOSA-1.0+FP+cf+bf16+8k",
34+
"TOSA-1.0+FP+INT+bf16+fft+int4+cf",
35+
"TOSA-1.0+FP+INT+fft+int4+cf+8k",
3036
]
3137

32-
test_valid_1_00_extensions = {
38+
test_valid_1_0_extensions = {
3339
"INT": ["int16", "int4", "var", "cf"],
3440
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
3541
}
@@ -40,19 +46,19 @@
4046
"TOSA-0.80+8k",
4147
"TOSA-0.80+BI+MI",
4248
"TOSA-0.80+BI+U55",
43-
"TOSA-1.00.0+fft",
44-
"TOSA-1.00.0+fp+bf16+fft",
45-
"TOSA-1.00.0+INT+INT4+cf",
46-
"TOSA-1.00.0+BI",
47-
"TOSA-1.00.0+FP+FP+INT",
48-
"TOSA-1.00.0+FP+CF+bf16",
49-
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
49+
"TOSA-1.0.0+fft",
50+
"TOSA-1.0.0+fp+bf16+fft",
51+
"TOSA-1.0.0+INT+INT4+cf",
52+
"TOSA-1.0.0+BI",
53+
"TOSA-1.0.0+FP+FP+INT",
54+
"TOSA-1.0.0+FP+CF+bf16",
55+
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
5056
]
5157

5258
test_compile_specs = [
5359
([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],),
5460
([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],),
55-
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
61+
([CompileSpec("tosa_version", "TOSA-1.0.0+INT".encode())],),
5662
]
5763

5864
test_compile_specs_no_version = [
@@ -70,8 +76,8 @@ def test_version_string_0_80(self, version_string: str):
7076
assert isinstance(tosa_spec, Tosa_0_80)
7177
assert tosa_spec.profile in ["BI", "MI"]
7278

73-
@parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
74-
def test_version_string_1_00(self, version_string: str):
79+
@parameterized.expand(test_valid_1_0_strings) # type: ignore[misc]
80+
def test_version_string_1_0(self, version_string: str):
7581
tosa_spec = TosaSpecification.create_from_string(version_string)
7682
assert isinstance(tosa_spec, Tosa_1_00)
7783
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
@@ -80,7 +86,7 @@ def test_version_string_1_00(self, version_string: str):
8086

8187
for profile in tosa_spec.profiles:
8288
assert [
83-
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
89+
e in test_valid_1_0_extensions[profile] for e in tosa_spec.extensions
8490
]
8591

8692
@parameterized.expand(test_invalid_strings) # type: ignore[misc]
@@ -103,3 +109,15 @@ def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec])
103109
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
104110

105111
assert tosa_spec is None
112+
113+
@parameterized.expand(test_valid_0_80_strings)
114+
def test_correct_string_representation_0_80(self, version_string: str):
115+
tosa_spec = TosaSpecification.create_from_string(version_string)
116+
assert isinstance(tosa_spec, Tosa_0_80)
117+
assert f"{tosa_spec}" == version_string
118+
119+
@parameterized.expand(test_valid_1_0_strings)
120+
def test_correct_string_representation_1_0(self, version_string: str):
121+
tosa_spec = TosaSpecification.create_from_string(version_string)
122+
assert isinstance(tosa_spec, Tosa_1_00)
123+
assert f"{tosa_spec}" == version_string

backends/arm/tosa_specification.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -14,7 +14,9 @@
1414
import re
1515
from typing import List
1616

17-
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-untyped]
18+
CompileSpec,
19+
)
1820
from packaging.version import Version
1921

2022

@@ -131,7 +133,7 @@ def __init__(self, version: Version, extras: List[str]):
131133
def __repr__(self):
132134
extensions = ""
133135
if self.level_8k:
134-
extensions += "+8K"
136+
extensions += "+8k"
135137
if self.is_U55_subset:
136138
extensions += "+u55"
137139
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
@@ -207,7 +209,10 @@ def _get_extensions_string(self) -> str:
207209
return "".join(["+" + e for e in self.extensions])
208210

209211
def __repr__(self):
210-
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
212+
extensions = self._get_extensions_string()
213+
if self.level_8k:
214+
extensions += "+8k"
215+
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
211216

212217
def __hash__(self) -> int:
213218
return hash(str(self.version) + self._get_profiles_string())

codegen/tools/gen_all_oplist.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,31 @@ def resolve_model_file_path_to_buck_target(model_file_path: str) -> str:
4747
return real_path
4848

4949

50+
def _raise_if_check_prim_ops_fail(options):
51+
52+
# Error out if we have more than one targets registering prim ops.
53+
if options.DEBUG_ONLY_check_prim_ops and len(options.DEBUG_ONLY_check_prim_ops) > 1:
54+
assert (
55+
options.DEBUG_ONLY_check_prim_ops[0] == "@"
56+
), "DEBUG_ONLY_check_prim_ops is not a valid file path, or it doesn't start with '@'. This is likely a BUCK issue."
57+
58+
prim_ops_targets_file = options.DEBUG_ONLY_check_prim_ops[1:]
59+
with open(prim_ops_targets_file, "r") as file:
60+
prim_ops_targets = file.read().split()
61+
if len(prim_ops_targets) > 1:
62+
# Yellow bold: \033[33;1m
63+
# Red bold: \033[31;1m
64+
# Green bold: \033[32;1m
65+
error = (
66+
"It seems this target is depending on more than 1 `prim_ops_registry` targets: "
67+
+ f'\033[33;1m\n{", ".join(prim_ops_targets)}\033[0m. \nThis will likely cause errors such as: '
68+
+ "\n \033[31;1mRe-registering aten::sym_size.int...\033[0m"
69+
+ "\nTo find out the dependency chain, run the following command: "
70+
+ f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {prim_ops_targets[0]})"\033[0m'
71+
)
72+
raise Exception(error)
73+
74+
5075
def main(argv: List[Any]) -> None:
5176
"""This binary generates 3 files:
5277
@@ -95,8 +120,18 @@ def main(argv: List[Any]) -> None:
95120
default=False,
96121
required=False,
97122
)
123+
parser.add_argument(
124+
"--DEBUG-ONLY-check-prim-ops",
125+
"--DEBUG_ONLY_check_prim_ops",
126+
help=(
127+
"Useful argument to take BUCK targets that registers prim ops and error out if we have more than 1."
128+
),
129+
required=False,
130+
)
98131
options = parser.parse_args(argv)
99132

133+
_raise_if_check_prim_ops_fail(options)
134+
100135
# Check if the build has any dependency on any selective build target. If we have a target, BUCK shold give us either:
101136
# 1. a yaml file containing selected ops (could be empty), or
102137
# 2. a non-empty list of yaml files in the `model_file_list_path` or
@@ -153,14 +188,17 @@ def main(argv: List[Any]) -> None:
153188
debug_info_2 = ",".join(
154189
model_dict["operators"][op_name]["debug_info"]
155190
)
156-
error = f"Operator {op_name} is used in 2 models: {debug_info_1} and {debug_info_2}"
191+
# Yellow bold: \033[33;1m
192+
# Red bold: \033[31;1m
193+
# Green bold: \033[32;1m
194+
error = f"\033[31;1mOperator {op_name} is used in 2 models: \033[33;1m{debug_info_1} and {debug_info_2}\033[0m"
157195
if "//" not in debug_info_1 and "//" not in debug_info_2:
158196
error += "\nWe can't determine what BUCK targets these model files belong to."
159197
tail = "."
160198
else:
161199
error += "\nPlease run the following commands to find out where is the BUCK target being added as a dependency to your target:\n"
162-
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_1})"'
163-
error += f'\n buck2 cquery <mode> "allpaths(<target>, {debug_info_2})"'
200+
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_1})"\033[0m'
201+
error += f'\n \033[32;1mbuck2 cquery <mode> "allpaths(<target>, {debug_info_2})"\033[0m'
164202
tail = "as well as results from BUCK commands listed above."
165203

166204
error += (

shim/xplat/executorch/codegen/codegen.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def executorch_ops_check(
706706
"--model_file_list_path $(@query_outputs \"filter('.*_et_oplist', deps(set({deps})))\") " +
707707
"--allow_include_all_overloads " +
708708
"--check_ops_not_overlapping " +
709+
"--DEBUG_ONLY_check_prim_ops $(@query_targets \"filter('prim_ops_registry(?:_static|_aten)?$', deps(set({deps})))\") " +
709710
"--output_dir $OUT ").format(deps = " ".join(["\'{}\'".format(d) for d in deps])),
710711
define_static_target = False,
711712
platforms = kwargs.pop("platforms", get_default_executorch_platforms()),

0 commit comments

Comments
 (0)