Skip to content

Commit 0b1d162

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into find-nvcuvid
2 parents 0f0e612 + c2e202d commit 0b1d162

File tree

12 files changed

+153
-58
lines changed

12 files changed

+153
-58
lines changed

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ jobs:
6767
# For the actual release we should add that label and change this to
6868
# include more python versions.
6969
python-version: ['3.10']
70-
# We test against 12.6 and 12.9 to avoid having too big of a CI matrix,
70+
# We test against 12.6 to avoid having too big of a CI matrix,
7171
# but for releases we should add 12.8.
72-
cuda-version: ['12.6', '12.9']
72+
# TODO add 13.0!
73+
cuda-version: ['12.6']
7374
# TODO: put back ffmpeg 5 https://github.com/pytorch/torchcodec/issues/325
7475
ffmpeg-version-for-tests: ['4.4.2', '6', '7']
7576

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ static bool g_cpu = registerDeviceInterface(
1515

1616
} // namespace
1717

18+
CpuDeviceInterface::SwsFrameContext::SwsFrameContext(
19+
int inputWidth,
20+
int inputHeight,
21+
AVPixelFormat inputFormat,
22+
int outputWidth,
23+
int outputHeight)
24+
: inputWidth(inputWidth),
25+
inputHeight(inputHeight),
26+
inputFormat(inputFormat),
27+
outputWidth(outputWidth),
28+
outputHeight(outputHeight) {}
29+
1830
bool CpuDeviceInterface::SwsFrameContext::operator==(
1931
const CpuDeviceInterface::SwsFrameContext& other) const {
2032
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
@@ -97,13 +109,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
97109
// And we sometimes re-create them because it's possible for frame
98110
// resolution to change mid-stream. Finally, we want to reuse the colorspace
99111
// conversion objects as much as possible for performance reasons.
100-
SwsFrameContext swsFrameContext;
101-
102-
swsFrameContext.inputWidth = avFrame->width;
103-
swsFrameContext.inputHeight = avFrame->height;
104-
swsFrameContext.inputFormat = frameFormat;
105-
swsFrameContext.outputWidth = expectedOutputWidth;
106-
swsFrameContext.outputHeight = expectedOutputHeight;
112+
SwsFrameContext swsFrameContext(
113+
avFrame->width,
114+
avFrame->height,
115+
frameFormat,
116+
expectedOutputWidth,
117+
expectedOutputHeight);
107118

108119
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
109120
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
@@ -128,22 +139,20 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
128139
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129140
// See comment above in swscale branch about the filterGraphContext_
130141
// creation. creation
131-
FiltersContext filtersContext;
132-
133-
filtersContext.inputWidth = avFrame->width;
134-
filtersContext.inputHeight = avFrame->height;
135-
filtersContext.inputFormat = frameFormat;
136-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
137-
filtersContext.outputWidth = expectedOutputWidth;
138-
filtersContext.outputHeight = expectedOutputHeight;
139-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
140-
filtersContext.timeBase = timeBase;
141-
142142
std::stringstream filters;
143143
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
144144
filters << ":sws_flags=bilinear";
145145

146-
filtersContext.filtergraphStr = filters.str();
146+
FiltersContext filtersContext(
147+
avFrame->width,
148+
avFrame->height,
149+
frameFormat,
150+
avFrame->sample_aspect_ratio,
151+
expectedOutputWidth,
152+
expectedOutputHeight,
153+
AV_PIX_FMT_RGB24,
154+
filters.str(),
155+
timeBase);
147156

148157
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
149158
filterGraphContext_ =

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ class CpuDeviceInterface : public DeviceInterface {
4343
const UniqueAVFrame& avFrame);
4444

4545
struct SwsFrameContext {
46-
int inputWidth;
47-
int inputHeight;
48-
AVPixelFormat inputFormat;
49-
int outputWidth;
50-
int outputHeight;
46+
int inputWidth = 0;
47+
int inputHeight = 0;
48+
AVPixelFormat inputFormat = AV_PIX_FMT_NONE;
49+
int outputWidth = 0;
50+
int outputHeight = 0;
51+
52+
SwsFrameContext() = default;
53+
SwsFrameContext(
54+
int inputWidth,
55+
int inputHeight,
56+
AVPixelFormat inputFormat,
57+
int outputWidth,
58+
int outputHeight);
5159
bool operator==(const SwsFrameContext&) const;
5260
bool operator!=(const SwsFrameContext&) const;
5361
};

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
275275
}
276276

