Skip to content

Commit 81a3acc

Browse files
Metal backend: track tensors (#15342)
This pull request refactors and improves the memory management and reference counting system for tensors in the Metal backend. The main change is the replacement of the previous ownership tracking (`is_tensor_own_memory`) with a more robust reference counting map (`memory_to_n_tensor`), which tracks how many tensors share a memory address and whether the memory is owned or not. Additional improvements include safer tensor deletion, proper Metal buffer deallocation, and consistent handling of tensor views and resource cleanup.
1 parent b38028d commit 81a3acc

File tree

5 files changed

+186
-81
lines changed

5 files changed

+186
-81
lines changed

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ extern "C" {
354354

355355
// Memory management functions for Metal
356356
void* metal_allocate_buffer(long bytes);
357+
void metal_deallocate_buffer(void* ptr);
357358
bool metal_is_device_pointer(void* ptr);
358359
int metal_copy_memory(
359360
void* dst,

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
8686
}
8787
}
8888

89+
void metal_deallocate_buffer(void* ptr) {
90+
@autoreleasepool {
91+
auto it = ptr_to_mtl_buffer.find(ptr);
92+
if (it != ptr_to_mtl_buffer.end()) {
93+
id<MTLBuffer> buffer = it->second;
94+
[buffer release];
95+
ptr_to_mtl_buffer.erase(it);
96+
ET_LOG(Debug, "Deallocated Metal buffer for pointer %p", ptr);
97+
ptr = nullptr;
98+
} else {
99+
ET_LOG(Error, "Failed to find Metal buffer for pointer %p", ptr);
100+
}
101+
}
102+
}
103+
89104
void metal_cleanup_resources() {
90105
if (!ptr_to_mtl_buffer.empty()) {
91106
@autoreleasepool {

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,12 @@ AOTITorchError aoti_torch_mps_convolution(
736736
throw std::runtime_error("Tensor size mismatch");
737737
}
738738

739-
// Store the tensor handle - mark that we own the memory since we manually allocated it with malloc
739+
// Store the tensor handle - mark that we own the memory since we manually allocated it
740740
*ret0 = output_tensor_handle;
741-
is_tensor_own_memory[et_tensor] = true; // We allocated the GPU memory
741+
// Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2
742+
// The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it
743+
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
744+
memory_to_n_tensor[tensor_data] = 1;
742745

743746
ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel);
744747

@@ -1327,10 +1330,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
13271330
}
13281331

13291332
// Mark that we own the memory for these tensors
1330-
auto* out_et_tensor = reinterpret_cast<Tensor*>(out_tensor_handle);
1331-
auto* attn_et_tensor = reinterpret_cast<Tensor*>(attn_tensor_handle);
1332-
is_tensor_own_memory[out_et_tensor] = true;
1333-
is_tensor_own_memory[attn_et_tensor] = true;
1333+
// Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2
1334+
// The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it
1335+
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
1336+
memory_to_n_tensor[out_contents_ptr] = 1;
1337+
memory_to_n_tensor[attn_contents_ptr] = 1;
13341338

13351339
// Set output tensor handles
13361340
*ret0 = out_tensor_handle;

backends/apple/metal/runtime/shims/memory.cpp

Lines changed: 159 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ using namespace executorch::backends::aoti;
3131

3232
// Global storage for tensors and their metadata
3333
std::unordered_set<std::shared_ptr<Tensor>> tensors;
34-
std::unordered_map<Tensor*, bool> is_tensor_own_memory;
34+
35+
// Reference counting for memory addresses
36+
// Maps memory address to number of tensors using it
37+
// Special value: NOT_OWN (-1) means tensor never owns the memory
38+
constexpr int32_t NOT_OWN = -1;
39+
std::unordered_map<void*, int32_t> memory_to_n_tensor;
3540

3641
extern "C" {
3742

@@ -110,7 +115,18 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
110115
// Store the tensor so it doesn't get destroyed
111116
tensors.insert(tensor);
112117
*ret_new_tensor = tensor.get();
113-
is_tensor_own_memory[tensor.get()] = false;
118+
119+
// Check if this memory address is already being tracked
120+
auto memory_it = memory_to_n_tensor.find(adjusted_data);
121+
ET_CHECK_OR_RETURN_ERROR(
122+
memory_it == memory_to_n_tensor.end(),
123+
InvalidArgument,
124+
"Memory address %p is already being tracked by another tensor",
125+
adjusted_data);
126+
127+
// Mark this memory as NOT_OWN since tensor created from blob never owns
128+
// memory
129+
memory_to_n_tensor[adjusted_data] = NOT_OWN;
114130

115131
ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: successfull");
116132
return Error::Ok;
@@ -192,59 +208,98 @@ AOTITorchError aoti_torch_empty_strided(
192208
// Store the tensor so it doesn't get destroyed
193209
tensors.insert(tensor);
194210
*ret_new_tensor = tensor.get();
195-
is_tensor_own_memory[tensor.get()] = true;
211+
212+
// This tensor owns the memory it allocated, set reference count to 1
213+
memory_to_n_tensor[ptr] = 1;
196214

197215
ET_LOG(Debug, "aoti_torch_empty_strided: successfull");
198216
return Error::Ok;
199217
}
200218

201219
AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) {
202220
ET_LOG(Debug, "aoti_torch_delete_tensor_object: entered");
203-
// Find tensor in the set
221+
222+
// Handle null tensor pointer
223+
if (tensor == nullptr) {
224+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: null tensor");
225+
return Error::Ok;
226+
}
227+
228+
// Check if tensor exists in our tracking
229+
bool found_in_tensors = false;
204230
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
205231
if (it->get() == tensor) {
206-
auto tensor_ptr = *it;
232+
found_in_tensors = true;
233+
break;
234+
}
235+
}
207236

208-
// Check ownership before cleaning up
209-
auto ownership_it = is_tensor_own_memory.find(tensor);
210-
bool owns_memory = (ownership_it != is_tensor_own_memory.end())
211-
? ownership_it->second
212-
: false;
237+
// If tensor not found in our tracking, it's invalid
238+
ET_CHECK_OR_RETURN_ERROR(
239+
found_in_tensors, InvalidArgument, "Didn't find tensor %p", tensor);
213240

214-
// Clean up ownership metadata
215-
is_tensor_own_memory.erase(tensor);
241+
// Find and delete the tensor
242+
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
243+
if (it->get() == tensor) {
244+
// Get the tensor before erasing
245+
auto tensor_ptr = *it;
246+
void* data_ptr = tensor_ptr->mutable_data_ptr();
216247

217-
if (owns_memory) {
218-
// et tensor owns the memory; need to free it manually
219-
void* data_ptr = tensor_ptr->mutable_data_ptr();
248+
// Find the reference count for this memory address
249+
auto memory_it = memory_to_n_tensor.find(data_ptr);
250+
if (memory_it != memory_to_n_tensor.end()) {
251+
int32_t ref_count = memory_it->second;
220252

221-
// Check if it's Metal GPU memory
222-
if (metal_is_device_pointer(data_ptr)) {
223-
// This is Metal GPU memory - the Metal helper will handle cleanup
224-
// Metal buffers are automatically managed by ARC when the buffer is
225-
// released
253+
if (ref_count == NOT_OWN) {
254+
// Tensor never owned the memory, skip freeing
255+
// Just remove tensor from tracking
226256
tensors.erase(it);
227257
ET_LOG(
228258
Debug,
229-
"aoti_torch_delete_tensor_object: successfull (Metal GPU memory)");
259+
"aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free");
230260
return Error::Ok;
261+
} else if (ref_count == 1) {
262+
// Only current tensor using this memory, free it
263+
// Check if it's Metal GPU memory
264+
if (metal_is_device_pointer(data_ptr)) {
265+
metal_deallocate_buffer(data_ptr);
266+
} else {
267+
// This is CPU memory - free immediately
268+
free(data_ptr);
269+
data_ptr = nullptr;
270+
ET_LOG(
271+
Debug, "aoti_torch_delete_tensor_object: freeing CPU memory");
272+
}
273+
274+
// Remove from memory tracking
275+
memory_to_n_tensor.erase(memory_it);
276+
} else if (ref_count > 1) {
277+
// Other tensors still using this memory, just decrement count
278+
memory_to_n_tensor[data_ptr] = ref_count - 1;
279+
ET_LOG(
280+
Debug,
281+
"aoti_torch_delete_tensor_object: decremented ref count from %d to %d",
282+
ref_count,
283+
ref_count - 1);
231284
}
232-
233-
// This is CPU memory - free immediately
234-
free(data_ptr);
285+
} else {
286+
ET_CHECK_OR_RETURN_ERROR(
287+
false,
288+
Internal,
289+
"Internal error: memory not found during deletion");
235290
}
236-
// else: Don't free memory since the tensor doesn't own it
237291

238-
// Remove from set (this will call the destructor if it's the last
292+
// Remove tensor from set (this will call the destructor if it's the last
239293
// reference)
240294
tensors.erase(it);
241-
ET_LOG(
242-
Debug, "aoti_torch_delete_tensor_object: successfull (CPU memory)");
295+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: successfull");
243296
return Error::Ok;
244297
}
245298
}
246-
ET_LOG(Error, "Didn't find tensor %p", tensor);
247-
return Error::InvalidArgument;
299+
300+
// This should never be reached since we found it above
301+
ET_CHECK_OR_RETURN_ERROR(
302+
false, Internal, "Internal error: tensor not found after validation");
248303
}
249304

250305
AOTITorchError aoti_torch_copy_(
@@ -375,75 +430,105 @@ AOTITorchError aoti_torch__reinterpret_tensor(
375430
InvalidArgument,
376431
"aoti_torch__reinterpret_tensor failed: ret_new_tensor is null");
377432

433+
// Check if storage_offset is not 0 - return error if not
434+
ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset));
435+
436+
// Get the device info from the source tensor to perform device_index
437+
// validation
438+
int32_t device_type = 0;
439+
int32_t device_index = 0;
440+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type));
441+
442+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index));
443+
444+
// Ensure device_index is always 0
445+
ET_CHECK_OR_RETURN_ERROR(
446+
device_index == 0,
447+
InvalidArgument,
448+
"device_index must be 0, got: %d",
449+
device_index);
450+
378451
// Get the dtype from the source tensor
379452
int32_t dtype = 0;
380453
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype));
381454

