1515#include < cstdint>
1616#include < cstdlib> // For posix_memalign
1717#include < memory>
18+ #include < unordered_map>
1819#include < unordered_set>
1920#include < vector>
2021
21- // CUDA error checking macro
22- #define ET_CUDA_CHECK_OR_RETURN_ERROR (EXPR ) \
23- do { \
24- const cudaError_t err = EXPR; \
25- if (err == cudaSuccess) { \
26- break ; \
27- } \
28- ET_LOG ( \
29- Error, \
30- " %s:%d CUDA error: %s" , \
31- __FILE__, \
32- __LINE__, \
33- cudaGetErrorString (err)); \
34- return Error::Internal; \
35- } while (0 )
36-
37- // Kernel launch check macro
38- #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR () \
39- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaGetLastError())
40-
4122namespace executorch {
4223namespace backends {
4324namespace cuda {
@@ -46,12 +27,99 @@ using executorch::aten::SizesType;
4627using executorch::aten::StridesType;
4728using executorch::backends::aoti::dtype_to_element_size;
4829using executorch::backends::aoti::dtype_to_scalar_type;
30+ using executorch::backends::aoti::validate_storage_offset;
4931
5032// Global storage for tensors and their metadata
5133std::unordered_set<std::shared_ptr<Tensor>> tensors;
34+ // Global storage for tensor ownership information
35+ std::unordered_map<Tensor*, bool > is_tensor_own_memory;
5236
5337extern " C" {
5438
39+ AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
40+ void * data,
41+ int64_t ndim,
42+ const int64_t * sizes_ptr,
43+ const int64_t * strides_ptr,
44+ int64_t storage_offset,
45+ int32_t dtype,
46+ int32_t device_type,
47+ int32_t device_index,
48+ Tensor** ret_new_tensor,
49+ int32_t layout,
50+ const uint8_t * opaque_metadata,
51+ int64_t opaque_metadata_size) {
52+ // Validate input parameters first
53+ if (data == nullptr ) {
54+ ET_LOG (
55+ Error,
56+ " aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null" );
57+ return Error::InvalidArgument;
58+ }
59+
60+ if (sizes_ptr == nullptr && ndim > 0 ) {
61+ ET_LOG (
62+ Error,
63+ " aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null" );
64+ return Error::InvalidArgument;
65+ }
66+
67+ if (ret_new_tensor == nullptr ) {
68+ ET_LOG (
69+ Error,
70+ " aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null" );
71+ return Error::InvalidArgument;
72+ }
73+
74+ // Check that device_index is always 0
75+ if (device_index != 0 ) {
76+ ET_LOG (Error, " device_index must be 0, got: %d" , device_index);
77+ return Error::InvalidArgument;
78+ }
79+
80+ // Validate dtype using SupportedDTypes from utils.h
81+ AOTITorchError dtype_error = validate_dtype (dtype);
82+ if (dtype_error != Error::Ok) {
83+ return dtype_error;
84+ }
85+
86+ // Storage offset must be 0 since from_blob cannot handle different offsets
87+ AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
88+ if (storage_offset_error != Error::Ok) {
89+ return storage_offset_error;
90+ }
91+
92+ // Convert sizes to the format expected by ExecutorTorch using SizesType
93+ std::vector<executorch::aten::SizesType> sizes =
94+ convert_sizes_to_vector (ndim, sizes_ptr);
95+
96+ // Convert strides using the common helper function with StridesType
97+ std::vector<executorch::aten::StridesType> strides =
98+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
99+
100+ // Create ExecutorTorch tensor that wraps the existing memory
101+ // Note: We're NOT copying the data, just wrapping it
102+ auto tensor = executorch::extension::from_blob (
103+ data, // existing memory (don't copy!)
104+ sizes, // tensor dimensions
105+ strides, // tensor strides (allows different strides)
106+ dtype_to_scalar_type (dtype) // map int32_t dtype to ScalarType
107+ );
108+
109+ if (!tensor) {
110+ ET_LOG (Error, " Failed to create tensor from blob" );
111+ return Error::InvalidArgument;
112+ }
113+
114+ // Store the tensor so it doesn't get destroyed
115+ tensors.insert (tensor);
116+
117+ *ret_new_tensor = tensor.get ();
118+ is_tensor_own_memory[tensor.get ()] = false ;
119+
120+ return Error::Ok;
121+ }
122+
55123AOTITorchError aoti_torch_empty_strided (
56124 int64_t ndim,
57125 const int64_t * sizes_ptr,
@@ -119,6 +187,7 @@ AOTITorchError aoti_torch_empty_strided(
119187 // Store the tensor so it doesn't get destroyed
120188 tensors.insert (tensor);
121189 *ret_new_tensor = tensor.get ();
190+ is_tensor_own_memory[tensor.get ()] = true ;
122191
123192 return Error::Ok;
124193}
@@ -156,9 +225,32 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
156225 // If tensor not found in our tracking, it's invalid
157226 if (!found_in_tensors) {
158227 ET_LOG (Error, " Didn't find tensor %p" , tensor);
228+ // Clean up any stale ownership info
229+ is_tensor_own_memory.erase (tensor);
159230 return Error::InvalidArgument;
160231 }
161232
233+ // Check ownership before cleaning up metadata
234+ auto ownership_it = is_tensor_own_memory.find (tensor);
235+ bool owns_memory = (ownership_it != is_tensor_own_memory.end ())
236+ ? ownership_it->second
237+ : false ;
238+
239+ // Clean up local metadata maps immediately to prevent use-after-free
240+ is_tensor_own_memory.erase (tensor);
241+
242+ if (!owns_memory) {
243+ // Don't free memory since the tensor doesn't own it, but still remove from
244+ // tracking
245+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
246+ if (it->get () == tensor) {
247+ tensors.erase (it);
248+ break ;
249+ }
250+ }
251+ return Error::Ok;
252+ }
253+
162254 // Find and delete the tensor
163255 for (auto it = tensors.begin (); it != tensors.end (); ++it) {
164256 if (it->get () == tensor) {
0 commit comments