Skip to content

Commit 13590bb

Browse files
committed
Use cuda filters to support 10-bit videos
For: #776 Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 4a842f9 commit 13590bb

File tree

5 files changed

+133
-30
lines changed

5 files changed

+133
-30
lines changed

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,121 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) {
199199
return;
200200
}
201201

202+
std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
203+
const VideoStreamOptions& videoStreamOptions,
204+
const UniqueAVFrame& avFrame,
205+
const AVRational& timeBase) {
206+
// We need FFmpeg filters to handle those conversion cases which are not
207+
// directly implemented in CUDA or CPU device interface (in case of a
208+
// fallback).
209+
enum AVPixelFormat frameFormat =
210+
static_cast<enum AVPixelFormat>(avFrame->format);
211+
212+
// Input frame is on CPU, we will just pass it to CPU device interface, so
213+
// skipping filters context as CPU device interface will handle everythong for
214+
// us.
215+
if (avFrame->format != AV_PIX_FMT_CUDA) {
216+
return nullptr;
217+
}
218+
219+
TORCH_CHECK(
220+
avFrame->hw_frames_ctx != nullptr,
221+
"The AVFrame does not have a hw_frames_ctx. "
222+
"That's unexpected, please report this to the TorchCodec repo.");
223+
224+
auto hwFramesCtx =
225+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
226+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
227+
228+
// NV12 conversion is implemented directly with NPP, no need for filters.
229+
if (actualFormat == AV_PIX_FMT_NV12) {
230+
return nullptr;
231+
}
232+
233+
auto frameDims =
234+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
235+
int height = frameDims.height;
236+
int width = frameDims.width;
237+
238+
AVPixelFormat outputFormat;
239+
std::stringstream filters;
240+
241+
unsigned version_int = avfilter_version();
242+
if (version_int < AV_VERSION_INT(8, 0, 103)) {
243+
// Color conversion support ('format=' option) was added to scale_cuda from
244+
// n5.0. With the earlier version of ffmpeg we have no choice but use CPU
245+
// filters. See:
246+
// https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95
247+
outputFormat = AV_PIX_FMT_RGB24;
248+
249+
filters << "hwdownload,format=" << av_pix_fmt_desc_get(actualFormat)->name;
250+
filters << ",scale=" << width << ":" << height;
251+
filters << ":sws_flags=bilinear";
252+
} else {
253+
// Actual output color format will be set via filter options
254+
outputFormat = AV_PIX_FMT_CUDA;
255+
256+
filters << "scale_cuda=" << width << ":" << height;
257+
filters << ":format=nv12:interp_algo=bilinear";
258+
}
259+
260+
return std::make_unique<FiltersContext>(
261+
avFrame->width,
262+
avFrame->height,
263+
frameFormat,
264+
avFrame->sample_aspect_ratio,
265+
width,
266+
height,
267+
outputFormat,
268+
filters.str(),
269+
timeBase,
270+
av_buffer_ref(avFrame->hw_frames_ctx));
271+
}
272+
202273
void CudaDeviceInterface::convertAVFrameToFrameOutput(
203274
const VideoStreamOptions& videoStreamOptions,
204275
[[maybe_unused]] const AVRational& timeBase,
205-
UniqueAVFrame& avFrame,
276+
UniqueAVFrame& avInputFrame,
206277
FrameOutput& frameOutput,
207278
std::optional<torch::Tensor> preAllocatedOutputTensor) {
279+
std::unique_ptr<FiltersContext> newFiltersContext =
280+
initializeFiltersContext(videoStreamOptions, avInputFrame, timeBase);
281+
UniqueAVFrame avFilteredFrame;
282+
if (newFiltersContext) {
283+
// We need to compare the current filter context with our previous filter
284+
// context. If they are different, then we need to re-create a filter
285+
// graph. We create a filter graph late so that we don't have to depend
286+
// on the unreliable metadata in the header. And we sometimes re-create
287+
// it because it's possible for frame resolution to change mid-stream.
288+
// Finally, we want to reuse the filter graph as much as possible for
289+
// performance reasons.
290+
if (!filterGraph_ || filtersContext_ != newFiltersContext) {
291+
filterGraph_ =
292+
std::make_unique<FilterGraph>(*newFiltersContext, videoStreamOptions);
293+
filtersContext_ = std::move(newFiltersContext);
294+
}
295+
avFilteredFrame = filterGraph_->convert(avInputFrame);
296+
297+
// If this check fails it means the frame wasn't
298+
// reshaped to its expected dimensions by filtergraph.
299+
TORCH_CHECK(
300+
(avFilteredFrame->width == filtersContext_->outputWidth) &&
301+
(avFilteredFrame->height == filtersContext_->outputHeight),
302+
"Expected frame from filter graph of ",
303+
filtersContext_->outputWidth,
304+
"x",
305+
filtersContext_->outputHeight,
306+
", got ",
307+
avFilteredFrame->width,
308+
"x",
309+
avFilteredFrame->height);
310+
}
311+
312+
UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame;
313+
314+
// The filtered frame might be on CPU if CPU fallback has happenned on filter
315+
// graph level. For example, that's how we handle color format conversion
316+
// on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet.
208317
if (avFrame->format != AV_PIX_FMT_CUDA) {
209318
// The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
210319
// the GPU. In this branch, the frame is on the CPU: this is what NVDEC
@@ -232,8 +341,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
232341
// Above we checked that the AVFrame was on GPU, but that's not enough, we
233342
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
234343
// because this is what the NPP color conversion routines expect.
235-
// TODO: we should investigate how to can perform color conversion for
236-
// non-8bit videos. This is supported on CPU.
237344
TORCH_CHECK(
238345
avFrame->hw_frames_ctx != nullptr,
239346
"The AVFrame does not have a hw_frames_ctx. "
@@ -242,16 +349,14 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
242349
auto hwFramesCtx =
243350
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
244351
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
352+
245353
TORCH_CHECK(
246354
actualFormat == AV_PIX_FMT_NV12,
247355
"The AVFrame is ",
248356
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
249357
: "unknown"),
250-
", but we expected AV_PIX_FMT_NV12. This typically happens when "
251-
"the video isn't 8bit, which is not supported on CUDA at the moment. "
252-
"Try using the CPU device instead. "
253-
"If the video is 10bit, we are tracking 10bit support in "
254-
"https://github.com/pytorch/torchcodec/issues/776");
358+
", but we expected AV_PIX_FMT_NV12. "
359+
"That's unexpected, please report this to the TorchCodec repo.");
255360

256361
auto frameDims =
257362
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);