382455
// Validate dtype using SupportedDTypes
383456
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));
384457

385-
int32_t device_type = 0;
386-
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type));
458+
// Get the original data pointer from the source tensor
459+
void* data_ptr = self->mutable_data_ptr();
460+
ET_CHECK_OR_RETURN_ERROR(
461+
data_ptr != nullptr,
462+
InvalidArgument,
463+
"Source tensor has null data pointer");
387464

388-
int32_t device_index = 0;
389-
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index));
465+
// Check if the given memory is in the map, if not return error
466+
auto memory_it = memory_to_n_tensor.find(data_ptr);
467+
ET_CHECK_OR_RETURN_ERROR(
468+
memory_it != memory_to_n_tensor.end(),
469+
InvalidArgument,
470+
"Memory address %p is not being tracked by reference counting system",
471+
data_ptr);
472+
473+
// Convert sizes using utility function from utils.h
474+
std::vector<aten::SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
475+
476+
// Convert strides using utility function from utils.h
477+
std::vector<aten::StridesType> strides =
478+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
479+
480+
// Create new tensor view that reinterprets the same memory with different
481+
// shape/strides This creates a view, not a copy - the data pointer is shared
482+
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
483+
data_ptr, // Reuse the same memory from source tensor
484+
sizes, // New sizes with explicit SizesType
485+
strides, // New strides with explicit StridesType
486+
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
487+
);
390488

