Skip to content

Commit 61a136a

Browse files
committed
cuda backend self-contained by decoupling it with backend/aoti
1 parent 6a96093 commit 61a136a

20 files changed

+662
-176
lines changed

backends/cuda/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ install(
7070
# CUDA-specific AOTI functionality
7171
set(_aoti_cuda_sources
7272
runtime/cuda_backend.cpp
73+
runtime/common_shims.cpp
7374
runtime/shims/memory.cpp
7475
runtime/shims/tensor_attribute.cpp
7576
runtime/guard.cpp
@@ -95,10 +96,9 @@ target_link_options(
9596
aoti_cuda PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
9697
)
9798

98-
# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and PyTorch
99-
# CUDA libraries
99+
# Link against CUDA::cudart, cuda_tensor_maker, and PyTorch CUDA libraries
100100
target_link_libraries(
101-
aoti_cuda PUBLIC aoti_common cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
101+
aoti_cuda PUBLIC executorch cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
102102
)
103103
# If you need other CUDA libraries, link them similarly:
104104
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/evalue.h>
13+
14+
namespace executorch {
15+
namespace backends {
16+
namespace cuda {
17+
18+
using executorch::runtime::Error;
19+
using executorch::runtime::etensor::Tensor;
20+
21+
extern "C" {
22+
23+
// Type definitions
24+
using AOTITensorHandle = Tensor*;
25+
using AOTIRuntimeError = Error;
26+
27+
// Forward declarations for AOT Inductor model container
28+
struct AOTInductorModelContainerOpaque;
29+
using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque*;
30+
using AOTInductorStreamHandle = void*;
31+
using AOTIProxyExecutorHandle = void*;
32+
33+
// Function pointer types for AOT Inductor model container operations
34+
using AOTInductorModelContainerCreateWithDeviceFunc = AOTIRuntimeError (*)(
35+
AOTInductorModelContainerHandle* container_handle,
36+
size_t num_models,
37+
const char* device_str,
38+
const char* cubin_dir);
39+
40+
using AOTInductorModelContainerDeleteFunc =
41+
AOTIRuntimeError (*)(AOTInductorModelContainerHandle container_handle);
42+
43+
using AOTInductorModelContainerGetNumInputsFunc = AOTIRuntimeError (*)(
44+
AOTInductorModelContainerHandle container_handle,
45+
size_t* num_inputs);
46+
47+
using AOTInductorModelContainerGetNumOutputsFunc = AOTIRuntimeError (*)(
48+
AOTInductorModelContainerHandle container_handle,
49+
size_t* num_outputs);
50+
51+
using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
52+
AOTInductorModelContainerHandle container_handle,
53+
Tensor** input_handles, // array of input Tensor*; handles
54+
// are stolen; the array itself is borrowed
55+
size_t num_inputs,
56+
Tensor** output_handles, // array for writing output Tensor*; handles
57+
// will be stolen by the caller; the array itself
58+
// is borrowed
59+
size_t n_outputs,
60+
AOTInductorStreamHandle stream_handle,
61+
AOTIProxyExecutorHandle proxy_executor_handle);
62+
63+
// Retrieves the name of an input tensor by index from the AOTI model container.
64+
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
65+
AOTInductorModelContainerHandle container_handle,
66+
size_t input_idx,
67+
const char** input_name);
68+
69+
// Retrieves the number of constants from the AOTI model container.
70+
using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
71+
AOTInductorModelContainerHandle container_handle,
72+
size_t* num_constants);
73+
74+
// Update the model container with the constant tensors
75+
using AOTInductorModelUpdateConstantsFromBlobFunc = AOTIRuntimeError (*)(
76+
AOTInductorModelContainerHandle container_handle,
77+
const uint8_t* weight_blob_ptr);
78+
79+
} // extern "C"
80+
81+
// AOTI Delegate Handle structure
82+
struct AOTIDelegateHandle {
83+
void* so_handle;
84+
std::string so_path;
85+
AOTInductorModelContainerHandle container_handle;
86+
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
87+
// dependency
88+
89+
// Function pointers specific to this handle's shared library
90+
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
91+
AOTInductorModelContainerDeleteFunc delete_container;
92+
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
93+
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
94+
AOTInductorModelContainerRunFunc run;
95+
AOTInductorModelUpdateConstantsFromBlobFunc update_constants_from_blob;
96+
};
97+
98+
} // namespace cuda
99+
} // namespace backends
100+
} // namespace executorch
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cuda/runtime/common_shims.h>
10+
#include <executorch/runtime/platform/log.h>
11+
#include <cstdint>
12+
13+
namespace executorch {
14+
namespace backends {
15+
namespace cuda {
16+
17+
namespace internal {
18+
// Global storage for tensor metadata
19+
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_sizes;
20+
std::unordered_map<Tensor*, std::vector<int64_t>> tensor_to_strides;
21+
} // namespace internal
22+
23+
extern "C" {
24+
25+
// Autograd mode functions
26+
int32_t aoti_torch_grad_mode_is_enabled() {
27+
// No autograd ever
28+
return false;
29+
}
30+
31+
void aoti_torch_grad_mode_set_enabled(bool enabled) {
32+
if (enabled) {
33+
throw std::runtime_error("Cannot enable autograd");
34+
}
35+
}
36+
37+
// Tensor attribute operations
38+
AOTITorchError aoti_torch_get_data_ptr(Tensor* tensor, void** ret_data_ptr) {
39+
*ret_data_ptr = tensor->mutable_data_ptr();
40+
return Error::Ok;
41+
}
42+
43+
AOTITorchError aoti_torch_get_storage_offset(
44+
Tensor* tensor,
45+
int64_t* ret_storage_offset) {
46+
// Storage offset is always 0 in ET
47+
*ret_storage_offset = 0;
48+
49+
return Error::Ok;
50+
}
51+
52+
AOTITorchError aoti_torch_get_strides(Tensor* tensor, int64_t** ret_strides) {
53+
auto it = internal::tensor_to_strides.find(tensor);
54+
bool needs_update = false;
55+
56+
if (it == internal::tensor_to_strides.end()) {
57+
needs_update = true;
58+
} else {
59+
// CRITICAL: Multimodal models reuse tensors with different shapes across
60+
// executions (e.g., variable-length audio). We MUST validate cached
61+
// metadata matches current tensor state, or CUDA kernels will receive
62+
// incorrect shapes leading to memory corruption and segfaults.
63+
auto tensor_strides = tensor->strides();
64+
needs_update = !std::equal(
65+
it->second.begin(),
66+
it->second.end(),
67+
tensor_strides.begin(),
68+
tensor_strides.end());
69+
}
70+
71+
if (needs_update) {
72+
std::vector<int64_t> strides(tensor->dim());
73+
auto tensor_strides = tensor->strides();
74+
for (int i = 0; i < tensor->dim(); i++) {
75+
strides[i] = tensor_strides[i];
76+
}
77+
it =
78+
internal::tensor_to_strides.insert_or_assign(tensor, std::move(strides))
79+
.first;
80+
}
81+
82+
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
83+
// return a valid pointer
84+
if (it->second.empty()) {
85+
static int64_t empty_strides_placeholder = 0;
86+
*ret_strides = &empty_strides_placeholder;
87+
} else {
88+
*ret_strides = it->second.data();
89+
}
90+
91+
return Error::Ok;
92+
}
93+
94+
AOTITorchError aoti_torch_get_dtype(Tensor* tensor, int32_t* ret_dtype) {
95+
*ret_dtype = static_cast<int32_t>(tensor->scalar_type());
96+
97+
return Error::Ok;
98+
}
99+
100+
AOTITorchError aoti_torch_get_sizes(Tensor* tensor, int64_t** ret_sizes) {
101+
auto it = internal::tensor_to_sizes.find(tensor);
102+
bool needs_update = false;
103+
104+
if (it == internal::tensor_to_sizes.end()) {
105+
needs_update = true;
106+
} else {
107+
// CRITICAL: Multimodal models reuse tensors with different shapes across
108+
// executions (e.g., variable-length audio). We MUST validate cached
109+
// metadata matches current tensor state, or CUDA kernels will receive
110+
// incorrect shapes leading to memory corruption and segfaults.
111+
auto tensor_sizes = tensor->sizes();
112+
needs_update = !std::equal(
113+
it->second.begin(),
114+
it->second.end(),
115+
tensor_sizes.begin(),
116+
tensor_sizes.end());
117+
}
118+
119+
if (needs_update) {
120+
std::vector<int64_t> sizes(tensor->dim());
121+
auto tensor_sizes = tensor->sizes();
122+
for (int i = 0; i < tensor->dim(); i++) {
123+
sizes[i] = tensor_sizes[i];
124+
}
125+
it = internal::tensor_to_sizes.insert_or_assign(tensor, std::move(sizes))
126+
.first;
127+
}
128+
129+
// For 0D tensors, data() returns nullptr on empty vectors, but we need to
130+
// return a valid pointer
131+
if (it->second.empty()) {
132+
static int64_t empty_sizes_placeholder = 0;
133+
*ret_sizes = &empty_sizes_placeholder;
134+
} else {
135+
*ret_sizes = it->second.data();
136+
}
137+
138+
return Error::Ok;
139+
}
140+
141+
AOTITorchError aoti_torch_get_device_index(
142+
Tensor* tensor,
143+
int32_t* ret_device_index) {
144+
// Let's assume all tensors AOTI using are on CUDA:0
145+
*ret_device_index = 0;
146+
return Error::Ok;
147+
}
148+
149+
AOTITorchError aoti_torch_get_dim(Tensor* tensor, int64_t* ret_dim) {
150+
*ret_dim = static_cast<int64_t>(tensor->dim());
151+
return Error::Ok;
152+
}
153+
154+
// Device and layout utility functions
155+
int32_t aoti_torch_device_type_cpu() {
156+
// Let's say cpu is 0 for ET as well
157+
return 0;
158+
}
159+
160+
int32_t aoti_torch_layout_strided() {
161+
// ET only support strided layout, the return value will always be 0, a.k.a
162+
// at::Layout::Strided;
163+
return 0;
164+
}
165+
166+
// Dtype constants - these return the PyTorch dtype codes
167+
int32_t aoti_torch_dtype_float32() {
168+
return 6; // PyTorch's float32 dtype code
169+
}
170+
171+
int32_t aoti_torch_dtype_bfloat16() {
172+
return 15; // PyTorch's bfloat16 dtype code
173+
}
174+
175+
int32_t aoti_torch_dtype_int8() {
176+
return 1; // PyTorch's int32 dtype code
177+
}
178+
179+
int32_t aoti_torch_dtype_int16() {
180+
return 2; // PyTorch's int32 dtype code
181+
}
182+
183+
int32_t aoti_torch_dtype_int32() {
184+
return 3; // PyTorch's int32 dtype code
185+
}
186+
187+
int32_t aoti_torch_dtype_bool() {
188+
return 11; // PyTorch's bool dtype code
189+
}
190+
191+
int32_t aoti_torch_dtype_int64() {
192+
return 4; // PyTorch's int64 dtype code
193+
}
194+
195+
// Dtype utility function needed by Metal backend.
196+
// Returns the size of the dtype in bytes.
197+
size_t aoti_torch_dtype_element_size(int32_t dtype) {
198+
return dtype_to_element_size(dtype);
199+
}
200+
201+
// Cleanup functions
202+
void cleanup_tensor_metadata() {
203+
internal::tensor_to_sizes.clear();
204+
internal::tensor_to_strides.clear();
205+
}
206+
207+
} // extern "C"
208+
209+
} // namespace cuda
210+
} // namespace backends
211+
} // namespace executorch

0 commit comments

Comments
 (0)