Skip to content

Commit 6c1a9ea

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 0ba7e85 + 64349e1 commit 6c1a9ea

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

runtime/core/portable_type/c10/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def define_common_targets():
6767
# fbsource//third-party/sleef:sleef currently fails to
6868
# link with missing symbols, hence the fbcode-specific dep below.
6969
}),
70-
fbcode_exported_deps = [
70+
fbcode_exported_deps = ([
7171
"//caffe2:aten-headers-cpu",
7272
"//caffe2:generated-config-header",
7373
"//caffe2/c10/core:base_headers",
@@ -76,7 +76,7 @@ def define_common_targets():
7676
"ovr_config//cpu:x86_64": [
7777
"third-party//sleef:sleef",
7878
]
79-
}),
79+
})) if not runtime.is_oss else [],
8080
fbcode_exported_preprocessor_flags = [
8181
# We don't -DCPU_CAPABILITY=AVX2 because that trips
8282
# -Wmacro-redefined, and we only care about getting

shim/xplat/executorch/build/env_interface.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _remove_platform_specific_args(kwargs):
119119
keys = []
120120
for key in kwargs:
121121
if (key.endswith("_platform_preprocessor_flags") or key.endswith("_platform_deps") or
122-
key.startswith("fbobjc") or key.endswith("_platform_compiler_flags")):
122+
key.startswith("fbobjc") or key.endswith("_platform_compiler_flags") or key == "fbcode_exported_preprocessor_flags"):
123123
keys.append(key)
124124
for key in keys:
125125
kwargs.pop(key)

0 commit comments

Comments
 (0)