Skip to content

Commit 00ba522

Browse files
Update
[ghstack-poisoned]
1 parent d8e6d13 commit 00ba522

File tree

5 files changed

+181
-83
lines changed

5 files changed

+181
-83
lines changed

backends/apple/metal/runtime/shims/et_metal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ extern "C" {
354354

355355
// Memory management functions for Metal
356356
void* metal_allocate_buffer(long bytes);
357+
void metal_deallocate_buffer(void* ptr);
357358
bool metal_is_device_pointer(void* ptr);
358359
int metal_copy_memory(
359360
void* dst,

backends/apple/metal/runtime/shims/et_metal.mm

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
8686
}
8787
}
8888

89+
void metal_deallocate_buffer(void* ptr) {
90+
@autoreleasepool {
91+
auto it = ptr_to_mtl_buffer.find(ptr);
92+
if (it != ptr_to_mtl_buffer.end()) {
93+
id<MTLBuffer> buffer = it->second;
94+
[buffer release];
95+
ptr_to_mtl_buffer.erase(it);
96+
ET_LOG(Debug, "Deallocated Metal buffer for pointer %p", ptr);
97+
ptr = nullptr;
98+
} else {
99+
ET_LOG(Error, "Failed to find Metal buffer for pointer %p", ptr);
100+
}
101+
}
102+
}
103+
89104
void metal_cleanup_resources() {
90105
if (!ptr_to_mtl_buffer.empty()) {
91106
@autoreleasepool {

backends/apple/metal/runtime/shims/et_metal_ops.mm

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -736,9 +736,12 @@ AOTITorchError aoti_torch_mps_convolution(
736736
throw std::runtime_error("Tensor size mismatch");
737737
}
738738

739-
// Store the tensor handle - mark that we own the memory since we manually allocated it with malloc
739+
// Store the tensor handle - mark that we own the memory since we manually allocated it
740740
*ret0 = output_tensor_handle;
741-
is_tensor_own_memory[et_tensor] = true; // We allocated the GPU memory
741+
// Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2
742+
// The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it
743+
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
744+
memory_to_n_tensor[tensor_data] = 1;
742745

743746
ET_LOG(Debug, "aoti_torch_mps_convolution: Created output tensor with %zu elements using MPSGraph", actual_numel);
744747

@@ -1327,10 +1330,11 @@ AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(
13271330
}
13281331

13291332
// Mark that we own the memory for these tensors
1330-
auto* out_et_tensor = reinterpret_cast<Tensor*>(out_tensor_handle);
1331-
auto* attn_et_tensor = reinterpret_cast<Tensor*>(attn_tensor_handle);
1332-
is_tensor_own_memory[out_et_tensor] = true;
1333-
is_tensor_own_memory[attn_et_tensor] = true;
1333+
// Note: memory_to_n_tensor is managed automatically in aoti_torch_create_tensor_from_blob_v2
1334+
// The function sets it to NOT_OWN, but we need to change it to 1 since we allocated it
1335+
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
1336+
memory_to_n_tensor[out_contents_ptr] = 1;
1337+
memory_to_n_tensor[attn_contents_ptr] = 1;
13341338

13351339
// Set output tensor handles
13361340
*ret0 = out_tensor_handle;

backends/apple/metal/runtime/shims/memory.cpp

Lines changed: 154 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@ using namespace executorch::backends::aoti;
3131

3232
// Global storage for tensors and their metadata
3333
std::unordered_set<std::shared_ptr<Tensor>> tensors;
34-
std::unordered_map<Tensor*, bool> is_tensor_own_memory;
34+
35+
// Reference counting for memory addresses
36+
// Maps memory address to number of tensors using it
37+
// Special value: NOT_OWN (-1) means tensor never owns the memory
38+
constexpr int32_t NOT_OWN = -1;
39+
std::unordered_map<void*, int32_t> memory_to_n_tensor;
3540

3641
extern "C" {
3742

@@ -110,7 +115,18 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
110115
// Store the tensor so it doesn't get destroyed
111116
tensors.insert(tensor);
112117
*ret_new_tensor = tensor.get();
113-
is_tensor_own_memory[tensor.get()] = false;
118+
119+
// Check if this memory address is already being tracked
120+
auto memory_it = memory_to_n_tensor.find(adjusted_data);
121+
ET_CHECK_OR_RETURN_ERROR(
122+
memory_it == memory_to_n_tensor.end(),
123+
InvalidArgument,
124+
"Memory address %p is already being tracked by another tensor",
125+
adjusted_data);
126+
127+
// Mark this memory as NOT_OWN since tensor created from blob never owns
128+
// memory
129+
memory_to_n_tensor[adjusted_data] = NOT_OWN;
114130

115131
ET_LOG(Debug, "aoti_torch_create_tensor_from_blob_v2: successfull");
116132
return Error::Ok;
@@ -192,59 +208,91 @@ AOTITorchError aoti_torch_empty_strided(
192208
// Store the tensor so it doesn't get destroyed
193209
tensors.insert(tensor);
194210
*ret_new_tensor = tensor.get();
195-
is_tensor_own_memory[tensor.get()] = true;
211+
212+
// This tensor owns the memory it allocated, set reference count to 1
213+
memory_to_n_tensor[ptr] = 1;
196214

197215
ET_LOG(Debug, "aoti_torch_empty_strided: successfull");
198216
return Error::Ok;
199217
}
200218

201219
AOTITorchError aoti_torch_delete_tensor_object(AOTITensorHandle tensor) {
202220
ET_LOG(Debug, "aoti_torch_delete_tensor_object: entered");
203-
// Find tensor in the set
221+
222+
// Handle null tensor pointer
223+
if (tensor == nullptr) {
224+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: null tensor");
225+
return Error::Ok;
226+
}
227+
228+
// Check if tensor exists in our tracking
229+
bool found_in_tensors = false;
204230
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
205231
if (it->get() == tensor) {
206-
auto tensor_ptr = *it;
232+
found_in_tensors = true;
233+
break;
234+
}
235+
}
207236

208-
// Check ownership before cleaning up
209-
auto ownership_it = is_tensor_own_memory.find(tensor);
210-
bool owns_memory = (ownership_it != is_tensor_own_memory.end())
211-
? ownership_it->second
212-
: false;
237+
// If tensor not found in our tracking, it's invalid
238+
ET_CHECK_OR_RETURN_ERROR(
239+
found_in_tensors, InvalidArgument, "Didn't find tensor %p", tensor);
213240

214-
// Clean up ownership metadata
215-
is_tensor_own_memory.erase(tensor);
241+
// Find and delete the tensor
242+
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
243+
if (it->get() == tensor) {
244+
// Get the tensor before erasing
245+
auto tensor_ptr = *it;
246+
void* data_ptr = tensor_ptr->mutable_data_ptr();
216247

217-
if (owns_memory) {
218-
// et tensor owns the memory; need to free it manually
219-
void* data_ptr = tensor_ptr->mutable_data_ptr();
248+
// Find the reference count for this memory address
249+
auto memory_it = memory_to_n_tensor.find(data_ptr);
250+
if (memory_it != memory_to_n_tensor.end()) {
251+
int32_t ref_count = memory_it->second;
220252

221-
// Check if it's Metal GPU memory
222-
if (metal_is_device_pointer(data_ptr)) {
223-
// This is Metal GPU memory - the Metal helper will handle cleanup
224-
// Metal buffers are automatically managed by ARC when the buffer is
225-
// released
253+
if (ref_count == NOT_OWN) {
254+
// Tensor never owned the memory, skip freeing
255+
// Just remove tensor from tracking
226256
tensors.erase(it);
227-
ET_LOG(
228-
Debug,
229-
"aoti_torch_delete_tensor_object: successfull (Metal GPU memory)");
257+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: tensor doesn't own memory, skipping free");
230258
return Error::Ok;
259+
} else if (ref_count == 1) {
260+
// Only current tensor using this memory, free it
261+
// Check if it's Metal GPU memory
262+
if (metal_is_device_pointer(data_ptr)) {
263+
metal_deallocate_buffer(data_ptr);
264+
} else {
265+
// This is CPU memory - free immediately
266+
free(data_ptr);
267+
data_ptr = nullptr;
268+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: freeing CPU memory");
269+
}
270+
271+
// Remove from memory tracking
272+
memory_to_n_tensor.erase(memory_it);
273+
} else if (ref_count > 1) {
274+
// Other tensors still using this memory, just decrement count
275+
memory_to_n_tensor[data_ptr] = ref_count - 1;
276+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: decremented ref count from %d to %d", ref_count, ref_count - 1);
231277
}
232-
233-
// This is CPU memory - free immediately
234-
free(data_ptr);
278+
} else {
279+
ET_CHECK_OR_RETURN_ERROR(
280+
false,
281+
Internal,
282+
"Internal error: memory not found during deletion");
235283
}
236-
// else: Don't free memory since the tensor doesn't own it
237284

238-
// Remove from set (this will call the destructor if it's the last
285+
// Remove tensor from set (this will call the destructor if it's the last
239286
// reference)
240287
tensors.erase(it);
241-
ET_LOG(
242-
Debug, "aoti_torch_delete_tensor_object: successfull (CPU memory)");
288+
ET_LOG(Debug, "aoti_torch_delete_tensor_object: successfull");
243289
return Error::Ok;
244290
}
245291
}
246-
ET_LOG(Error, "Didn't find tensor %p", tensor);
247-
return Error::InvalidArgument;
292+
293+
// This should never be reached since we found it above
294+
ET_CHECK_OR_RETURN_ERROR(
295+
false, Internal, "Internal error: tensor not found after validation");
248296
}
249297

250298
AOTITorchError aoti_torch_copy_(
@@ -375,75 +423,105 @@ AOTITorchError aoti_torch__reinterpret_tensor(
375423
InvalidArgument,
376424
"aoti_torch__reinterpret_tensor failed: ret_new_tensor is null");
377425

426+
// Check if storage_offset is not 0 - return error if not
427+
ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset));
428+
429+
// Get the device info from the source tensor to perform device_index
430+
// validation
431+
int32_t device_type = 0;
432+
int32_t device_index = 0;
433+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type));
434+
435+
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index));
436+
437+
// Ensure device_index is always 0
438+
ET_CHECK_OR_RETURN_ERROR(
439+
device_index == 0,
440+
InvalidArgument,
441+
"device_index must be 0, got: %d",
442+
device_index);
443+
378444
// Get the dtype from the source tensor
379445
int32_t dtype = 0;
380446
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(self, &dtype));
381447

