Skip to content

Commit be590d3

Browse files
committed
Update on "[Executorch] Handle broadcast semantics for last dim"
This diff add support to handle element wise mul op when broadcast is across last dim Differential Revision: [D64156863](https://our.internmc.facebook.com/intern/diff/D64156863/) [ghstack-poisoned]
2 parents 6e928d8 + 5a9981b commit be590d3

File tree

4 files changed

+35
-60
lines changed

4 files changed

+35
-60
lines changed

extension/llm/custom_ops/targets.bzl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(
3+
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
4+
"get_compiler_optimization_flags",
5+
)
6+
27

38
def define_common_targets():
49
"""Defines targets that should be shared between fbcode and xplat.
@@ -34,21 +39,7 @@ def define_common_targets():
3439
"//executorch/kernels/portable/cpu/util:reduce_util",
3540
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
3641
],
37-
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + select({
38-
"DEFAULT": [],
39-
"ovr_config//os:android-arm64": [
40-
"-O2",
41-
] if not runtime.is_oss else [],
42-
"ovr_config//os:iphoneos": [
43-
"-O2",
44-
] if not runtime.is_oss else [],
45-
"ovr_config//os:macos-arm64": [
46-
"-O2",
47-
] if not runtime.is_oss else [],
48-
"ovr_config//os:macos-x86_64": [
49-
"-O2",
50-
] if not runtime.is_oss else [],
51-
}),
42+
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
5243
visibility = [
5344
"//executorch/...",
5445
"//executorch/extension/llm/custom_ops/...",

kernels/optimized/lib_defs.bzl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ load("@fbsource//tools/build_defs:default_platform_defs.bzl", "DEVSERVER_PLATFOR
22
load("@fbsource//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
33
load("@fbsource//xplat/executorch/backends/xnnpack/third-party:third_party_libs.bzl", "third_party_dep")
44
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
5+
load(
6+
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
7+
"get_compiler_optimization_flags",
8+
)
59

610
# Because vec exists as a collection of header files, compile and preprocessor
711
# flags applied to the vec target do not have any effect, since no compilation
@@ -121,21 +125,7 @@ def define_libs():
121125
exported_headers = native.glob([
122126
"blas/**/*.h",
123127
]),
124-
compiler_flags = select({
125-
"DEFAULT": [],
126-
"ovr_config//os:android-arm64": [
127-
"-O2",
128-
] if not runtime.is_oss else [],
129-
"ovr_config//os:iphoneos": [
130-
"-O2",
131-
] if not runtime.is_oss else [],
132-
"ovr_config//os:macos-arm64": [
133-
"-O2",
134-
] if not runtime.is_oss else [],
135-
"ovr_config//os:macos-x86_64": [
136-
"-O2",
137-
] if not runtime.is_oss else [],
138-
}),
128+
compiler_flags = get_compiler_optimization_flags(),
139129
header_namespace = "executorch/kernels/optimized",
140130
visibility = [
141131
"//executorch/...",

kernels/optimized/op_registration_util.bzl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ load(
44
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
55
"get_vec_android_preprocessor_flags",
66
)
7+
load(
8+
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
9+
"get_compiler_optimization_flags",
10+
)
711

812
def op_target(name, deps = []):
913
"""Registers an optimized implementation for an operator overload group.
@@ -87,21 +91,7 @@ def define_op_library(name, deps):
8791
],
8892
# kernels often have helpers with no prototypes just disabling the warning here as the headers
8993
# are codegend and linked in later
90-
compiler_flags = ["-Wno-missing-prototypes"] + select({
91-
"DEFAULT": [],
92-
"ovr_config//os:android": [
93-
"-O2",
94-
] if not runtime.is_oss else [],
95-
"ovr_config//os:iphoneos": [
96-
"-O2",
97-
] if not runtime.is_oss else [],
98-
"ovr_config//os:macos-arm64": [
99-
"-O2",
100-
] if not runtime.is_oss else [],
101-
"ovr_config//os:macos-x86_64": [
102-
"-O2",
103-
] if not runtime.is_oss else [],
104-
}),
94+
compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(),
10595
deps = [
10696
"//executorch/runtime/kernel:kernel_includes",
10797
] + augmented_deps,

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "is_xplat", "runtime")
22
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
33

4+
def get_compiler_optimization_flags():
5+
# various ovr_configs are not available in oss
6+
if not runtime.is_oss:
7+
compiler_flags = select({
8+
"DEFAULT": [],
9+
"ovr_config//os:android-arm64": [
10+
"-O2",
11+
],
12+
"ovr_config//os:iphoneos": [
13+
"-O2",
14+
],
15+
"ovr_config//os:macos-arm64": [
16+
"-O2",
17+
],
18+
})
19+
return compiler_flags
20+
return []
21+
422
def op_target(name, deps = [], android_deps = [], _allow_third_party_deps = False, _aten_mode_deps = []):
523
"""Registers an implementation of an operator overload group.
624
@@ -132,21 +150,7 @@ def define_op_library(name, deps, android_deps, aten_target, _allow_third_party_
132150
# library, and it blocks users like unit tests to use kernel
133151
# implementation directly. So we enable this for xplat only.
134152
["-fvisibility=hidden"] if is_xplat() else []
135-
) + select({
136-
"DEFAULT": [],
137-
"ovr_config//os:android-arm64": [
138-
"-O2",
139-
] if not runtime.is_oss else [],
140-
"ovr_config//os:iphoneos": [
141-
"-O2",
142-
] if not runtime.is_oss else [],
143-
"ovr_config//os:macos-arm64": [
144-
"-O2",
145-
] if not runtime.is_oss else [],
146-
"ovr_config//os:macos-x86_64": [
147-
"-O2",
148-
] if not runtime.is_oss else [],
149-
}),
153+
) + get_compiler_optimization_flags(),
150154
deps = [
151155
"//executorch/runtime/kernel:kernel_includes" + aten_suffix,
152156
] + deps,

0 commit comments

Comments
 (0)