@@ -31,7 +31,12 @@ using namespace executorch::backends::aoti;
3131
3232// Global storage for tensors and their metadata
3333std::unordered_set<std::shared_ptr<Tensor>> tensors;
34- std::unordered_map<Tensor*, bool > is_tensor_own_memory;
34+
35+ // Reference counting for memory addresses
36+ // Maps memory address to number of tensors using it
37+ // Special value: NOT_OWN (-1) means tensor never owns the memory
38+ constexpr int32_t NOT_OWN = -1 ;
39+ std::unordered_map<void *, int32_t > memory_to_n_tensor;
3540
3641extern " C" {
3742
@@ -110,7 +115,18 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
110115 // Store the tensor so it doesn't get destroyed
111116 tensors.insert (tensor);
112117 *ret_new_tensor = tensor.get ();
113- is_tensor_own_memory[tensor.get ()] = false ;
118+
119+ // Check if this memory address is already being tracked
120+ auto memory_it = memory_to_n_tensor.find (adjusted_data);
121+ ET_CHECK_OR_RETURN_ERROR (
122+ memory_it == memory_to_n_tensor.end (),
123+ InvalidArgument,
124+ " Memory address %p is already being tracked by another tensor" ,
125+ adjusted_data);
126+
127+ // Mark this memory as NOT_OWN since tensor created from blob never owns
128+ // memory
129+ memory_to_n_tensor[adjusted_data] = NOT_OWN;
114130
115131 ET_LOG (Debug, " aoti_torch_create_tensor_from_blob_v2: successfull" );
116132 return Error::Ok;
@@ -192,59 +208,91 @@ AOTITorchError aoti_torch_empty_strided(
192208 // Store the tensor so it doesn't get destroyed
193209 tensors.insert (tensor);
194210 *ret_new_tensor = tensor.get ();
195- is_tensor_own_memory[tensor.get ()] = true ;
211+
212+ // This tensor owns the memory it allocated, set reference count to 1
213+ memory_to_n_tensor[ptr] = 1 ;
196214
197215 ET_LOG (Debug, " aoti_torch_empty_strided: successfull" );
198216 return Error::Ok;
199217}
200218
201219AOTITorchError aoti_torch_delete_tensor_object (AOTITensorHandle tensor) {
202220 ET_LOG (Debug, " aoti_torch_delete_tensor_object: entered" );
203- // Find tensor in the set
221+
222+ // Handle null tensor pointer
223+ if (tensor == nullptr ) {
224+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: null tensor" );
225+ return Error::Ok;
226+ }
227+
228+ // Check if tensor exists in our tracking
229+ bool found_in_tensors = false ;
204230 for (auto it = tensors.begin (); it != tensors.end (); ++it) {
205231 if (it->get () == tensor) {
206- auto tensor_ptr = *it;
232+ found_in_tensors = true ;
233+ break ;
234+ }
235+ }
207236
208- // Check ownership before cleaning up
209- auto ownership_it = is_tensor_own_memory.find (tensor);
210- bool owns_memory = (ownership_it != is_tensor_own_memory.end ())
211- ? ownership_it->second
212- : false ;
237+ // If tensor not found in our tracking, it's invalid
238+ ET_CHECK_OR_RETURN_ERROR (
239+ found_in_tensors, InvalidArgument, " Didn't find tensor %p" , tensor);
213240
214- // Clean up ownership metadata
215- is_tensor_own_memory.erase (tensor);
241+ // Find and delete the tensor
242+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
243+ if (it->get () == tensor) {
244+ // Get the tensor before erasing
245+ auto tensor_ptr = *it;
246+ void * data_ptr = tensor_ptr->mutable_data_ptr ();
216247
217- if (owns_memory) {
218- // et tensor owns the memory; need to free it manually
219- void * data_ptr = tensor_ptr->mutable_data_ptr ();
248+ // Find the reference count for this memory address
249+ auto memory_it = memory_to_n_tensor.find (data_ptr);
250+ if (memory_it != memory_to_n_tensor.end ()) {
251+ int32_t ref_count = memory_it->second ;
220252
221- // Check if it's Metal GPU memory
222- if (metal_is_device_pointer (data_ptr)) {
223- // This is Metal GPU memory - the Metal helper will handle cleanup
224- // Metal buffers are automatically managed by ARC when the buffer is
225- // released
253+ if (ref_count == NOT_OWN) {
254+ // Tensor never owned the memory, skip freeing
255+ // Just remove tensor from tracking
226256 tensors.erase (it);
227- ET_LOG (
228- Debug,
229- " aoti_torch_delete_tensor_object: successfull (Metal GPU memory)" );
257+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free" );
230258 return Error::Ok;
259+ } else if (ref_count == 1 ) {
260+ // Only current tensor using this memory, free it
261+ // Check if it's Metal GPU memory
262+ if (metal_is_device_pointer (data_ptr)) {
263+ metal_deallocate_buffer (data_ptr);
264+ } else {
265+ // This is CPU memory - free immediately
266+ free (data_ptr);
267+ data_ptr = nullptr ;
268+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: freeing CPU memory" );
269+ }
270+
271+ // Remove from memory tracking
272+ memory_to_n_tensor.erase (memory_it);
273+ } else if (ref_count > 1 ) {
274+ // Other tensors still using this memory, just decrement count
275+ memory_to_n_tensor[data_ptr] = ref_count - 1 ;
276+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: decremented ref count from %d to %d" , ref_count, ref_count - 1 );
231277 }
232-
233- // This is CPU memory - free immediately
234- free (data_ptr);
278+ } else {
279+ ET_CHECK_OR_RETURN_ERROR (
280+ false ,
281+ Internal,
282+ " Internal error: memory not found during deletion" );
235283 }
236- // else: Don't free memory since the tensor doesn't own it
237284
238- // Remove from set (this will call the destructor if it's the last
285+ // Remove tensor from set (this will call the destructor if it's the last
239286 // reference)
240287 tensors.erase (it);
241- ET_LOG (
242- Debug, " aoti_torch_delete_tensor_object: successfull (CPU memory)" );
288+ ET_LOG (Debug, " aoti_torch_delete_tensor_object: successfull" );
243289 return Error::Ok;
244290 }
245291 }
246- ET_LOG (Error, " Didn't find tensor %p" , tensor);
247- return Error::InvalidArgument;
292+
293+ // This should never be reached since we found it above
294+ ET_CHECK_OR_RETURN_ERROR (
295+ false , Internal, " Internal error: tensor not found after validation" );
248296}
249297
250298AOTITorchError aoti_torch_copy_ (
@@ -375,75 +423,105 @@ AOTITorchError aoti_torch__reinterpret_tensor(
375423 InvalidArgument,
376424 " aoti_torch__reinterpret_tensor failed: ret_new_tensor is null" );
377425
426+ // Check if storage_offset is not 0 - return error if not
427+ ET_CHECK_OK_OR_RETURN_ERROR (validate_storage_offset (storage_offset));
428+
429+ // Get the device info from the source tensor to perform device_index
430+ // validation
431+ int32_t device_type = 0 ;
432+ int32_t device_index = 0 ;
433+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_type (self, &device_type));
434+
435+ ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_index (self, &device_index));
436+
437+ // Ensure device_index is always 0
438+ ET_CHECK_OR_RETURN_ERROR (
439+ device_index == 0 ,
440+ InvalidArgument,
441+ " device_index must be 0, got: %d" ,
442+ device_index);
443+
378444 // Get the dtype from the source tensor
379445 int32_t dtype = 0 ;
380446 ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_dtype (self, &dtype));
381447
382448 // Validate dtype using SupportedDTypes
383449 ET_CHECK_OK_OR_RETURN_ERROR (validate_dtype (dtype));
384450
385- int32_t device_type = 0 ;
386- ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_type (self, &device_type));
451+ // Get the original data pointer from the source tensor
452+ void * data_ptr = self->mutable_data_ptr ();
453+ ET_CHECK_OR_RETURN_ERROR (
454+ data_ptr != nullptr ,
455+ InvalidArgument,
456+ " Source tensor has null data pointer" );
387457
388- int32_t device_index = 0 ;
389- ET_CHECK_OK_OR_RETURN_ERROR (aoti_torch_get_device_index (self, &device_index));
458+ // Check if the given memory is in the map, if not return error
459+ auto memory_it = memory_to_n_tensor.find (data_ptr);
460+ ET_CHECK_OR_RETURN_ERROR (
461+ memory_it != memory_to_n_tensor.end (),
462+ InvalidArgument,
463+ " Memory address %p is not being tracked by reference counting system" ,
464+ data_ptr);
465+
466+ // Convert sizes using utility function from utils.h
467+ std::vector<aten::SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
468+
469+ // Convert strides using utility function from utils.h
470+ std::vector<aten::StridesType> strides =
471+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
472+
473+ // Create new tensor view that reinterprets the same memory with different
474+ // shape/strides This creates a view, not a copy - the data pointer is shared
475+ std::shared_ptr<Tensor> tensor = executorch::extension::from_blob (
476+ data_ptr, // Reuse the same memory from source tensor
477+ sizes, // New sizes with explicit SizesType
478+ strides, // New strides with explicit StridesType
479+ dtype_to_scalar_type (dtype) // Convert dtype with explicit type casting
480+ );
390481
391- // Get the base data pointer from the source tensor
392- void * base_data_ptr = self->mutable_data_ptr ();
393482 ET_CHECK_OR_RETURN_ERROR (
394- base_data_ptr != nullptr ,
483+ tensor != nullptr ,
395484 InvalidArgument,
396- " Source tensor has null data pointer " );
485+ " Failed to create reinterpreted tensor view " );
397486
398- // Calculate new tensor size in elements for logging
399- int64_t new_numel = 1 ;
400- for (int64_t i = 0 ; i < ndim; i++) {
401- new_numel *= sizes_ptr[i];
402- }
487+ // Store the tensor so it doesn't get destroyed
488+ tensors.insert (tensor);
403489
404- ET_LOG (
405- Debug,
406- " aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld" ,
407- base_data_ptr,
408- new_numel,
409- storage_offset);
410-
411- // Create a new tensor view that shares the same underlying storage
412- // This is the correct way to implement reinterpret_tensor - as a view, not a
413- // copy
414- AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2 (
415- base_data_ptr, // Same underlying data pointer
416- ndim, // New dimensions
417- sizes_ptr, // New sizes
418- strides_ptr, // New strides
419- storage_offset, // Storage offset (will be handled properly now)
420- dtype,
421- device_type,
422- device_index,
423- ret_new_tensor,
424- 0 , // layout (default)
425- nullptr , // opaque_metadata
426- 0 // opaque_metadata_size
427- );
490+ *ret_new_tensor = tensor.get ();
428491
429- if (create_err != Error::Ok) {
430- ET_LOG (Error, " failed to create reinterpreted tensor view" );
431- return create_err;
432- }
492+ // Increment the reference count for this memory address only if it is owned
493+ // by tensor
494+ memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
495+ ? NOT_OWN
496+ : memory_to_n_tensor[data_ptr] + 1 ;
433497
434498 ET_LOG (Debug, " aoti_torch__reinterpret_tensor: successfull" );
435499 return Error::Ok;
436500}
437501
438502// Cleanup function for clearing global state
439503void cleanup_memory () {
440- is_tensor_own_memory.clear ();
441- if (!tensors.empty ()) {
442- ET_LOG (Error, " Warning: tensors not empty during cleanup" );
504+ // Use aoti_torch_delete_tensor_object to properly delete each tensor
505+ // Note: We need to collect tensor pointers first since deletion modifies the
506+ // set
507+ std::vector<Tensor*> tensor_ptrs;
508+ tensor_ptrs.reserve (tensors.size ());
509+ for (const auto & tensor_shared : tensors) {
510+ tensor_ptrs.push_back (tensor_shared.get ());
511+ }
512+
513+ // Now delete each tensor - this will modify the global tensors set
514+ for (Tensor* tensor_ptr : tensor_ptrs) {
515+ aoti_torch_delete_tensor_object (tensor_ptr);
443516 }
444517
518+ // tensors set should now be empty, but ensure it's cleared
519+ tensors.clear ();
520+
445521 // Clean up Metal resources
446522 metal_cleanup_resources ();
523+
524+ ET_LOG (Info, " Cleared all tensors and Metal resources" );
447525}
448526
449527} // extern "C"
0 commit comments