Skip to content

Commit e45f680

Browse files
tensor destroy (#14698)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14686 by @larryliu0820 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/larryliu0820/76/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/larryliu0820/76/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/larryliu0820/76/orig @diff-train-skip-merge Co-authored-by: Mengwei Liu <[email protected]>
1 parent 5d29a7d commit e45f680

File tree

4 files changed

+532
-2
lines changed

4 files changed

+532
-2
lines changed

backends/cuda/runtime/shims/memory.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,78 @@ AOTITorchError aoti_torch_empty_strided(
123123
return Error::Ok;
124124
}
125125

126-
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
127126
void clear_all_tensors() {
127+
// Use aoti_torch_delete_tensor_object to properly delete each tensor
128+
// Note: We need to collect tensor pointers first since deletion modifies the
129+
// set
130+
auto old_tensors =
131+
std::move(tensors); // tensors is now empty and no need to copy
132+
for (const auto& tensor_shared : old_tensors) {
133+
aoti_torch_delete_tensor_object(tensor_shared.get());
134+
}
135+
136+
// tensors set should now be empty, but ensure it's cleared
128137
tensors.clear();
129138
}
130139

140+
AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
141+
// Handle null tensor pointer
142+
if (tensor == nullptr) {
143+
ET_LOG(Error, "Cannot delete null tensor");
144+
return Error::InvalidArgument;
145+
}
146+
147+
// Check if tensor exists in our tracking
148+
bool found_in_tensors = false;
149+
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
150+
if (it->get() == tensor) {
151+
found_in_tensors = true;
152+
break;
153+
}
154+
}
155+
156+
// If tensor not found in our tracking, it's invalid
157+
if (!found_in_tensors) {
158+
ET_LOG(Error, "Didn't find tensor %p", tensor);
159+
return Error::InvalidArgument;
160+
}
161+
162+
// Find and delete the tensor
163+
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
164+
if (it->get() == tensor) {
165+
// Get the tensor before erasing
166+
auto tensor_ptr = *it;
167+
168+
void* data_ptr = tensor_ptr->mutable_data_ptr();
169+
170+
// Determine if it's GPU memory
171+
cudaPointerAttributes attributes{};
172+
ET_CUDA_CHECK_OR_RETURN_ERROR(
173+
cudaPointerGetAttributes(&attributes, data_ptr));
174+
175+
// et tensor does not own data; need to free them manually.
176+
if (attributes.type == cudaMemoryTypeManaged) {
177+
// This is CUDA managed memory - free with proper synchronization
178+
ET_CUDA_CHECK_OR_RETURN_ERROR(
179+
cudaDeviceSynchronize()); // Wait for all operations to complete
180+
// BEFORE freeing
181+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr));
182+
} else {
183+
// This is CPU memory - free immediately
184+
free(data_ptr);
185+
}
186+
// Remove from set (this will call the destructor if it's the last
187+
// reference)
188+
tensors.erase(it);
189+
return Error::Ok;
190+
}
191+
}
192+
193+
// This should never be reached since we found it above
194+
ET_LOG(Error, "Internal error: tensor not found after validation");
195+
return Error::Internal;
196+
}
197+
131198
} // extern "C"
132199

133200
} // namespace cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,16 @@ AOTITorchError aoti_torch_empty_strided(
4444
int32_t device_index,
4545
Tensor** ret_new_tensor);
4646

47+
/**
48+
* Deletes a tensor object and frees its associated memory.
49+
*
50+
* @param tensor Pointer to the tensor object to be deleted
51+
* @return AOTITorchError error code (Error::Ok on success, or an error code on
52+
* failure)
53+
*/
54+
AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);
55+
4756
// Function to clear all tensors from internal storage
48-
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
4957
void clear_all_tensors();
5058

5159
} // extern "C"

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ def define_common_targets():
2828
TARGETS and BUCK files that call this function.
2929
"""
3030
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
31+
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")

0 commit comments

Comments
 (0)