|
5 | 5 | #include <mscclpp/gpu.hpp> |
6 | 6 | #include <mscclpp/gpu_utils.hpp> |
7 | 7 |
|
| 8 | +static inline bool isCudaTeardownError(cudaError_t err) { |
| 9 | +#if defined(__HIP_PLATFORM_AMD__) |
| 10 | + return err == cudaErrorContextIsDestroyed || err == cudaErrorInvalidDevice; |
| 11 | +#else // !defined(__HIP_PLATFORM_AMD__) |
| 12 | + return err == cudaErrorCudartUnloading || err == cudaErrorContextIsDestroyed || err == cudaErrorInitializationError || |
| 13 | + err == cudaErrorInvalidDevice; |
| 14 | +#endif // !defined(__HIP_PLATFORM_AMD__) |
| 15 | +} |
| 16 | + |
| 17 | +static inline bool isCuTeardownError(CUresult r) { |
| 18 | + return r == CUDA_ERROR_DEINITIALIZED || r == CUDA_ERROR_CONTEXT_IS_DESTROYED; |
| 19 | +} |
| 20 | + |
| 21 | +#define MSCCLPP_CUDATHROW_IGNORE_TEARDOWN(cmd) \ |
| 22 | + do { \ |
| 23 | + cudaError_t __e = cmd; \ |
| 24 | + if (isCudaTeardownError(__e)) { \ |
| 25 | + (void)cudaGetLastError(); \ |
| 26 | + } else { \ |
| 27 | + MSCCLPP_CUDATHROW(__e); \ |
| 28 | + } \ |
| 29 | + } while (false) |
| 30 | + |
| 31 | +#define MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cmd) \ |
| 32 | + do { \ |
| 33 | + CUresult __e = cmd; \ |
| 34 | + if (!isCuTeardownError(__e)) { \ |
| 35 | + MSCCLPP_CUTHROW(__e); \ |
| 36 | + } \ |
| 37 | + } while (false) |
| 38 | + |
8 | 39 | namespace mscclpp { |
9 | 40 |
|
10 | | -AvoidCudaGraphCaptureGuard::AvoidCudaGraphCaptureGuard() : mode_(cudaStreamCaptureModeRelaxed) { |
11 | | - MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode_)); |
| 41 | +AvoidCudaGraphCaptureGuard::AvoidCudaGraphCaptureGuard() : mode_(cudaStreamCaptureModeRelaxed), active_(true) { |
| 42 | + cudaError_t res = cudaThreadExchangeStreamCaptureMode(&mode_); |
| 43 | + if (isCudaTeardownError(res)) { |
| 44 | + // Runtime is going away; just mark inactive so destructor skips restoring. |
| 45 | + active_ = false; |
| 46 | + (void)cudaGetLastError(); |
| 47 | + } else { |
| 48 | + MSCCLPP_CUDATHROW(res); |
| 49 | + } |
12 | 50 | } |
13 | 51 |
|
14 | | -AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard() { (void)cudaThreadExchangeStreamCaptureMode(&mode_); } |
| 52 | +AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard() { |
| 53 | + if (!active_) return; |
| 54 | + (void)cudaThreadExchangeStreamCaptureMode(&mode_); |
| 55 | +} |
15 | 56 |
|
16 | 57 | CudaStreamWithFlags::CudaStreamWithFlags() : stream_(nullptr) { MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId_)); } |
17 | 58 |
|
@@ -185,25 +226,25 @@ void* gpuCallocPhysical(size_t bytes, size_t gran, size_t align) { |
185 | 226 |
|
186 | 227 | void gpuFree(void* ptr) { |
187 | 228 | AvoidCudaGraphCaptureGuard cgcGuard; |
188 | | - MSCCLPP_CUDATHROW(cudaFree(ptr)); |
| 229 | + MSCCLPP_CUDATHROW_IGNORE_TEARDOWN(cudaFree(ptr)); |
189 | 230 | } |
190 | 231 |
|
191 | 232 | void gpuFreeHost(void* ptr) { |
192 | 233 | AvoidCudaGraphCaptureGuard cgcGuard; |
193 | | - MSCCLPP_CUDATHROW(cudaFreeHost(ptr)); |
| 234 | + MSCCLPP_CUDATHROW_IGNORE_TEARDOWN(cudaFreeHost(ptr)); |
194 | 235 | } |
195 | 236 |
|
196 | 237 | #if (CUDA_NVLS_API_AVAILABLE) |
197 | 238 | void gpuFreePhysical(void* ptr) { |
198 | 239 | AvoidCudaGraphCaptureGuard cgcGuard; |
199 | 240 | CUmemGenericAllocationHandle handle; |
200 | 241 | size_t size = 0; |
201 | | - MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr)); |
202 | | - MSCCLPP_CUTHROW(cuMemRelease(handle)); |
203 | | - MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); |
204 | | - MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size)); |
205 | | - MSCCLPP_CUTHROW(cuMemRelease(handle)); |
206 | | - MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, size)); |
| 242 | + MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cuMemRetainAllocationHandle(&handle, ptr)); |
| 243 | + MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cuMemRelease(handle)); |
| 244 | + MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); |
| 245 | + MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cuMemUnmap((CUdeviceptr)ptr, size)); |
| 246 | + MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cuMemRelease(handle)); |
| 247 | + MSCCLPP_CUTHROW_IGNORE_TEARDOWN(cuMemAddressFree((CUdeviceptr)ptr, size)); |
207 | 248 | } |
208 | 249 | #endif // CUDA_NVLS_API_AVAILABLE |
209 | 250 |
|
|
0 commit comments