277277
torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_);
278-
nppCtx_->hStream = at::cuda::getCurrentCUDAStream(deviceIndex).stream();
278+
279+
// Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
280+
// NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
281+
// We will be waiting for this event to complete before calling the NPP
282+
// functions, to ensure NVDEC has finished decoding the frame before running
283+
// the NPP color-conversion.
284+
// Note that our code is generic and assumes that the NVDEC's stream can be
285+
// arbitrary, but unfortunately we know it's hardcoded to be the default
286+
// stream by FFmpeg:
287+
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
288+
TORCH_CHECK(
289+
hwFramesCtx->device_ctx != nullptr,
290+
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
291+
auto cudaDeviceCtx =
292+
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
293+
at::cuda::CUDAEvent nvdecDoneEvent;
294+
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
295+
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
296+
nvdecDoneEvent.record(nvdecStream);
297+
298+
// Don't start NPP work before NVDEC is done decoding the frame!
299+
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
300+
nvdecDoneEvent.block(nppStream);
301+
302+
// Create the NPP context if we haven't yet.
303+
nppCtx_->hStream = nppStream.stream();
279304
cudaError_t err =
280305
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
281306
TORCH_CHECK(

src/torchcodec/_core/DeviceInterface.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,6 @@
1717

1818
namespace facebook::torchcodec {
1919

20-
// Note that all these device functions should only be called if the device is
21-
// not a CPU device. CPU device functions are already implemented in the
22-
// SingleStreamDecoder implementation.
23-
// These functions should only be called from within an if block like this:
24-
// if (device.type() != torch::kCPU) {
25-
// deviceFunction(device, ...);
26-
// }
27-
2820
class DeviceInterface {
2921
public:
3022
DeviceInterface(const torch::Device& device) : device_(device) {}

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,15 @@ int getNumChannels(const UniqueAVFrame& avFrame) {
6161
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
6262
return avFrame->ch_layout.nb_channels;
6363
#else
64-
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
64+
int numChannels = av_get_channel_layout_nb_channels(avFrame->channel_layout);
65+
// Handle FFmpeg 4 bug where channel_layout and numChannels are 0 or unset
66+
// Set values based on avFrame->channels which appears to be correct
67+
// to allow successful initialization of SwrContext
68+
if (numChannels == 0 && avFrame->channels > 0) {
69+
avFrame->channel_layout = av_get_default_channel_layout(avFrame->channels);
70+
numChannels = avFrame->channels;
71+
}
72+
return numChannels;
6573
#endif
6674
}
6775

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@ extern "C" {
1313

1414
namespace facebook::torchcodec {
1515

16+
FiltersContext::FiltersContext(
17+
int inputWidth,
18+
int inputHeight,
19+
AVPixelFormat inputFormat,
20+
AVRational inputAspectRatio,
21+
int outputWidth,
22+
int outputHeight,
23+
AVPixelFormat outputFormat,
24+
const std::string& filtergraphStr,
25+
AVRational timeBase)
26+
: inputWidth(inputWidth),
27+
inputHeight(inputHeight),
28+
inputFormat(inputFormat),
29+
inputAspectRatio(inputAspectRatio),
30+
outputWidth(outputWidth),
31+
outputHeight(outputHeight),
32+
outputFormat(outputFormat),
33+
filtergraphStr(filtergraphStr),
34+
timeBase(timeBase) {}
35+
1636
bool operator==(const AVRational& lhs, const AVRational& rhs) {
1737
return lhs.num == rhs.num && lhs.den == rhs.den;
1838
}

src/torchcodec/_core/FilterGraph.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,24 @@ struct FiltersContext {
1919
int outputWidth = 0;
2020
int outputHeight = 0;
2121
AVPixelFormat outputFormat = AV_PIX_FMT_NONE;
22-
2322
std::string filtergraphStr;
2423
AVRational timeBase = {0, 0};
2524
UniqueAVBufferRef hwFramesCtx;
2625

26+
FiltersContext() = default;
27+
FiltersContext(FiltersContext&&) = default;
28+
FiltersContext& operator=(FiltersContext&&) = default;
29+
FiltersContext(
30+
int inputWidth,
31+
int inputHeight,
32+
AVPixelFormat inputFormat,
33+
AVRational inputAspectRatio,
34+
int outputWidth,
35+
int outputHeight,
36+
AVPixelFormat outputFormat,
37+
const std::string& filtergraphStr,
38+
AVRational timeBase);
39+
2740
bool operator==(const FiltersContext&) const;
2841
bool operator!=(const FiltersContext&) const;
2942
};

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,6 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
243243
return;
244244
}
245245

246-
for (unsigned int i = 0; i < formatContext_->nb_streams; ++i) {
247-
// We want to scan and update the metadata of all streams.
248-
TORCH_CHECK(
249-
formatContext_->streams[i]->discard != AVDISCARD_ALL,
250-
"Did you add a stream before you called for a scan?");
251-
}
252-
253246
AutoAVPacket autoAVPacket;
254247
while (true) {
255248
ReferenceAVPacket packet(autoAVPacket);
@@ -1253,7 +1246,11 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput(
12531246
formatContext_->streams[activeStreamIndex_]->time_base);
12541247
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
12551248
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
1256-
} else if (deviceInterface_) {
1249+
} else {
1250+
TORCH_CHECK(
1251+
deviceInterface_ != nullptr,
1252+
"No device interface available for video decoding. This ",
1253+
"shouldn't happen, please report.");
12571254
deviceInterface_->convertAVFrameToFrameOutput(
12581255
streamInfo.videoStreamOptions,
12591256
streamInfo.timeBase,

src/torchcodec/decoders/_video_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
247247
Returns:
248248
FrameBatch: The frames at the given indices.
249249
"""
250+
if isinstance(indices, torch.Tensor):
251+
# TODO we should avoid converting tensors to lists and just let the
252+
# core ops and C++ code natively accept tensors. See
253+
# https://github.com/pytorch/torchcodec/issues/879
254+
indices = indices.to(torch.int).tolist()
255+
250256
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
251257
self._decoder, frame_indices=indices
252258
)
@@ -322,6 +328,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
322328
Returns:
323329
FrameBatch: The frames that are played at ``seconds``.
324330
"""
331+
if isinstance(seconds, torch.Tensor):
332+
# TODO we should avoid converting tensors to lists and just let the
333+
# core ops and C++ code natively accept tensors. See
334+
# https://github.com/pytorch/torchcodec/issues/879
335+
seconds = seconds.to(torch.float).tolist()
336+
325337
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
326338
self._decoder, timestamps=seconds
327339
)

0 commit comments

Comments
 (0)