Skip to content

Commit 986f10c

Browse files
authored
BETA CUDA interface: Fix CUDA context initialization (#946)
1 parent e0fe5b6 commit 986f10c

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device)
203203
TORCH_CHECK(
204204
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
205205

206-
// Initialize CUDA context with a dummy tensor
207-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
208-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
209-
206+
initializeCudaContextWithPytorch(device_);
210207
nppCtx_ = getNppStreamContext(device_);
211208
}
212209

src/torchcodec/_core/CUDACommon.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ PerGpuCache<NppStreamContext> g_cached_npp_ctxs(
2323

2424
} // namespace
2525

26+
void initializeCudaContextWithPytorch(const torch::Device& device) {
27+
// It is important for pytorch itself to create the cuda context. If ffmpeg
28+
// creates the context it may not be compatible with pytorch.
29+
// This is a dummy tensor to initialize the cuda context.
30+
torch::Tensor dummyTensorForCudaInitialization = torch::zeros(
31+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
32+
}
33+
2634
/* clang-format off */
2735
// Note: [YUV -> RGB Color Conversion, color space and color range]
2836
//

src/torchcodec/_core/CUDACommon.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ extern "C" {
2222

2323
namespace facebook::torchcodec {
2424

25+
void initializeCudaContextWithPytorch(const torch::Device& device);
26+
2527
// Unique pointer type for NPP stream context
2628
using UniqueNppContext = std::unique_ptr<NppStreamContext>;
2729

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,10 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
129129
TORCH_CHECK(
130130
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
131131

132-
// It is important for pytorch itself to create the cuda context. If ffmpeg
133-
// creates the context it may not be compatible with pytorch.
134-
// This is a dummy tensor to initialize the cuda context.
135-
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
136-
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_));
132+
initializeCudaContextWithPytorch(device_);
133+
134+
// TODO rename this, this is a hardware device context, not a CUDA context!
135+
// See https://github.com/meta-pytorch/torchcodec/issues/924
137136
ctx_ = getCudaContext(device_);
138137
nppCtx_ = getNppStreamContext(device_);
139138
}

0 commit comments

Comments
 (0)