Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,78 @@ AOTITorchError aoti_torch_empty_strided(
return Error::Ok;
}

// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
void clear_all_tensors() {
// Use aoti_torch_delete_tensor_object to properly delete each tensor
// Note: We need to collect tensor pointers first since deletion modifies the
// set
auto old_tensors =
std::move(tensors); // tensors is now empty and no need to copy
for (const auto& tensor_shared : old_tensors) {
aoti_torch_delete_tensor_object(tensor_shared.get());
}

// tensors set should now be empty, but ensure it's cleared
tensors.clear();
}

AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
// Handle null tensor pointer
if (tensor == nullptr) {
ET_LOG(Error, "Cannot delete null tensor");
return Error::InvalidArgument;
}

// Check if tensor exists in our tracking
bool found_in_tensors = false;
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
if (it->get() == tensor) {
found_in_tensors = true;
break;
}
}

// If tensor not found in our tracking, it's invalid
if (!found_in_tensors) {
ET_LOG(Error, "Didn't find tensor %p", tensor);
return Error::InvalidArgument;
}

// Find and delete the tensor
for (auto it = tensors.begin(); it != tensors.end(); ++it) {
if (it->get() == tensor) {
// Get the tensor before erasing
auto tensor_ptr = *it;

void* data_ptr = tensor_ptr->mutable_data_ptr();

// Determine if it's GPU memory
cudaPointerAttributes attributes{};
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaPointerGetAttributes(&attributes, data_ptr));

// et tensor does not own data; need to free them manually.
if (attributes.type == cudaMemoryTypeManaged) {
// This is CUDA managed memory - free with proper synchronization
ET_CUDA_CHECK_OR_RETURN_ERROR(
cudaDeviceSynchronize()); // Wait for all operations to complete
// BEFORE freeing
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaFree(data_ptr));
} else {
// This is CPU memory - free immediately
free(data_ptr);
}
// Remove from set (this will call the destructor if it's the last
// reference)
tensors.erase(it);
return Error::Ok;
}
}

// This should never be reached since we found it above
ET_LOG(Error, "Internal error: tensor not found after validation");
return Error::Internal;
}

} // extern "C"

} // namespace cuda
Expand Down
10 changes: 9 additions & 1 deletion backends/cuda/runtime/shims/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@ AOTITorchError aoti_torch_empty_strided(
int32_t device_index,
Tensor** ret_new_tensor);

/**
* Deletes a tensor object and frees its associated memory.
*
* @param tensor Pointer to the tensor object to be deleted
* @return AOTITorchError error code (Error::Ok on success, or an error code on
* failure)
*/
AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);

// Function to clear all tensors from internal storage
// TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
void clear_all_tensors();

} // extern "C"
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ def define_common_targets():
TARGETS and BUCK files that call this function.
"""
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
Loading
Loading