Skip to content

Commit 243cede

Browse files
committed
prototype e2e works on latest ET
1 parent fe438f9 commit 243cede

File tree

6 files changed

+35
-13
lines changed

6 files changed

+35
-13
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
# https://github.com/google/XNNPACK/commit/c690daa67f883e1b627aadf7684c06797e9a0684
5050
cmake_minimum_required(VERSION 3.29)
5151
project(executorch)
52+
# project(executorch LANGUAGES CXX CUDA)
53+
5254

5355
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
5456
include(${PROJECT_SOURCE_DIR}/tools/cmake/Utils.cmake)

backends/aoti/CMakeLists.txt

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,36 @@ if(NOT EXECUTORCH_ROOT)
2121
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
2222
endif()
2323

24-
include(${EXECUTORCH_ROOT}/build/Utils.cmake)
24+
# include(${EXECUTORCH_ROOT}/build/Utils.cmake)
2525

26-
find_package(CUDA)
27-
28-
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
26+
find_package(CUDAToolkit REQUIRED)
2927

3028
set(_aoti_sources runtime/AotiBackend.cpp)
31-
3229
add_library(aoti_backend STATIC ${_aoti_sources})
3330
target_include_directories(
34-
aoti_backend PUBLIC ${_common_include_directories} ${CUDA_INCLUDE_DIRS}
31+
aoti_backend
32+
PUBLIC
33+
${CUDAToolkit_INCLUDE_DIRS}
34+
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}>
35+
$<INSTALL_INTERFACE:include>
3536
)
36-
3737
target_compile_options(aoti_backend PUBLIC -fexceptions -frtti -fPIC)
38-
target_link_libraries(aoti_backend PUBLIC extension_tensor ${CUDA_LIBRARIES})
38+
# Ensure symbols are exported properly
39+
target_link_options(aoti_backend PUBLIC -Wl,--export-dynamic)
40+
# Link against CUDA::cudart (the CUDA runtime library)
41+
target_link_libraries(
42+
aoti_backend
43+
PUBLIC
44+
extension_tensor
45+
CUDA::cudart
46+
${CMAKE_DL_LIBS}
47+
)
48+
# If you need other CUDA libraries, link them similarly:
49+
# target_link_libraries(aoti_backend PUBLIC CUDA::cublas CUDA::cufft ...)
50+
# If you have a custom function, keep it
3951
executorch_target_link_options_shared_lib(aoti_backend)
40-
4152
install(
4253
TARGETS aoti_backend
4354
EXPORT ExecuTorchTargets
4455
DESTINATION lib
45-
INCLUDES
46-
DESTINATION ${_common_include_directories}
4756
)

backends/aoti/aoti_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def preprocess(
3434
so_path = torch._inductor.aot_compile(graph_module, args, kwargs, options={}) # type: ignore[arg-type]
3535
print(so_path)
3636
check_call(
37-
f"patchelf --remove-needed libtorch.so --remove-needed libtorch_cuda.so --remove-needed libc10_cuda.so --remove-needed libtorch_cpu.so --add-needed libcudart.so {so_path}",
37+
f"patchelf --remove-needed libtorch.so --remove-needed libc10.so --remove-needed libtorch_cuda.so --remove-needed libc10_cuda.so --remove-needed libtorch_cpu.so --add-needed libcudart.so {so_path}",
3838
shell=True,
3939
)
4040

backends/aoti/runtime/AotiBackend.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/arm/runtime/VelaBinStream.h>
109
#include <executorch/extension/tensor/tensor.h>
1110
#include <executorch/runtime/backend/interface.h>
1211
#include <executorch/runtime/core/error.h>

export_and_run_aoti.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
./install_executorch.sh
2+
python export_add.py
3+
./install_executorch.sh --clean
4+
mkdir -p cmake-out
5+
cd cmake-out
6+
cmake -DEXECUTORCH_BUILD_AOTI=ON \
7+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
8+
..
9+
cd ..
10+
cmake --build cmake-out -j9
11+
./cmake-out/executor_runner --model_path add.pte

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ zstd # Imported by resolve_buck.py.
1010
certifi # Imported by resolve_buck.py.
1111
lintrunner==0.12.7
1212
lintrunner-adapters==0.12.6
13+
patchelf

0 commit comments

Comments
 (0)