Skip to content

Commit 03a87cc

Browse files
committed
Update on "Reuse GELU implementation from PyTorch core"
kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch. Note that, because we will pick up Sleef internally and ignore it externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS. Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break. Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/) [ghstack-poisoned]
2 parents 0151665 + 917ea4c commit 03a87cc

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

build/Utils.cmake

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,16 @@ function(resolve_python_executable)
321321
)
322322
endif()
323323
endfunction()
324+
325+
# find_package(Torch CONFIG REQUIRED) replacement for targets that
326+
# have a header-only Torch dependency. Because find_package sets
327+
# variables in the parent scope, we use a macro to preserve this
328+
# rather than maintaining our own list of those variables.
329+
macro(find_package_torch_headers)
330+
# We cannot simply use CMAKE_FIND_ROOT_PATH_BOTH, because that does
331+
# not propagate into TorchConfig.cmake.
332+
set(OLD_CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ${CMAKE_FIND_ROOT_PATH_MODE_PACKAGE})
333+
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE BOTH)
334+
find_package(Torch CONFIG REQUIRED)
335+
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ${OLD_CMAKE_FIND_ROOT_PATH_MODE_PACKAGE})
336+
endmacro()

kernels/optimized/CMakeLists.txt

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,7 @@ message("Generated files ${gen_command_sources}")
6363

6464
list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
6565
add_library(optimized_kernels ${_optimized_kernels__srcs})
66-
# We require Torch headers, which setup.py puts in CMAKE_PREFIX_PATH
67-
# for us. Toolchains that we might be using for cross-compiling could
68-
# set CMAKE_FIND_ROOT_PATH, which prevents find_package from finding
69-
# headers not rooted under CMAKE_FIND_ROOT_PATH. This is reasonable
70-
# for binary dependencies because they probably aren't built for the
71-
# target platform, but for our header-only use case, we should just
72-
# ignore CMAKE_FIND_ROOT_PATH.
73-
find_package(Torch CONFIG REQUIRED NO_CMAKE_FIND_ROOT_PATH)
66+
find_package_torch_headers()
7467
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
7568
target_link_libraries(
7669
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool

0 commit comments

Comments
 (0)