15
15
#include < cstdint>
16
16
#include < cstdlib> // For posix_memalign
17
17
#include < memory>
18
+ #include < unordered_map>
18
19
#include < unordered_set>
19
20
#include < vector>
20
21
21
- // CUDA error checking macro
22
- #define ET_CUDA_CHECK_OR_RETURN_ERROR (EXPR ) \
23
- do { \
24
- const cudaError_t err = EXPR; \
25
- if (err == cudaSuccess) { \
26
- break ; \
27
- } \
28
- ET_LOG ( \
29
- Error, \
30
- " %s:%d CUDA error: %s" , \
31
- __FILE__, \
32
- __LINE__, \
33
- cudaGetErrorString (err)); \
34
- return Error::Internal; \
35
- } while (0 )
36
-
37
- // Kernel launch check macro
38
- #define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR () \
39
- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaGetLastError())
40
-
41
22
namespace executorch {
42
23
namespace backends {
43
24
namespace cuda {
@@ -46,12 +27,122 @@ using executorch::aten::SizesType;
46
27
using executorch::aten::StridesType;
47
28
using executorch::backends::aoti::dtype_to_element_size;
48
29
using executorch::backends::aoti::dtype_to_scalar_type;
30
+ using executorch::backends::aoti::validate_storage_offset;
49
31
50
32
// Global storage for tensors and their metadata
51
33
std::unordered_set<std::shared_ptr<Tensor>> tensors;
52
34
35
+ // Reference counting for memory addresses
36
+ // Maps memory address to number of tensors using it
37
+ // Special value: NOT_OWN (-1) means tensor never owns the memory
38
+ constexpr int32_t NOT_OWN = -1 ;
39
+ std::unordered_map<void *, int32_t > memory_to_n_tensor;
40
+
53
41
extern " C" {
54
42
43
+ AOTITorchError aoti_torch_create_tensor_from_blob_v2 (
44
+ void * data,
45
+ int64_t ndim,
46
+ const int64_t * sizes_ptr,
47
+ const int64_t * strides_ptr,
48
+ int64_t storage_offset,
49
+ int32_t dtype,
50
+ int32_t device_type,
51
+ int32_t device_index,
52
+ Tensor** ret_new_tensor,
53
+ int32_t layout,
54
+ const uint8_t * opaque_metadata,
55
+ int64_t opaque_metadata_size) {
56
+ // TODO(gasoonjia): verify given data is on the target device
57
+ (void )device_type;
58
+ (void )opaque_metadata;
59
+ (void )layout;
60
+ (void )opaque_metadata_size;
61
+
62
+ // Validate input parameters first
63
+ if (data == nullptr ) {
64
+ ET_LOG (
65
+ Error,
66
+ " aoti_torch_create_tensor_from_blob_v2 failed: data pointer is null" );
67
+ return Error::InvalidArgument;
68
+ }
69
+
70
+ if (sizes_ptr == nullptr && ndim > 0 ) {
71
+ ET_LOG (
72
+ Error,
73
+ " aoti_torch_create_tensor_from_blob_v2 failed: sizes_ptr is null" );
74
+ return Error::InvalidArgument;
75
+ }
76
+
77
+ if (ret_new_tensor == nullptr ) {
78
+ ET_LOG (
79
+ Error,
80
+ " aoti_torch_create_tensor_from_blob_v2 failed: ret_new_tensor is null" );
81
+ return Error::InvalidArgument;
82
+ }
83
+
84
+ // Check that device_index is always 0
85
+ if (device_index != 0 ) {
86
+ ET_LOG (Error, " device_index must be 0, got: %d" , device_index);
87
+ return Error::InvalidArgument;
88
+ }
89
+
90
+ // Validate dtype using SupportedDTypes from utils.h
91
+ AOTITorchError dtype_error = validate_dtype (dtype);
92
+ if (dtype_error != Error::Ok) {
93
+ return dtype_error;
94
+ }
95
+
96
+ // Storage offset must be 0 since from_blob cannot handle different offsets
97
+ AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
98
+ if (storage_offset_error != Error::Ok) {
99
+ return storage_offset_error;
100
+ }
101
+
102
+ // Convert sizes to the format expected by ExecutorTorch using SizesType
103
+ std::vector<executorch::aten::SizesType> sizes =
104
+ convert_sizes_to_vector (ndim, sizes_ptr);
105
+
106
+ // Convert strides using the common helper function with StridesType
107
+ std::vector<executorch::aten::StridesType> strides =
108
+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
109
+
110
+ // Create ExecutorTorch tensor that wraps the existing memory
111
+ // Note: We're NOT copying the data, just wrapping it
112
+ auto tensor = executorch::extension::from_blob (
113
+ data, // existing memory (don't copy!)
114
+ sizes, // tensor dimensions
115
+ strides, // tensor strides (allows different strides)
116
+ dtype_to_scalar_type (dtype) // map int32_t dtype to ScalarType
117
+ );
118
+
119
+ if (!tensor) {
120
+ ET_LOG (Error, " Failed to create tensor from blob" );
121
+ return Error::InvalidArgument;
122
+ }
123
+
124
+ // Store the tensor so it doesn't get destroyed
125
+ tensors.insert (tensor);
126
+
127
+ *ret_new_tensor = tensor.get ();
128
+
129
+ // Check if this memory address is already being tracked
130
+ auto memory_it = memory_to_n_tensor.find (data);
131
+ if (memory_it != memory_to_n_tensor.end ()) {
132
+ ET_LOG (
133
+ Error,
134
+ " Memory address %p is already being tracked by another tensor" ,
135
+ data);
136
+ return Error::InvalidArgument;
137
+ }
138
+
139
+ // Mark this memory as NOT_OWN since tensor created from blob never owns
140
+ // memory
141
+ memory_to_n_tensor[data] = NOT_OWN;
142
+
143
+ return Error::Ok;
144
+ }
145
+
55
146
AOTITorchError aoti_torch_empty_strided (
56
147
int64_t ndim,
57
148
const int64_t * sizes_ptr,
@@ -120,6 +211,9 @@ AOTITorchError aoti_torch_empty_strided(
120
211
tensors.insert (tensor);
121
212
*ret_new_tensor = tensor.get ();
122
213
214
+ // This tensor owns the memory it allocated, set reference count to 1
215
+ memory_to_n_tensor[ptr] = 1 ;
216
+
123
217
return Error::Ok;
124
218
}
125
219
@@ -164,26 +258,47 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
164
258
if (it->get () == tensor) {
165
259
// Get the tensor before erasing
166
260
auto tensor_ptr = *it;
167
-
168
261
void * data_ptr = tensor_ptr->mutable_data_ptr ();
169
262
170
- // Determine if it's GPU memory
171
- cudaPointerAttributes attributes{};
172
- ET_CUDA_CHECK_OR_RETURN_ERROR (
173
- cudaPointerGetAttributes (&attributes, data_ptr));
174
-
175
- // et tensor does not own data; need to free them manually.
176
- if (attributes.type == cudaMemoryTypeManaged) {
177
- // This is CUDA managed memory - free with proper synchronization
178
- ET_CUDA_CHECK_OR_RETURN_ERROR (
179
- cudaDeviceSynchronize ()); // Wait for all operations to complete
180
- // BEFORE freeing
181
- ET_CUDA_CHECK_OR_RETURN_ERROR (cudaFree (data_ptr));
263
+ // Find the reference count for this memory address
264
+ auto memory_it = memory_to_n_tensor.find (data_ptr);
265
+ if (memory_it != memory_to_n_tensor.end ()) {
266
+ int32_t ref_count = memory_it->second ;
267
+
268
+ if (ref_count == NOT_OWN) {
269
+ // Tensor never owned the memory, skip freeing
270
+ // Just remove tensor from tracking
271
+ tensors.erase (it);
272
+ return Error::Ok;
273
+ } else if (ref_count == 1 ) {
274
+ // Only current tensor using this memory, free it
275
+ // Determine if it's GPU memory
276
+ cudaPointerAttributes attributes{};
277
+ ET_CUDA_CHECK_OR_RETURN_ERROR (
278
+ cudaPointerGetAttributes (&attributes, data_ptr));
279
+
280
+ if (attributes.type == cudaMemoryTypeManaged) {
281
+ // This is CUDA managed memory - free with proper synchronization
282
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaDeviceSynchronize ());
283
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaFree (data_ptr));
284
+ } else {
285
+ // This is CPU memory - free immediately
286
+ free (data_ptr);
287
+ data_ptr = nullptr ;
288
+ }
289
+
290
+ // Remove from memory tracking
291
+ memory_to_n_tensor.erase (memory_it);
292
+ } else if (ref_count > 1 ) {
293
+ // Other tensors still using this memory, just decrement count
294
+ memory_to_n_tensor[data_ptr] = ref_count - 1 ;
295
+ }
182
296
} else {
183
- // This is CPU memory - free immediately
184
- free (data_ptr) ;
297
+ ET_LOG (Error, " Internal error: memory not found during deletion " );
298
+ return Error::Internal ;
185
299
}
186
- // Remove from set (this will call the destructor if it's the last
300
+
301
+ // Remove tensor from set (this will call the destructor if it's the last
187
302
// reference)
188
303
tensors.erase (it);
189
304
return Error::Ok;
0 commit comments