382448
// Validate dtype using SupportedDTypes
383449
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));
384450

385-
int32_t device_type = 0;
386-
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_type(self, &device_type));
451+
// Get the original data pointer from the source tensor
452+
void* data_ptr = self->mutable_data_ptr();
453+
ET_CHECK_OR_RETURN_ERROR(
454+
data_ptr != nullptr,
455+
InvalidArgument,
456+
"Source tensor has null data pointer");
387457

388-
int32_t device_index = 0;
389-
ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_device_index(self, &device_index));
458+
// Check if the given memory is in the map, if not return error
459+
auto memory_it = memory_to_n_tensor.find(data_ptr);
460+
ET_CHECK_OR_RETURN_ERROR(
461+
memory_it != memory_to_n_tensor.end(),
462+
InvalidArgument,
463+
"Memory address %p is not being tracked by reference counting system",
464+
data_ptr);
465+
466+
// Convert sizes using utility function from utils.h
467+
std::vector<aten::SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
468+
469+
// Convert strides using utility function from utils.h
470+
std::vector<aten::StridesType> strides =
471+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
472+
473+
// Create new tensor view that reinterprets the same memory with different
474+
// shape/strides This creates a view, not a copy - the data pointer is shared
475+
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
476+
data_ptr, // Reuse the same memory from source tensor
477+
sizes, // New sizes with explicit SizesType
478+
strides, // New strides with explicit StridesType
479+
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
480+
);
390481

