Skip to content

Commit 2bea318

Browse files
authored
[aoti-et] Store symbols from dlopen into AOTIDelegateHandle (#15172)
This pull request refactors how function pointers for AOTI model container operations are managed and loaded in the CUDA backend. Instead of relying on global/static function pointers, function pointers are now stored per-instance in the `AOTIDelegateHandle` structure. This change enables safe handling of multiple shared libraries within the same process and improves encapsulation and maintainability. **Refactoring function pointer management:** * Removed global function pointers for AOTI model container operations from `aoti_model_container.cpp` and `aoti_model_container.h`, and moved them into the `AOTIDelegateHandle` struct as per-instance members. [[1]](diffhunk://#diff-32ff58ae0581446607da6874fa62b366ba18bcff4d621b16987fda78312244a6L1-L39) [[2]](diffhunk://#diff-84caca41e72ad693665c930ab7d0c31e05f64b268f4d7ac37c17869149fad0c7L63-L92) [[3]](diffhunk://#diff-84caca41e72ad693665c930ab7d0c31e05f64b268f4d7ac37c17869149fad0c7R83-R89) **CUDA backend updates:** * Updated the CUDA backend (`cuda_backend.cpp`) to load function pointers into each `AOTIDelegateHandle` instance using a new `load_function_pointers_into_handle` method, replacing the previous global symbol registration logic. All calls to model container functions now use the handle's member function pointers. [[1]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L31-R35) [[2]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L60-R84) [[3]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L138-L150) [[4]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L168-R194) [[5]](diffhunk://#diff-a4b17eccf1aa933837671c5184e02bc815d934a362344bb2b17b789cdfaa5375L264-R277) **Build system adjustments:** * Removed `aoti_model_container.cpp` from the build targets and library sources, as global function pointer definitions are no longer needed. [[1]](diffhunk://#diff-c95a0b47f516c30f4b2e384b88c94c088d1031e6df7af66678a6fc9d3fb1a1a5L29-L31) [[2]](diffhunk://#diff-c3d5933d211acc568c9bdf8e08d0ca99b01e50bca113307fbab4cbc4018fdf55L29-R29)
1 parent 11c0b4f commit 2bea318

File tree

5 files changed

+56
-97
lines changed

5 files changed

+56
-97
lines changed

backends/aoti/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2626
find_package_torch()
2727

2828
# Common AOTI functionality - combines all AOTI common components
29-
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
29+
set(_aoti_common_sources common_shims.cpp)
3030
add_library(aoti_common STATIC ${_aoti_common_sources})
3131
target_include_directories(
3232
aoti_common
Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,36 +60,17 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
6060
AOTInductorStreamHandle stream_handle,
6161
AOTIProxyExecutorHandle proxy_executor_handle);
6262

63-
// Global function pointers (will be loaded dynamically)
64-
extern AOTInductorModelContainerCreateWithDeviceFunc
65-
AOTInductorModelContainerCreateWithDevice;
66-
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
67-
extern AOTInductorModelContainerGetNumInputsFunc
68-
AOTInductorModelContainerGetNumInputs;
69-
extern AOTInductorModelContainerGetNumOutputsFunc
70-
AOTInductorModelContainerGetNumOutputs;
71-
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;
72-
7363
// Retrieves the name of an input tensor by index from the AOTI model container.
74-
// Needed by Metal backend
7564
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
7665
AOTInductorModelContainerHandle container_handle,
7766
size_t input_idx,
7867
const char** input_name);
7968

8069
// Retrieves the number of constants from the AOTI model container.
81-
// Needed by Metal backend
8270
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
8371
AOTInductorModelContainerHandle container_handle,
8472
size_t* num_constants);
8573

86-
// Global function pointers (will be loaded dynamically).
87-
// Needed by Metal backend
88-
extern AOTInductorModelContainerGetInputNameFunc
89-
AOTInductorModelContainerGetInputName;
90-
extern AOTInductorModelContainerGetNumConstantsFunc
91-
AOTInductorModelContainerGetNumConstants;
92-
9374
} // extern "C"
9475

9576
// AOTI Delegate Handle structure
@@ -99,6 +80,13 @@ struct AOTIDelegateHandle {
9980
AOTInductorModelContainerHandle container_handle;
10081
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
10182
// dependency
83+
84+
// Function pointers specific to this handle's shared library
85+
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
86+
AOTInductorModelContainerDeleteFunc delete_container;
87+
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
88+
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
89+
AOTInductorModelContainerRunFunc run;
10290
};
10391

10492
} // namespace aoti

backends/aoti/aoti_model_container.cpp

Lines changed: 0 additions & 39 deletions
This file was deleted.

backends/aoti/targets.bzl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,9 @@ def define_common_targets():
2525

2626
# AOTI model container functionality
2727
runtime.cxx_library(
28-
name = "model_container",
29-
srcs = [
30-
"aoti_model_container.cpp",
31-
],
28+
name = "delegate_handle",
3229
headers = [
33-
"aoti_model_container.h",
30+
"aoti_delegate_handle.h",
3431
],
3532
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
3633
link_whole = True,
@@ -44,7 +41,7 @@ def define_common_targets():
4441
],
4542
)
4643