391-
// Get the base data pointer from the source tensor
392-
void* base_data_ptr = self->mutable_data_ptr();
393489
ET_CHECK_OR_RETURN_ERROR(
394-
base_data_ptr != nullptr,
490+
tensor != nullptr,
395491
InvalidArgument,
396-
"Source tensor has null data pointer");
492+
"Failed to create reinterpreted tensor view");
397493

398-
// Calculate new tensor size in elements for logging
399-
int64_t new_numel = 1;
400-
for (int64_t i = 0; i < ndim; i++) {
401-
new_numel *= sizes_ptr[i];
402-
}
494+
// Store the tensor so it doesn't get destroyed
495+
tensors.insert(tensor);
403496

404-
ET_LOG(
405-
Debug,
406-
"aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld",
407-
base_data_ptr,
408-
new_numel,
409-
storage_offset);
410-
411-
// Create a new tensor view that shares the same underlying storage
412-
// This is the correct way to implement reinterpret_tensor - as a view, not a
413-
// copy
414-
AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2(
415-
base_data_ptr, // Same underlying data pointer
416-
ndim, // New dimensions
417-
sizes_ptr, // New sizes
418-
strides_ptr, // New strides
419-
storage_offset, // Storage offset (will be handled properly now)
420-
dtype,
421-
device_type,
422-
device_index,
423-
ret_new_tensor,
424-
0, // layout (default)
425-
nullptr, // opaque_metadata
426-
0 // opaque_metadata_size
427-
);
497+
*ret_new_tensor = tensor.get();
428498

429-
if (create_err != Error::Ok) {
430-
ET_LOG(Error, "failed to create reinterpreted tensor view");
431-
return create_err;
432-
}
499+
// Increment the reference count for this memory address only if it is owned
500+
// by tensor
501+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
502+
? NOT_OWN
503+
: memory_to_n_tensor[data_ptr] + 1;
433504

434505
ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull");
435506
return Error::Ok;
436507
}
437508

438509
// Cleanup function for clearing global state
439510
void cleanup_memory() {
440-
is_tensor_own_memory.clear();
441-
if (!tensors.empty()) {
442-
ET_LOG(Error, "Warning: tensors not empty during cleanup");
511+
// Use aoti_torch_delete_tensor_object to properly delete each tensor
512+
// Note: We need to collect tensor pointers first since deletion modifies the
513+
// set
514+
std::vector<Tensor*> tensor_ptrs;
515+
tensor_ptrs.reserve(tensors.size());
516+
for (const auto& tensor_shared : tensors) {
517+
tensor_ptrs.push_back(tensor_shared.get());
443518
}
444519

520+
// Now delete each tensor - this will modify the global tensors set
521+
for (Tensor* tensor_ptr : tensor_ptrs) {
522+
aoti_torch_delete_tensor_object(tensor_ptr);
523+
}
524+
525+
// tensors set should now be empty, but ensure it's cleared
526+
tensors.clear();
527+
445528
// Clean up Metal resources
446529
metal_cleanup_resources();
530+
531+
ET_LOG(Info, "Cleared all tensors and Metal resources");
447532
}
448533

449534
} // extern "C"

backends/apple/metal/runtime/shims/memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace metal {
2222
extern "C" {
2323

2424
// Global storage declarations
25-
extern std::unordered_map<Tensor*, bool> is_tensor_own_memory;
25+
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
2626
extern std::unordered_set<std::shared_ptr<Tensor>> tensors;
2727

2828
// Memory-related operations

0 commit comments

Comments
 (0)