391-
// Get the base data pointer from the source tensor
392-
void* base_data_ptr = self->mutable_data_ptr();
393482
ET_CHECK_OR_RETURN_ERROR(
394-
base_data_ptr != nullptr,
483+
tensor != nullptr,
395484
InvalidArgument,
396-
"Source tensor has null data pointer");
485+
"Failed to create reinterpreted tensor view");
397486

398-
// Calculate new tensor size in elements for logging
399-
int64_t new_numel = 1;
400-
for (int64_t i = 0; i < ndim; i++) {
401-
new_numel *= sizes_ptr[i];
402-
}
487+
// Store the tensor so it doesn't get destroyed
488+
tensors.insert(tensor);
403489

404-
ET_LOG(
405-
Debug,
406-
"aoti_torch__reinterpret_tensor: base_data_ptr=%p, new_numel=%lld, storage_offset=%lld",
407-
base_data_ptr,
408-
new_numel,
409-
storage_offset);
410-
411-
// Create a new tensor view that shares the same underlying storage
412-
// This is the correct way to implement reinterpret_tensor - as a view, not a
413-
// copy
414-
AOTITorchError create_err = aoti_torch_create_tensor_from_blob_v2(
415-
base_data_ptr, // Same underlying data pointer
416-
ndim, // New dimensions
417-
sizes_ptr, // New sizes
418-
strides_ptr, // New strides
419-
storage_offset, // Storage offset (will be handled properly now)
420-
dtype,
421-
device_type,
422-
device_index,
423-
ret_new_tensor,
424-
0, // layout (default)
425-
nullptr, // opaque_metadata
426-
0 // opaque_metadata_size
427-
);
490+
*ret_new_tensor = tensor.get();
428491

