Skip to content

Commit f7f97f7

Browse files
introduce shim layers for cudaguard and cudastreamguard (pytorch#14925)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: pytorch#14902 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/47/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/47/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/47/orig Differential Revision: [D84126634](https://our.internmc.facebook.com/intern/diff/D84126634/) @diff-train-skip-merge --------- Co-authored-by: gasoonjia <[email protected]> Co-authored-by: Gasoonjia <[email protected]>
1 parent 29b4db8 commit f7f97f7

20 files changed

+591
-193
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,8 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
4040
# Ensure symbols are exported properly
4141
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
4242

43-
# Link against PyTorch libraries and standard libraries
44-
target_link_libraries(
45-
aoti_common
46-
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
47-
# Link PyTorch libraries for AOTI functions
48-
${TORCH_LIBRARIES}
49-
)
43+
# Link against ExecuTorch libraries and standard libraries
44+
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
5045
executorch_target_link_options_shared_lib(aoti_common)
5146

5247
install(

backends/aoti/aoti_model_container.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ struct AOTIDelegateHandle {
7777
void* so_handle;
7878
std::string so_path;
7979
AOTInductorModelContainerHandle container_handle;
80+
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
81+
// dependency
8082
};
8183

8284
} // namespace aoti

backends/aoti/common_shims.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,18 @@ int32_t aoti_torch_layout_strided() {
127127
}
128128

129129
// Dtype constants - these return the PyTorch dtype codes
130-
// Currently only float32 is supported, but using robust enum-based approach
131130
int32_t aoti_torch_dtype_float32() {
132131
return 6; // PyTorch's float32 dtype code
133132
}
134133

134+
int32_t aoti_torch_dtype_bfloat16() {
135+
return 15; // PyTorch's bfloat16 dtype code
136+
}
137+
138+
int32_t aoti_torch_dtype_int64() {
139+
return 4; // PyTorch's int64 dtype code
140+
}
141+
135142
// Cleanup functions
136143
void cleanup_tensor_metadata() {
137144
internal::tensor_to_sizes.clear();

backends/aoti/common_shims.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim);
5858
int32_t aoti_torch_device_type_cpu();
5959
int32_t aoti_torch_layout_strided();
6060
int32_t aoti_torch_dtype_float32();
61+
int32_t aoti_torch_dtype_bfloat16();
62+
int32_t aoti_torch_dtype_int64();
6163

6264
// Autograd mode functions
6365
int32_t aoti_torch_grad_mode_is_enabled();

backends/aoti/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_common_targets():
5151
link_whole = True,
5252
supports_python_dlopen = True,
5353
visibility = ["@EXECUTORCH_CLIENTS"],
54-
deps = [
54+
exported_deps = [
5555
":common_shims",
5656
":model_container",
5757
],

backends/cuda/CMakeLists.txt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
3535
find_package_torch()
3636

3737
# CUDA-specific AOTI functionality
38-
set(_aoti_cuda_sources runtime/cuda_backend.cpp runtime/shims/memory.cpp
39-
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
38+
set(_aoti_cuda_sources
39+
runtime/cuda_backend.cpp runtime/shims/memory.cpp
40+
runtime/shims/tensor_attribute.cpp runtime/guard.cpp
41+
runtime/shims/cuda_guard.cpp
4042
)
4143
add_library(aoti_cuda STATIC ${_aoti_cuda_sources})
4244
target_include_directories(
@@ -53,10 +55,7 @@ target_link_options(aoti_cuda PUBLIC -Wl,--export-dynamic)
5355

5456
# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
5557
target_link_libraries(
56-
aoti_cuda
57-
PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
58-
# Link PyTorch libraries for AOTI CUDA functions
59-
${TORCH_LIBRARIES}
58+
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
6059
)
6160
# If you need other CUDA libraries, link them similarly:
6261
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)

backends/cuda/runtime/TARGETS

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ runtime.cxx_library(
66
name = "runtime_shims",
77
srcs = [
88
"guard.cpp",
9+
"shims/cuda_guard.cpp",
910
"shims/memory.cpp",
1011
"shims/tensor_attribute.cpp",
1112
],
1213
headers = [
1314
"guard.h",
15+
"shims/cuda_guard.h",
1416
"shims/memory.h",
1517
"shims/tensor_attribute.h",
1618
"utils.h",
@@ -32,3 +34,25 @@ runtime.cxx_library(
3234
("cuda", None, "cuda-lazy"),
3335
],
3436
)
37+
38+
runtime.cxx_library(
39+
name = "cuda_backend",
40+
srcs = [
41+
"cuda_backend.cpp",
42+
],
43+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
44+
link_whole = True,
45+
supports_python_dlopen = True,
46+
# Constructor needed for backend registration.
47+
compiler_flags = ["-Wno-global-constructors"],
48+
visibility = ["@EXECUTORCH_CLIENTS"],
49+
deps = [
50+
":runtime_shims",
51+
"//executorch/backends/aoti:aoti_common",
52+
"//executorch/runtime/backend:interface",
53+
"//executorch/runtime/core/exec_aten/util:tensor_util",
54+
],
55+
external_deps = [
56+
("cuda", None, "cuda-lazy"),
57+
],
58+
)

0 commit comments

Comments
 (0)