Skip to content

Commit 23ab521

Browse files
committed
Update base for Update on "update cuda delegate resource free pipeline for safety and segfault-free"
This diff survives `clear_all_tensors()` function and enable it during backend destroy stage. Furthermore, we defer the container handle deletion to OS to avoid potential segfault if there's more than one .so files. Differential Revision: [D84135792](https://our.internmc.facebook.com/intern/diff/D84135792/) [ghstack-poisoned]
1 parent 77b7a23 commit 23ab521

File tree

5 files changed

+33
-16
lines changed

5 files changed

+33
-16
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,7 @@ target_compile_options(aoti_common PUBLIC -fexceptions -frtti -fPIC)
4141
target_link_options(aoti_common PUBLIC -Wl,--export-dynamic)
4242

4343
# Link against PyTorch libraries and standard libraries
44-
target_link_libraries(
45-
aoti_common
46-
PUBLIC extension_tensor ${CMAKE_DL_LIBS}
47-
# Link PyTorch libraries for AOTI functions
48-
${TORCH_LIBRARIES}
49-
)
44+
target_link_libraries(aoti_common PUBLIC extension_tensor ${CMAKE_DL_LIBS})
5045
executorch_target_link_options_shared_lib(aoti_common)
5146

5247
install(

backends/aoti/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_common_targets():
5151
link_whole = True,
5252
supports_python_dlopen = True,
5353
visibility = ["@EXECUTORCH_CLIENTS"],
54-
deps = [
54+
exported_deps = [
5555
":common_shims",
5656
":model_container",
5757
],

backends/cuda/runtime/TARGETS

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,25 @@ runtime.cxx_library(
3434
("cuda", None, "cuda-lazy"),
3535
],
3636
)
37+
38+
runtime.cxx_library(
39+
name = "cuda_backend",
40+
srcs = [
41+
"cuda_backend.cpp",
42+
],
43+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
44+
link_whole = True,
45+
supports_python_dlopen = True,
46+
# Constructor needed for backend registration.
47+
compiler_flags = ["-Wno-global-constructors"],
48+
visibility = ["@EXECUTORCH_CLIENTS"],
49+
deps = [
50+
":runtime_shims",
51+
"//executorch/backends/aoti:aoti_common",
52+
"//executorch/runtime/backend:interface",
53+
"//executorch/runtime/core/exec_aten/util:tensor_util",
54+
],
55+
external_deps = [
56+
("cuda", None, "cuda-lazy"),
57+
],
58+
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ class ET_EXPERIMENTAL CudaBackend final
9797
auto aoti_cuda_buffer = named_data_map->get_data(so_blob_key.c_str());
9898
ET_CHECK_OR_RETURN_ERROR(
9999
aoti_cuda_buffer.ok(),
100-
aoti_cuda_buffer.error(),
100+
Internal,
101101
"Failed to get data for key %s: 0x%x",
102102
so_blob_key.c_str(),
103-
aoti_cuda_buffer.error());
103+
static_cast<uint32_t>(aoti_cuda_buffer.error()));
104104

105105
// Generate dynamic temporary file path
106106
filesystem::path temp_dir = filesystem::temp_directory_path();
@@ -311,7 +311,7 @@ class ET_EXPERIMENTAL CudaBackend final
311311
if (handle->cuda_stream != nullptr) {
312312
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
313313
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
314-
ET_CHECK_OR_LOG(
314+
ET_CHECK_OR_LOG_ERROR(
315315
stream_err == cudaSuccess,
316316
"Failed to destroy CUDA stream: %s",
317317
cudaGetErrorString(stream_err));

runtime/platform/log.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ using ::executorch::runtime::LogLevel;
188188
* @param[in] _condition The condition to check.
189189
* @param[in] _format Log message format string.
190190
*/
191-
#define ET_CHECK_OR_LOG(_condition, _format, ...) \
192-
do { \
193-
if (!(_condition)) { \
194-
ET_LOG(Error, _format, ##__VA_ARGS__); \
195-
} \
191+
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) \
192+
do { \
193+
if (!(_condition)) { \
194+
ET_LOG(Error, _format, ##__VA_ARGS__); \
195+
} \
196196
} while (0)
197197

198198
#else // ET_LOG_ENABLED
@@ -211,6 +211,6 @@ using ::executorch::runtime::LogLevel;
211211
* @param[in] _condition The condition to check.
212212
* @param[in] _format Log message format string.
213213
*/
214-
#define ET_CHECK_OR_LOG(_condition, _format, ...) ((void)0)
214+
#define ET_CHECK_OR_LOG_ERROR(_condition, _format, ...) ((void)0)
215215

216216
#endif // ET_LOG_ENABLED

0 commit comments

Comments
 (0)