Skip to content

Commit 81e1502

Browse files
committed
[Executorch] enable sleef consistently
Earlier only android platofrms had support for sleef Differential Revision: [D64571782](https://our.internmc.facebook.com/intern/diff/D64571782/) [ghstack-poisoned]
1 parent 37c079c commit 81e1502

File tree

3 files changed

+41
-38
lines changed

3 files changed

+41
-38
lines changed

extension/llm/custom_ops/targets.bzl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(
3+
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
4+
"get_vec_preprocessor_flags",
5+
"get_vec_deps",
6+
)
27

38
def define_common_targets():
49
"""Defines targets that should be shared between fbcode and xplat.
@@ -21,6 +26,7 @@ def define_common_targets():
2126
"op_sdpa.h",
2227
"op_update_quantized_cache.h",
2328
],
29+
preprocessor_flags = get_vec_preprocessor_flags(),
2430
exported_deps = [
2531
"//executorch/runtime/kernel:kernel_includes",
2632
"//executorch/kernels/portable/cpu:scalar_utils",
@@ -33,7 +39,7 @@ def define_common_targets():
3339
deps = [
3440
"//executorch/kernels/portable/cpu/util:reduce_util",
3541
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
36-
],
42+
] + get_vec_deps(),
3743
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors", "-O2"],
3844
visibility = [
3945
"//executorch/...",

kernels/optimized/lib_defs.bzl

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,37 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
1111
# functions in order to declare the required compiler flags needed in order to
1212
# access CPU vector intrinsics.
1313

14-
def get_vec_android_preprocessor_flags():
15-
preprocessor_flags = [
16-
(
17-
"^android-arm64.*$",
18-
[
19-
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
20-
],
21-
),
22-
]
14+
def get_vec_preprocessor_flags():
15+
preprocessor_flags = select({
16+
"ovr_config//os:iphoneos": [
17+
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
18+
] if not runtime.is_oss else [],
19+
"ovr_config//os:macos-arm64": [
20+
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
21+
] if not runtime.is_oss else [],
22+
"ovr_config//os:android-arm64": [
23+
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
24+
] if not runtime.is_oss else [],
25+
"DEFAULT": [],
26+
})
27+
return preprocessor_flags
28+
29+
def get_vec_deps():
30+
preprocessor_flags = select({
31+
"ovr_config//os:linux-x86_64": [
32+
"fbsource//third-party/sleef:sleef",
33+
] if not runtime.is_oss else [],
34+
"ovr_config//os:iphoneos": [
35+
"fbsource//third-party/sleef:sleef_arm",
36+
] if not runtime.is_oss else [],
37+
"ovr_config//os:macos-arm64": [
38+
"fbsource//third-party/sleef:sleef_arm",
39+
] if not runtime.is_oss else [],
40+
"ovr_config//os:android-arm64": [
41+
"fbsource//third-party/sleef:sleef_arm",
42+
] if not runtime.is_oss else [],
43+
"DEFAULT": [],
44+
})
2345
return preprocessor_flags
2446

2547
def get_vec_cxx_preprocessor_flags():
@@ -56,32 +78,7 @@ def define_libs():
5678
"//executorch/...",
5779
"@EXECUTORCH_CLIENTS",
5880
],
59-
cxx_platform_deps = select({
60-
"DEFAULT": [
61-
(
62-
DEVSERVER_PLATFORM_REGEX,
63-
[
64-
"fbsource//third-party/sleef:sleef",
65-
],
66-
),
67-
],
68-
"ovr_config//cpu:arm64": [
69-
(
70-
DEVSERVER_PLATFORM_REGEX,
71-
[
72-
"fbsource//third-party/sleef:sleef_arm",
73-
],
74-
),
75-
],
76-
}),
77-
fbandroid_platform_deps = [
78-
(
79-
"^android-arm64.*$",
80-
[
81-
"fbsource//third-party/sleef:sleef_arm",
82-
],
83-
),
84-
],
81+
deps = get_vec_deps(),
8582
)
8683

8784
runtime.cxx_library(

kernels/optimized/op_registration_util.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
33
load(
44
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
5-
"get_vec_android_preprocessor_flags",
5+
"get_vec_preprocessor_flags",
66
)
77

88
def op_target(name, deps = []):
@@ -91,7 +91,7 @@ def define_op_library(name, deps):
9191
deps = [
9292
"//executorch/runtime/kernel:kernel_includes",
9393
] + augmented_deps,
94-
fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(),
94+
preprocessor_flags = get_vec_preprocessor_flags(),
9595
# sleef needs to be added as a direct dependency of the operator target when building for Android,
9696
# or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of
9797
# dependencies are not transitive

0 commit comments

Comments
 (0)