429-
if (create_err != Error::Ok) {
430-
ET_LOG(Error, "failed to create reinterpreted tensor view");
431-
return create_err;
432-
}
492+
// Increment the reference count for this memory address only if it is owned
493+
// by tensor
494+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
495+
? NOT_OWN
496+
: memory_to_n_tensor[data_ptr] + 1;
433497

434498
ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull");
435499
return Error::Ok;
436500
}
437501

438502
// Cleanup function for clearing global state
439503
void cleanup_memory() {
440-
is_tensor_own_memory.clear();
441-
if (!tensors.empty()) {
442-
ET_LOG(Error, "Warning: tensors not empty during cleanup");
504+
// Use aoti_torch_delete_tensor_object to properly delete each tensor
505+
// Note: We need to collect tensor pointers first since deletion modifies the
506+
// set
507+
std::vector<Tensor*> tensor_ptrs;
508+
tensor_ptrs.reserve(tensors.size());
509+
for (const auto& tensor_shared : tensors) {
510+
tensor_ptrs.push_back(tensor_shared.get());
511+
}
512+
513+
// Now delete each tensor - this will modify the global tensors set
514+
for (Tensor* tensor_ptr : tensor_ptrs) {
515+
aoti_torch_delete_tensor_object(tensor_ptr);
443516
}
444517

518+
// tensors set should now be empty, but ensure it's cleared
519+
tensors.clear();
520+
445521
// Clean up Metal resources
446522
metal_cleanup_resources();
523+
524+
ET_LOG(Info, "Cleared all tensors and Metal resources");
447525
}
448526

449527
} // extern "C"

backends/apple/metal/runtime/shims/memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace metal {
2222
extern "C" {
2323

2424
// Global storage declarations
25-
extern std::unordered_map<Tensor*, bool> is_tensor_own_memory;
25+
extern std::unordered_map<void*, int32_t> memory_to_n_tensor;
2626
extern std::unordered_set<std::shared_ptr<Tensor>> tensors;
2727

2828
// Memory-related operations

0 commit comments

Comments
 (0)