47-
# Common AOTI functionality (combining both common_shims and model_container)
44+
# Common AOTI functionality (combining both common_shims and delegate_handle)
4845
runtime.cxx_library(
4946
name = "aoti_common",
5047
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
@@ -53,6 +50,6 @@ def define_common_targets():
5350
visibility = ["@EXECUTORCH_CLIENTS"],
5451
exported_deps = [
5552
":common_shims",
56-
":model_container",
53+
":delegate_handle",
5754
],
5855
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
#include <vector>
2222

2323
// Include our shim layer headers
24-
#include <executorch/backends/aoti/aoti_model_container.h>
24+
#include <executorch/backends/aoti/aoti_delegate_handle.h>
2525
#include <executorch/backends/aoti/common_shims.h>
2626
#include <executorch/backends/cuda/runtime/shims/memory.h>
2727
#include <executorch/backends/cuda/runtime/utils.h>
2828

2929
namespace executorch::backends::cuda {
3030

31-
#define LOAD_SYMBOL(name, handle) \
32-
do { \
33-
name = reinterpret_cast<name##Func>(dlsym(handle, #name)); \
34-
ET_CHECK_OR_RETURN_ERROR( \
35-
name != nullptr, AccessFailed, "Failed to load " #name); \
31+
#define LOAD_SYMBOL(handle, member, name, so_handle) \
32+
do { \
33+
handle->member = reinterpret_cast<name##Func>(dlsym(so_handle, #name)); \
34+
ET_CHECK_OR_RETURN_ERROR( \
35+
handle->member != nullptr, AccessFailed, "Failed to load " #name); \
3636
} while (0)
3737

3838
using namespace std;
@@ -57,12 +57,31 @@ using executorch::runtime::etensor::Tensor;
5757
class ET_EXPERIMENTAL CudaBackend final
5858
: public ::executorch::runtime::BackendInterface {
5959
private:
60-
Error register_shared_library_functions(void* so_handle) const {
61-
LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle);
62-
LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle);
63-
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
64-
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
65-
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
60+
Error load_function_pointers_into_handle(
61+
void* so_handle,
62+
AOTIDelegateHandle* handle) const {
63+
LOAD_SYMBOL(
64+
handle,
65+
create_with_device,
66+
AOTInductorModelContainerCreateWithDevice,
67+
so_handle);
68+
69+
LOAD_SYMBOL(
70+
handle, delete_container, AOTInductorModelContainerDelete, so_handle);
71+
72+
LOAD_SYMBOL(
73+
handle,
74+
get_num_inputs,
75+
AOTInductorModelContainerGetNumInputs,
76+
so_handle);
77+
78+
LOAD_SYMBOL(
79+
handle,
80+
get_num_outputs,
81+
AOTInductorModelContainerGetNumOutputs,
82+
so_handle);
83+
84+
LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle);
6685

6786
return Error::Ok;
6887
}
@@ -135,19 +154,22 @@ class ET_EXPERIMENTAL CudaBackend final
135154

136155
processed->Free();
137156

138-
// Register all shared library functions
139-
ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle));
157+
// Create handle and load function pointers into it
158+
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
159+
handle->so_handle = so_handle;
160+
handle->so_path = so_path.string();
161+
162+
// Load function pointers specific to this handle's shared library
163+
ET_CHECK_OK_OR_RETURN_ERROR(
164+
load_function_pointers_into_handle(so_handle, handle));
140165

141166
AOTInductorModelContainerHandle container_handle = nullptr;
142167

143-
ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice(
144-
&container_handle, 1, "cuda", nullptr));
168+
ET_CHECK_OK_OR_RETURN_ERROR(
169+
handle->create_with_device(&container_handle, 1, "cuda", nullptr));
145170

146171
ET_LOG(Info, "container_handle = %p", container_handle);
147172

148-
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
149-
handle->so_handle = so_handle;
150-
handle->so_path = so_path.string();
151173
handle->container_handle = container_handle;
152174

153175
// Create a CUDA stream for asynchronous execution
@@ -165,20 +187,11 @@ class ET_EXPERIMENTAL CudaBackend final
165187
Span<EValue*> args) const override {
166188
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
167189

168-
// Need to re-register all the symbols from the so_handle hosted by this
169-
// CudaBackend instance. The reason is that these symbols are
170-
// static/singleton across the whole process. When we share multiple methods
171-
// (meaning multiple so_handle) in the same process, we need to re-register
172-
// the symbols from the so_handle that is being used in this execution.
173-
ET_CHECK_OK_OR_RETURN_ERROR(
174-
register_shared_library_functions(handle->so_handle));
175-
176190
size_t n_inputs;
177-
AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs);
191+
handle->get_num_inputs(handle->container_handle, &n_inputs);
178192

179193
size_t n_outputs;
180-
AOTInductorModelContainerGetNumOutputs(
181-
handle->container_handle, &n_outputs);
194+
handle->get_num_outputs(handle->container_handle, &n_outputs);
182195

183196
ET_CHECK_OR_RETURN_ERROR(
184197
n_inputs + n_outputs == args.size(),
@@ -261,7 +274,7 @@ class ET_EXPERIMENTAL CudaBackend final
261274
gpu_outputs[i] = gpu_output_handle;
262275
}
263276
// Run AOTI container with GPU tensors
264-
AOTIRuntimeError error = AOTInductorModelContainerRun(
277+
AOTIRuntimeError error = handle->run(
265278
handle->container_handle,
266279
gpu_inputs.data(), // Use GPU input tensors
267280
n_inputs,

0 commit comments

Comments
 (0)