1515#include < cstdint>
1616#include < cstdlib> // For posix_memalign
1717#include < memory>
18+ #include < unordered_map>
1819#include < unordered_set>
1920#include < vector>
2021
21- // CUDA error checking macro
22- #define ET_CUDA_CHECK_OR_RETURN_ERROR (EXPR ) \
23- do { \
24- const cudaError_t err = EXPR; \
25- if (err == cudaSuccess) { \
26- break ; \
27- } \
28- ET_LOG ( \
29- Error, \
30- " %s:%d CUDA error: %s" , \
31- __FILE__, \
32- __LINE__, \
33- cudaGetErrorString (err)); \
34- return Error::Internal; \
35- } while (0 )
36-
37- // Kernel launch check macro
38- #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR () \
39- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaGetLastError())
40-
4122namespace executorch {
4223namespace backends {
4324namespace cuda {
@@ -46,12 +27,122 @@ using executorch::aten::SizesType;
4627using executorch::aten::StridesType;
4728using executorch::backends::aoti::dtype_to_element_size;
4829using executorch::backends::aoti::dtype_to_scalar_type;
30+ using executorch::backends::aoti::validate_storage_offset;
4931
5032// Global storage for tensors and their metadata
5133std::unordered_set<std::shared_ptr<Tensor>> tensors;
5234
35+ // Reference counting for memory addresses
36+ // Maps memory address to number of tensors using it
37+ // Special value: NOT_OWN (-1) means tensor never owns the memory
38+ constexpr int32_t NOT_OWN = -1 ;
39+ std::unordered_map<void *, int32_t > memory_to_n_tensor;
40+
5341extern " C" {
5442
43+ AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
44+ void * data,
45+ int64_t ndim,
46+ const int64_t * sizes_ptr,
47+ const int64_t * strides_ptr,
48+ int64_t storage_offset,
49+ int32_t dtype,
50+ int32_t device_type,
51+ int32_t device_index,
52+ Tensor** ret_new_tensor,
53+ int32_t layout,
54+ const uint8_t * opaque_metadata,
55+ int64_t opaque_metadata_size) {
56+ // TODO(gasoonjia): verify given data is on the target device
57+ (void )device_type;
58+ (void )opaque_metadata;
59+ (void )layout;
60+ (void )opaque_metadata_size;
61+
62+ // Validate input parameters first
63+ if (data == nullptr ) {
64+ ET_LOG (
65+ Error,
66+ " aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null" );
67+ return Error::InvalidArgument;
68+ }
69+
70+ if (sizes_ptr == nullptr && ndim > 0 ) {
71+ ET_LOG (
72+ Error,
73+ " aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null" );
74+ return Error::InvalidArgument;
75+ }
76+
77+ if (ret_new_tensor == nullptr ) {
78+ ET_LOG (
79+ Error,
80+ " aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null" );
81+ return Error::InvalidArgument;
82+ }
83+
84+ // Check that device_index is always 0
85+ if (device_index != 0 ) {
86+ ET_LOG (Error, " device_index must be 0, got: %d" , device_index);
87+ return Error::InvalidArgument;
88+ }
89+
90+ // Validate dtype using SupportedDTypes from utils.h
91+ AOTITorchError dtype_error = validate_dtype (dtype);
92+ if (dtype_error != Error::Ok) {
93+ return dtype_error;
94+ }
95+
96+ // Storage offset must be 0 since from_blob cannot handle different offsets
97+ AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
98+ if (storage_offset_error != Error::Ok) {
99+ return storage_offset_error;
100+ }
101+
102+ // Convert sizes to the format expected by ExecutorTorch using SizesType
103+ std::vector<executorch::aten::SizesType> sizes =
104+ convert_sizes_to_vector (ndim, sizes_ptr);
105+
106+ // Convert strides using the common helper function with StridesType
107+ std::vector<executorch::aten::StridesType> strides =
108+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
109+
110+ // Create ExecutorTorch tensor that wraps the existing memory
111+ // Note: We're NOT copying the data, just wrapping it
112+ auto tensor = executorch::extension::from_blob (
113+ data, // existing memory (don't copy!)
114+ sizes, // tensor dimensions
115+ strides, // tensor strides (allows different strides)
116+ dtype_to_scalar_type (dtype) // map int32_t dtype to ScalarType
117+ );
118+
119+ if (!tensor) {
120+ ET_LOG (Error, " Failed to create tensor from blob" );
121+ return Error::InvalidArgument;
122+ }
123+
124+ // Store the tensor so it doesn't get destroyed
125+ tensors.insert (tensor);
126+
127+ *ret_new_tensor = tensor.get ();
128+
129+ // Check if this memory address is already being tracked
130+ auto memory_it = memory_to_n_tensor.find (data);
131+ if (memory_it != memory_to_n_tensor.end ()) {
132+ ET_LOG (
133+ Error,
134+ " Memory address %p is already being tracked by another tensor" ,
135+ data);
136+ return Error::InvalidArgument;
137+ }
138+
139+ // Mark this memory as NOT_OWN since tensor created from blob never owns
140+ // memory
141+ memory_to_n_tensor[data] = NOT_OWN;
142+
143+ return Error::Ok;
144+ }
145+
55146AOTITorchError aoti_torch_empty_strided (
56147 int64_t ndim,
57148 const int64_t * sizes_ptr,
@@ -120,6 +211,9 @@ AOTITorchError aoti_torch_empty_strided(
120211 tensors.insert (tensor);
121212 *ret_new_tensor = tensor.get ();
122213
214+ // This tensor owns the memory it allocated, set reference count to 1
215+ memory_to_n_tensor[ptr] = 1 ;
216+
123217 return Error::Ok;
124218}
125219
@@ -164,26 +258,47 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
164258 if (it->get () == tensor) {
165259 // Get the tensor before erasing
166260 auto tensor_ptr = *it;
167-
168261 void * data_ptr = tensor_ptr->mutable_data_ptr ();
169262
170- // Determine if it's GPU memory
171- cudaPointerAttributes attributes{};
172- ET_CUDA_CHECK_OR_RETURN_ERROR (
173- cudaPointerGetAttributes (&attributes, data_ptr));
174-
175- // et tensor does not own data; need to free them manually.
176- if (attributes.type == cudaMemoryTypeManaged) {
177- // This is CUDA managed memory - free with proper synchronization
178- ET_CUDA_CHECK_OR_RETURN_ERROR (
179- cudaDeviceSynchronize ()); // Wait for all operations to complete
180- // BEFORE freeing
181- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaFree (data_ptr));
263+ // Find the reference count for this memory address
264+ auto memory_it = memory_to_n_tensor.find (data_ptr);
265+ if (memory_it != memory_to_n_tensor.end ()) {
266+ int32_t ref_count = memory_it->second ;
267+
268+ if (ref_count == NOT_OWN) {
269+ // Tensor never owned the memory, skip freeing
270+ // Just remove tensor from tracking
271+ tensors.erase (it);
272+ return Error::Ok;
273+ } else if (ref_count == 1 ) {
274+ // Only current tensor using this memory, free it
275+ // Determine if it's GPU memory
276+ cudaPointerAttributes attributes{};
277+ ET_CUDA_CHECK_OR_RETURN_ERROR (
278+ cudaPointerGetAttributes (&attributes, data_ptr));
279+
280+ if (attributes.type == cudaMemoryTypeManaged) {
281+ // This is CUDA managed memory - free with proper synchronization
282+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaDeviceSynchronize ());
283+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaFree (data_ptr));
284+ } else {
285+ // This is CPU memory - free immediately
286+ free (data_ptr);
287+ data_ptr = nullptr ;
288+ }
289+
290+ // Remove from memory tracking
291+ memory_to_n_tensor.erase (memory_it);
292+ } else if (ref_count > 1 ) {
293+ // Other tensors still using this memory, just decrement count
294+ memory_to_n_tensor[data_ptr] = ref_count - 1 ;
295+ }
182296 } else {
183- // This is CPU memory - free immediately
184- free (data_ptr) ;
297+ ET_LOG (Error, " Internal error: memory not found during deletion " );
298+ return Error::Internal ;
185299 }
186- // Remove from set (this will call the destructor if it's the last
300+
301+ // Remove tensor from set (this will call the destructor if it's the last
187302 // reference)
188303 tensors.erase (it);
189304 return Error::Ok;
0 commit comments