1515#include < cstdint>
1616#include < cstdlib> // For posix_memalign
1717#include < memory>
18+ #include < unordered_map>
1819#include < unordered_set>
1920#include < vector>
2021
21- // CUDA error checking macro
22- #define ET_CUDA_CHECK_OR_RETURN_ERROR (EXPR ) \
23- do { \
24- const cudaError_t err = EXPR; \
25- if (err == cudaSuccess) { \
26- break ; \
27- } \
28- ET_LOG ( \
29- Error, \
30- " %s:%d CUDA error: %s" , \
31- __FILE__, \
32- __LINE__, \
33- cudaGetErrorString (err)); \
34- return Error::Internal; \
35- } while (0 )
36-
37- // Kernel launch check macro
38- #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR () \
39- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaGetLastError())
40-
4122namespace executorch {
4223namespace backends {
4324namespace cuda {
@@ -46,12 +27,105 @@ using executorch::aten::SizesType;
4627using executorch::aten::StridesType;
4728using executorch::backends::aoti::dtype_to_element_size;
4829using executorch::backends::aoti::dtype_to_scalar_type;
30+ using executorch::backends::aoti::validate_storage_offset;
4931
5032// Global storage for tensors and their metadata
5133std::unordered_set<std::shared_ptr<Tensor>> tensors;
34+ // Global storage for tensor ownership information
35+ std::unordered_map<Tensor*, bool > is_tensor_own_memory;
5236
5337extern " C" {
5438
39+ AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
40+ void * data,
41+ int64_t ndim,
42+ const int64_t * sizes_ptr,
43+ const int64_t * strides_ptr,
44+ int64_t storage_offset,
45+ int32_t dtype,
46+ int32_t device_type,
47+ int32_t device_index,
48+ Tensor** ret_new_tensor,
49+ int32_t layout,
50+ const uint8_t * opaque_metadata,
51+ int64_t opaque_metadata_size) {
52+ // TODO(gasoonjia): verify given data is on the target device
53+ (void )device_type;
54+ (void )opaque_metadata;
55+ (void )layout;
56+ (void )opaque_metadata_size;
57+
58+ // Validate input parameters first
59+ if (data == nullptr ) {
60+ ET_LOG (
61+ Error,
62+ " aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null" );
63+ return Error::InvalidArgument;
64+ }
65+
66+ if (sizes_ptr == nullptr && ndim > 0 ) {
67+ ET_LOG (
68+ Error,
69+ " aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null" );
70+ return Error::InvalidArgument;
71+ }
72+
73+ if (ret_new_tensor == nullptr ) {
74+ ET_LOG (
75+ Error,
76+ " aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null" );
77+ return Error::InvalidArgument;
78+ }
79+
80+ // Check that device_index is always 0
81+ if (device_index != 0 ) {
82+ ET_LOG (Error, " device_index must be 0, got: %d" , device_index);
83+ return Error::InvalidArgument;
84+ }
85+
86+ // Validate dtype using SupportedDTypes from utils.h
87+ AOTITorchError dtype_error = validate_dtype (dtype);
88+ if (dtype_error != Error::Ok) {
89+ return dtype_error;
90+ }
91+
92+ // Storage offset must be 0 since from_blob cannot handle different offsets
93+ AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
94+ if (storage_offset_error != Error::Ok) {
95+ return storage_offset_error;
96+ }
97+
98+ // Convert sizes to the format expected by ExecutorTorch using SizesType
99+ std::vector<executorch::aten::SizesType> sizes =
100+ convert_sizes_to_vector (ndim, sizes_ptr);
101+
102+ // Convert strides using the common helper function with StridesType
103+ std::vector<executorch::aten::StridesType> strides =
104+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
105+
106+ // Create ExecutorTorch tensor that wraps the existing memory
107+ // Note: We're NOT copying the data, just wrapping it
108+ auto tensor = executorch::extension::from_blob (
109+ data, // existing memory (don't copy!)
110+ sizes, // tensor dimensions
111+ strides, // tensor strides (allows different strides)
112+ dtype_to_scalar_type (dtype) // map int32_t dtype to ScalarType
113+ );
114+
115+ if (!tensor) {
116+ ET_LOG (Error, " Failed to create tensor from blob" );
117+ return Error::InvalidArgument;
118+ }
119+
120+ // Store the tensor so it doesn't get destroyed
121+ tensors.insert (tensor);
122+
123+ *ret_new_tensor = tensor.get ();
124+ is_tensor_own_memory[tensor.get ()] = false ;
125+
126+ return Error::Ok;
127+ }
128+
55129AOTITorchError aoti_torch_empty_strided (
56130 int64_t ndim,
57131 const int64_t * sizes_ptr,
@@ -119,6 +193,7 @@ AOTITorchError aoti_torch_empty_strided(
119193 // Store the tensor so it doesn't get destroyed
120194 tensors.insert (tensor);
121195 *ret_new_tensor = tensor.get ();
196+ is_tensor_own_memory[tensor.get ()] = true ;
122197
123198 return Error::Ok;
124199}
@@ -156,9 +231,32 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
156231 // If tensor not found in our tracking, it's invalid
157232 if (!found_in_tensors) {
158233 ET_LOG (Error, " Didn't find tensor %p" , tensor);
234+ // Clean up any stale ownership info
235+ is_tensor_own_memory.erase (tensor);
159236 return Error::InvalidArgument;
160237 }
161238
239+ // Check ownership before cleaning up metadata
240+ auto ownership_it = is_tensor_own_memory.find (tensor);
241+ bool owns_memory = (ownership_it != is_tensor_own_memory.end ())
242+ ? ownership_it->second
243+ : false ;
244+
245+ // Clean up local metadata maps immediately to prevent use-after-free
246+ is_tensor_own_memory.erase (tensor);
247+
248+ if (!owns_memory) {
249+ // Don't free memory since the tensor doesn't own it, but still remove from
250+ // tracking
251+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
252+ if (it->get () == tensor) {
253+ tensors.erase (it);
254+ break ;
255+ }
256+ }
257+ return Error::Ok;
258+ }
259+
162260 // Find and delete the tensor
163261 for (auto it = tensors.begin (); it != tensors.end (); ++it) {
164262 if (it->get () == tensor) {
0 commit comments