src/torchcodec/_core/CudaDeviceInterface.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <npp.h>
1010
#include "src/torchcodec/_core/DeviceInterface.h"
11+
#include "src/torchcodec/_core/FilterGraph.h"
1112

1213
namespace facebook::torchcodec {
1314

@@ -30,8 +31,17 @@ class CudaDeviceInterface : public DeviceInterface {
3031
std::nullopt) override;
3132

3233
private:
34+
std::unique_ptr<FiltersContext> initializeFiltersContext(
35+
const VideoStreamOptions& videoStreamOptions,
36+
const UniqueAVFrame& avFrame,
37+
const AVRational& timeBase);
38+
3339
UniqueAVBufferRef ctx_;
3440
std::unique_ptr<NppStreamContext> nppCtx_;
41+
// Current filter context. Used to know whether a new FilterGraph
42+
// should be created to process the next frame.
43+
std::unique_ptr<FiltersContext> filtersContext_;
44+
std::unique_ptr<FilterGraph> filterGraph_;
3545
};
3646

3747
} // namespace facebook::torchcodec

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ FiltersContext::FiltersContext(
2222
int outputHeight,
2323
AVPixelFormat outputFormat,
2424
const std::string& filtergraphStr,
25-
AVRational timeBase)
25+
AVRational timeBase,
26+
AVBufferRef* hwFramesCtx)
2627
: inputWidth(inputWidth),
2728
inputHeight(inputHeight),
2829
inputFormat(inputFormat),
@@ -31,7 +32,8 @@ FiltersContext::FiltersContext(
3132
outputHeight(outputHeight),
3233
outputFormat(outputFormat),
3334
filtergraphStr(filtergraphStr),
34-
timeBase(timeBase) {}
35+
timeBase(timeBase),
36+
hwFramesCtx(hwFramesCtx) {}
3537

3638
bool operator==(const AVRational& lhs, const AVRational& rhs) {
3739
return lhs.num == rhs.num && lhs.den == rhs.den;

src/torchcodec/_core/FilterGraph.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ struct FiltersContext {
3535
int outputHeight,
3636
AVPixelFormat outputFormat,
3737
const std::string& filtergraphStr,
38-
AVRational timeBase);
38+
AVRational timeBase,
39+
AVBufferRef* hwFramesCtx = nullptr);
3940

4041
bool operator==(const FiltersContext&) const;
4142
bool operator!=(const FiltersContext&) const;

test/test_decoders.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,22 +1225,6 @@ def test_full_and_studio_range_bt709_video(self, asset):
12251225
elif cuda_version_used_for_building_torch() == (12, 8):
12261226
assert psnr(gpu_frame, cpu_frame) > 20
12271227

1228-
@needs_cuda
1229-
def test_10bit_videos_cuda(self):
1230-
# Assert that we raise proper error on different kinds of 10bit videos.
1231-
1232-
# TODO we should investigate how to support 10bit videos on GPU.
1233-
# See https://github.com/pytorch/torchcodec/issues/776
1234-
1235-
asset = H265_10BITS
1236-
1237-
decoder = VideoDecoder(asset.path, device="cuda")
1238-
with pytest.raises(
1239-
RuntimeError,
1240-
match="The AVFrame is p010le, but we expected AV_PIX_FMT_NV12.",
1241-
):
1242-
decoder.get_frame_at(0)
1243-
12441228
@needs_cuda
12451229
def test_10bit_gpu_fallsback_to_cpu(self):
12461230
# Test for 10-bit videos that aren't supported by NVDEC: we decode and
@@ -1272,12 +1256,13 @@ def test_10bit_gpu_fallsback_to_cpu(self):
12721256
frames_cpu = decoder_cpu.get_frames_at(frame_indices).data
12731257
assert_frames_equal(frames_gpu.cpu(), frames_cpu)
12741258

1259+
@pytest.mark.parametrize("device", all_supported_devices())
12751260
@pytest.mark.parametrize("asset", (H264_10BITS, H265_10BITS))
1276-
def test_10bit_videos_cpu(self, asset):
1277-
# This just validates that we can decode 10-bit videos on CPU.
1261+
def test_10bit_videos(self, device, asset):
1262+
# This just validates that we can decode 10-bit videos.
12781263
# TODO validate against the ref that the decoded frames are correct
12791264

1280-
decoder = VideoDecoder(asset.path)
1265+
decoder = VideoDecoder(asset.path, device=device)
12811266
decoder.get_frame_at(10)
12821267

12831268
def setup_frame_mappings(tmp_path, file, stream_index):

0 commit comments

Comments
 (0)