Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# Common AOTI functionality - combines all AOTI common components
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
set(_aoti_common_sources common_shims.cpp)
add_library(aoti_common STATIC ${_aoti_common_sources})
target_include_directories(
aoti_common
Expand Down
39 changes: 0 additions & 39 deletions backends/aoti/aoti_model_container.cpp

This file was deleted.

26 changes: 7 additions & 19 deletions backends/aoti/aoti_model_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,17 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);

// Global function pointers (will be loaded dynamically)
extern AOTInductorModelContainerCreateWithDeviceFunc
AOTInductorModelContainerCreateWithDevice;
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
extern AOTInductorModelContainerGetNumInputsFunc
AOTInductorModelContainerGetNumInputs;
extern AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs;
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;

// Retrieves the name of an input tensor by index from the AOTI model container.
// Needed by Metal backend
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** input_name);

// Retrieves the number of constants from the AOTI model container.
// Needed by Metal backend
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants);

// Global function pointers (will be loaded dynamically).
// Needed by Metal backend
extern AOTInductorModelContainerGetInputNameFunc
AOTInductorModelContainerGetInputName;
extern AOTInductorModelContainerGetNumConstantsFunc
AOTInductorModelContainerGetNumConstants;

} // extern "C"

// AOTI Delegate Handle structure
Expand All @@ -99,6 +80,13 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency

// Function pointers specific to this handle's shared library
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
AOTInductorModelContainerDeleteFunc delete_container;
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
AOTInductorModelContainerRunFunc run;
};

} // namespace aoti
Expand Down
3 changes: 0 additions & 3 deletions backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def define_common_targets():
# AOTI model container functionality
runtime.cxx_library(
name = "model_container",
srcs = [
"aoti_model_container.cpp",
],
headers = [
"aoti_model_container.h",
],
Expand Down
73 changes: 43 additions & 30 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@

namespace executorch::backends::cuda {

#define LOAD_SYMBOL(name, handle) \
do { \
name = reinterpret_cast<name##Func>(dlsym(handle, #name)); \
ET_CHECK_OR_RETURN_ERROR( \
name != nullptr, AccessFailed, "Failed to load " #name); \
#define LOAD_SYMBOL(handle, member, name, so_handle) \
do { \
handle->member = reinterpret_cast<name##Func>(dlsym(so_handle, #name)); \
ET_CHECK_OR_RETURN_ERROR( \
handle->member != nullptr, AccessFailed, "Failed to load " #name); \
} while (0)

using namespace std;
Expand All @@ -57,12 +57,31 @@ using executorch::runtime::etensor::Tensor;
class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
Error register_shared_library_functions(void* so_handle) const {
LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
LOAD_SYMBOL(
handle,
create_with_device,
AOTInductorModelContainerCreateWithDevice,
so_handle);

LOAD_SYMBOL(
handle, delete_container, AOTInductorModelContainerDelete, so_handle);

LOAD_SYMBOL(
handle,
get_num_inputs,
AOTInductorModelContainerGetNumInputs,
so_handle);

LOAD_SYMBOL(
handle,
get_num_outputs,
AOTInductorModelContainerGetNumOutputs,
so_handle);

LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle);

return Error::Ok;
}
Expand Down Expand Up @@ -135,19 +154,22 @@ class ET_EXPERIMENTAL CudaBackend final

processed->Free();

// Register all shared library functions
ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle));
// Create handle and load function pointers into it
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
handle->so_handle = so_handle;
handle->so_path = so_path.string();

// Load function pointers specific to this handle's shared library
ET_CHECK_OK_OR_RETURN_ERROR(
load_function_pointers_into_handle(so_handle, handle));

AOTInductorModelContainerHandle container_handle = nullptr;

ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice(
&container_handle, 1, "cuda", nullptr));
ET_CHECK_OK_OR_RETURN_ERROR(
handle->create_with_device(&container_handle, 1, "cuda", nullptr));

ET_LOG(Info, "container_handle = %p", container_handle);

AOTIDelegateHandle* handle = new AOTIDelegateHandle();
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;

// Create a CUDA stream for asynchronous execution
Expand All @@ -165,20 +187,11 @@ class ET_EXPERIMENTAL CudaBackend final
Span<EValue*> args) const override {
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Need to re-register all the symbols from the so_handle hosted by this
// CudaBackend instance. The reason is that these symbols are
// static/singleton across the whole process. When we share multiple methods
// (meaning multiple so_handle) in the same process, we need to re-register
// the symbols from the so_handle that is being used in this execution.
ET_CHECK_OK_OR_RETURN_ERROR(
register_shared_library_functions(handle->so_handle));

size_t n_inputs;
AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs);
handle->get_num_inputs(handle->container_handle, &n_inputs);

size_t n_outputs;
AOTInductorModelContainerGetNumOutputs(
handle->container_handle, &n_outputs);
handle->get_num_outputs(handle->container_handle, &n_outputs);

ET_CHECK_OR_RETURN_ERROR(
n_inputs + n_outputs == args.size(),
Expand Down Expand Up @@ -261,7 +274,7 @@ class ET_EXPERIMENTAL CudaBackend final
gpu_outputs[i] = gpu_output_handle;
}
// Run AOTI container with GPU tensors
AOTIRuntimeError error = AOTInductorModelContainerRun(
AOTIRuntimeError error = handle->run(
handle->container_handle,
gpu_inputs.data(), // Use GPU input tensors
n_inputs,
Expand Down
Loading