diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 653687fccb7..e10322ad40c 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -286,18 +286,6 @@ class ET_EXPERIMENTAL CudaBackend final i); } - // Clean up GPU tensors that we created (ExecuTorch tensors are always - // CPU, so all GPU tensors are our copies) - for (int i = 0; i < n_inputs; i++) { - // All GPU input tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_inputs[i]); - } - - for (int i = 0; i < n_outputs; i++) { - // All GPU output tensors were created by us, delete them - aoti_torch_delete_tensor_object(gpu_outputs[i]); - } - return Error::Ok; } @@ -318,16 +306,13 @@ class ET_EXPERIMENTAL CudaBackend final handle->cuda_stream = nullptr; } - // Delete the container BEFORE closing the shared library - if (handle->container_handle != nullptr) { - AOTIRuntimeError delete_result = - AOTInductorModelContainerDelete(handle->container_handle); - ET_CHECK_OR_LOG_ERROR( - delete_result == Error::Ok, - "Failed to delete AOTInductorModelContainer with error code %d", - delete_result); - handle->container_handle = nullptr; - } + // NOTE: AOTInductorModelContainerDelete does not work correctly with + // multiple .so files. Deleting one container frees shared resources, + // which causes segmentation faults when attempting to delete other + // containers. As a workaround, we skip explicit container deletion + // and defer cleanup to the OS. + // TODO(gasoonjia): Find a proper solution for safe container deletion. + // AOTInductorModelContainerDelete(handle->container_handle); // Now close the shared library if (handle->so_handle != nullptr) { @@ -346,6 +331,7 @@ class ET_EXPERIMENTAL CudaBackend final } delete handle; + clear_all_tensors(); } }; diff --git a/backends/cuda/runtime/shims/memory.cpp b/backends/cuda/runtime/shims/memory.cpp index cbaca68576e..a054169330b 100644 --- a/backends/cuda/runtime/shims/memory.cpp +++ b/backends/cuda/runtime/shims/memory.cpp @@ -271,14 +271,21 @@ void clear_all_tensors() { // Use aoti_torch_delete_tensor_object to properly delete each tensor // Note: We need to collect tensor pointers first since deletion modifies the // set - auto old_tensors = - std::move(tensors); // tensors is now empty and no need to copy - for (const auto& tensor_shared : old_tensors) { - aoti_torch_delete_tensor_object(tensor_shared.get()); + std::vector tensor_ptrs; + tensor_ptrs.reserve(tensors.size()); + for (const auto& tensor_shared : tensors) { + tensor_ptrs.push_back(tensor_shared.get()); + } + + // Now delete each tensor - this will modify the global tensors set + for (Tensor* tensor_ptr : tensor_ptrs) { + aoti_torch_delete_tensor_object(tensor_ptr); } // tensors set should now be empty, but ensure it's cleared tensors.clear(); + + ET_LOG(Info, "Cleared all tensors"); } AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {