diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index aee1ecd07..2648bf48c 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -60,12 +60,10 @@ UniqueAVBufferRef getHardwareDeviceContext(const torch::Device& device) { // Create hardware device context c10::cuda::CUDAGuard deviceGuard(device); - // Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1: - // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb - // So we ensure the deviceIndex is not negative. // We set the device because we may be called from a different thread than // the one that initialized the cuda context. - cudaSetDevice(deviceIndex); + TORCH_CHECK( + cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device"); AVBufferRef* hardwareDeviceCtxRaw = nullptr; std::string deviceOrdinal = std::to_string(deviceIndex);