Skip to content

Commit 4bdc851

Browse files
committed
Merge branch 'main' of https://github.com/pytorch/torchcodec into cuda8
2 parents 6be7b76 + c6a0a5a commit 4bdc851

File tree

3 files changed

+61
-12
lines changed

3 files changed

+61
-12
lines changed

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,40 @@ AVBufferRef* getFromCache(const torch::Device& device) {
7777
return nullptr;
7878
}
7979

80-
AVBufferRef* getCudaContext(const torch::Device& device) {
81-
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
82-
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
83-
torch::DeviceIndex deviceIndex = getFFMPEGCompatibleDeviceIndex(device);
84-
85-
AVBufferRef* hw_device_ctx = getFromCache(device);
86-
if (hw_device_ctx != nullptr) {
87-
return hw_device_ctx;
80+
AVBufferRef* getFFMPEGContextFromExistingCudaContext(
81+
const torch::Device& device,
82+
torch::DeviceIndex nonNegativeDeviceIndex,
83+
enum AVHWDeviceType type) {
84+
c10::cuda::CUDAGuard deviceGuard(device);
85+
// Valid values for the argument to cudaSetDevice are 0 to maxDevices - 1:
86+
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb
87+
// So we ensure the deviceIndex is not negative.
88+
// We set the device because we may be called from a different thread than
89+
// the one that initialized the cuda context.
90+
cudaSetDevice(nonNegativeDeviceIndex);
91+
AVBufferRef* hw_device_ctx = nullptr;
92+
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
93+
int err = av_hwdevice_ctx_create(
94+
&hw_device_ctx,
95+
type,
96+
deviceOrdinal.c_str(),
97+
nullptr,
98+
AV_CUDA_USE_CURRENT_CONTEXT);
99+
if (err < 0) {
100+
TORCH_CHECK(
101+
false,
102+
"Failed to create specified HW device",
103+
getFFMPEGErrorStringFromErrorCode(err));
88104
}
105+
return hw_device_ctx;
106+
}
89107

90-
std::string deviceOrdinal = std::to_string(deviceIndex);
108+
AVBufferRef* getFFMPEGContextFromNewCudaContext(
109+
const torch::Device& device,
110+
torch::DeviceIndex nonNegativeDeviceIndex,
111+
enum AVHWDeviceType type) {
112+
AVBufferRef* hw_device_ctx = nullptr;
113+
std::string deviceOrdinal = std::to_string(nonNegativeDeviceIndex);
91114
int err = av_hwdevice_ctx_create(
92115
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
93116
if (err < 0) {
@@ -99,6 +122,32 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
99122
return hw_device_ctx;
100123
}
101124

125+
AVBufferRef* getCudaContext(const torch::Device& device) {
126+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
127+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
128+
torch::DeviceIndex nonNegativeDeviceIndex =
129+
getFFMPEGCompatibleDeviceIndex(device);
130+
131+
AVBufferRef* hw_device_ctx = getFromCache(device);
132+
if (hw_device_ctx != nullptr) {
133+
return hw_device_ctx;
134+
}
135+
136+
// 58.26.100 introduced the concept of reusing the existing cuda context
137+
// which is much faster and lower memory than creating a new cuda context.
138+
// So we try to use that if it is available.
139+
// FFMPEG 6.1.2 appears to be the earliest release that contains version
140+
// 58.26.100 of avutil.
141+
// https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265
142+
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
143+
return getFFMPEGContextFromExistingCudaContext(
144+
device, nonNegativeDeviceIndex, type);
145+
#else
146+
return getFFMPEGContextFromNewCudaContext(
147+
device, nonNegativeDeviceIndex, type);
148+
#endif
149+
}
150+
102151
torch::Tensor allocateDeviceTensor(
103152
at::IntArrayRef shape,
104153
torch::Device device,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
11681168
getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* frame) {
11691169
StreamInfo& activeStream = streams_[frameStreamIndex];
11701170
return frame->pts >=
1171-
activeStream.discardFramesBeforePts.value_or(INT64_MIN);
1171+
activeStream.discardFramesBeforePts;
11721172
});
11731173
return rawOutput;
11741174
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ class VideoDecoder {
307307
int64_t currentDuration = 0;
308308
// The desired position of the cursor in the stream. We send frames >=
309309
// this pts to the user when they request a frame.
310-
// We set this field if the user requested a seek.
311-
std::optional<int64_t> discardFramesBeforePts = 0;
310+
// We update this field if the user requested a seek.
311+
int64_t discardFramesBeforePts = INT64_MIN;
312312
VideoStreamDecoderOptions options;
313313
// The filter state associated with this stream (for video streams). The
314314
// actual graph will be nullptr for inactive streams.

0 commit comments

Comments
 (0)