Skip to content

Commit 1541ab8

Browse files
committed
Add resize transform tests; make transforms happen before color
conversion
1 parent 3a2df84 commit 1541ab8

15 files changed

+153
-135
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ void CpuDeviceInterface::initializeVideo(
9494
// If we have any transforms, replace filters_ with the filter strings from
9595
// the transforms. As noted above, we decide between swscale and filtergraph
9696
// when we actually decode a frame.
97+
//
98+
// Note: We explicitly add the format conversion filter at the end to ensure
99+
// that color conversion happens AFTER the transforms, not before. This
100+
// matches the behavior of the reference generation in the test suite.
101+
// Without this, FFmpeg's automatic format negotiation might insert the
102+
// conversion before the transforms, which would produce different results.
97103
std::stringstream filters;
98104
bool first = true;
99105
for (const auto& transform : transforms) {
@@ -104,7 +110,10 @@ void CpuDeviceInterface::initializeVideo(
104110
first = false;
105111
}
106112
if (!transforms.empty()) {
107-
filters_ = filters.str();
113+
// Note that we ensure that the transforms come BEFORE the format
114+
// conversion. This means that the transforms are applied in the frame's
115+
// original pixel format and colorspace.
116+
filters_ = filters.str() + filters_;
108117
}
109118

110119
initialized_ = true;
@@ -324,17 +333,17 @@ void CpuDeviceInterface::createSwsContext(
324333
torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
325334
const UniqueAVFrame& avFrame,
326335
const FrameDims& outputDims) {
327-
enum AVPixelFormat frameFormat =
336+
enum AVPixelFormat avFrameFormat =
328337
static_cast<enum AVPixelFormat>(avFrame->format);
329338

330339
FiltersContext filtersContext(
331340
avFrame->width,
332341
avFrame->height,
333-
frameFormat,
342+
avFrameFormat,
334343
avFrame->sample_aspect_ratio,
335344
outputDims.width,
336345
outputDims.height,
337-
AV_PIX_FMT_RGB24,
346+
/*outputFormat=*/AV_PIX_FMT_RGB24,
338347
filters_,
339348
timeBase_);
340349

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,32 @@ class CpuDeviceInterface : public DeviceInterface {
109109
UniqueSwsContext swsContext_;
110110
SwsFrameContext prevSwsFrameContext_;
111111

112-
// The filter we supply to filterGraph_, if it is used. The default is the
113-
// copy filter, which just copies the input to the output. Computationally, it
114-
// should be a no-op. If we get no user-provided transforms, we will use the
115-
// copy filter. Otherwise, we will construct the string from the transforms.
112+
// We pass the filters to FFmpeg's filtergraph API. It is a simple pipeline
113+
// of what FFmpeg calls "filters" to apply to decoded frames before returning
114+
// them. In the PyTorch ecosystem, we call these "transforms". During
115+
// initialization, we convert the user-supplied transforms into this string of
116+
// filters.
116117
//
117-
// Note that even if we only use the copy filter, we still get the desired
118-
// colorspace conversion. We construct the filtergraph with its output sink
119-
// set to RGB24.
120-
std::string filters_ = "copy";
118+
// Note that we start with the format conversion, and then we ensure that the
119+
// user-supplied filters always happen BEFORE the format conversion. We want
120+
// the user-supplied filters to operate on frames in their original pixel
121+
// format and colorspace.
122+
//
123+
// The reason why is not obvious: when users do not need to perform any
124+
// transforms, or the only transform they apply is a single resize, we can
125+
// sometimes just call swscale directly; see getColorConversionLibrary() for
126+
// the full conditions. A single call to swscale's sws_scale() will always do
127+
// the scaling (resize) in the frame's original pixel format and colorspace.
128+
// In order for calling swscale directly to be an optimization, we must make
129+
// sure that the behavior between calling it directly and using filtergraph
130+
// is identical.
131+
//
132+
// If we had to apply transforms in the output pixel format and colorspace,
133+
// we could achieve that by calling sws_scale() twice: once to do the resize
134+
// and another time to do the format conversion. But that goes against the
135+
// whole point of calling sws_scale() directly, as it's a performance
136+
// optimization.
137+
std::string filters_ = "format=rgb24";
121138

122139
// The flags we supply to swsContext_, if it used. The flags control the
123140
// resizing algorithm. We default to bilinear. Users can override this with a

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -399,68 +399,65 @@ SwrContext* createSwrContext(
399399
return swrContext;
400400
}
401401

402-
AVFilterContext* createBuffersinkFilter(
402+
AVFilterContext* createAVFilterContextWithOptions(
403403
AVFilterGraph* filterGraph,
404-
enum AVPixelFormat outputFormat) {
405-
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
406-
TORCH_CHECK(buffersink != nullptr, "Failed to get buffersink filter.");
407-
408-
AVFilterContext* sinkContext = nullptr;
409-
int status;
404+
const AVFilter* buffer,
405+
const enum AVPixelFormat outputFormat) {
406+
AVFilterContext* avFilterContext = nullptr;
410407
const char* filterName = "out";
411408

412-
enum AVPixelFormat pix_fmts[] = {outputFormat, AV_PIX_FMT_NONE};
409+
enum AVPixelFormat pixFmts[] = {outputFormat, AV_PIX_FMT_NONE};
413410

414411
// av_opt_set_int_list was replaced by av_opt_set_array() in FFmpeg 8.
415412
#if LIBAVUTIL_VERSION_MAJOR >= 60 // FFmpeg >= 8
416413
// Output options like pixel_formats must be set before filter init
417-
sinkContext =
418-
avfilter_graph_alloc_filter(filterGraph, buffersink, filterName);
414+
avFilterContext =
415+
avfilter_graph_alloc_filter(filterGraph, buffer, filterName);
419416
TORCH_CHECK(
420-
sinkContext != nullptr, "Failed to allocate buffersink filter context.");
417+
avFilterContext != nullptr, "Failed to allocate buffer filter context.");
421418

422419
// When setting pix_fmts, only the first element is used, so nb_elems = 1
423420
// AV_PIX_FMT_NONE acts as a terminator for the array in av_opt_set_int_list
424-
status = av_opt_set_array(
425-
sinkContext,
421+
int status = av_opt_set_array(
422+
avFilterContext,
426423
"pixel_formats",
427424
AV_OPT_SEARCH_CHILDREN,
428425
0, // start_elem
429426
1, // nb_elems
430427
AV_OPT_TYPE_PIXEL_FMT,
431-
pix_fmts);
428+
pixFmts);
432429
TORCH_CHECK(
433430
status >= 0,
434-
"Failed to set pixel format for buffersink filter: ",
431+
"Failed to set pixel format for buffer filter: ",
435432
getFFMPEGErrorStringFromErrorCode(status));
436433

437-
status = avfilter_init_str(sinkContext, nullptr);
434+
status = avfilter_init_str(avFilterContext, nullptr);
438435
TORCH_CHECK(
439436
status >= 0,
440-
"Failed to initialize buffersink filter: ",
437+
"Failed to initialize buffer filter: ",
441438
getFFMPEGErrorStringFromErrorCode(status));
442439
#else // FFmpeg <= 7
443440
// For older FFmpeg versions, create filter and then set options
444-
status = avfilter_graph_create_filter(
445-
&sinkContext, buffersink, filterName, nullptr, nullptr, filterGraph);
441+
int status = avfilter_graph_create_filter(
442+
&avFilterContext, buffer, filterName, nullptr, nullptr, filterGraph);
446443
TORCH_CHECK(
447444
status >= 0,
448-
"Failed to create buffersink filter: ",
445+
"Failed to create buffer filter: ",
449446
getFFMPEGErrorStringFromErrorCode(status));
450447

451448
status = av_opt_set_int_list(
452-
sinkContext,
449+
avFilterContext,
453450
"pix_fmts",
454-
pix_fmts,
451+
pixFmts,
455452
AV_PIX_FMT_NONE,
456453
AV_OPT_SEARCH_CHILDREN);
457454
TORCH_CHECK(
458455
status >= 0,
459-
"Failed to set pixel formats for buffersink filter: ",
456+
"Failed to set pixel formats for buffer filter: ",
460457
getFFMPEGErrorStringFromErrorCode(status));
461458
#endif
462459

463-
return sinkContext;
460+
return avFilterContext;
464461
}
465462

466463
UniqueAVFrame convertAudioAVFrameSamples(

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ int64_t computeSafeDuration(
246246
const AVRational& frameRate,
247247
const AVRational& timeBase);
248248

249-
AVFilterContext* createBuffersinkFilter(
249+
AVFilterContext* createAVFilterContextWithOptions(
250250
AVFilterGraph* filterGraph,
251-
enum AVPixelFormat outputFormat);
251+
const AVFilter* buffer,
252+
const enum AVPixelFormat outputFormat);
252253

253254
} // namespace facebook::torchcodec

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ FilterGraph::FilterGraph(
6363
filterGraph_->nb_threads = videoStreamOptions.ffmpegThreadCount.value();
6464
}
6565

66-
const AVFilter* buffersrc = avfilter_get_by_name("buffer");
67-
66+
// Configure the source context.
67+
const AVFilter* bufferSrc = avfilter_get_by_name("buffer");
6868
UniqueAVBufferSrcParameters srcParams(av_buffersrc_parameters_alloc());
6969
TORCH_CHECK(srcParams, "Failed to allocate buffersrc params");
7070

@@ -78,7 +78,7 @@ FilterGraph::FilterGraph(
7878
}
7979

8080
sourceContext_ =
81-
avfilter_graph_alloc_filter(filterGraph_.get(), buffersrc, "in");
81+
avfilter_graph_alloc_filter(filterGraph_.get(), bufferSrc, "in");
8282
TORCH_CHECK(sourceContext_, "Failed to allocate filter graph");
8383

8484
int status = av_buffersrc_parameters_set(sourceContext_, srcParams.get());
@@ -93,23 +93,31 @@ FilterGraph::FilterGraph(
9393
"Failed to create filter graph : ",
9494
getFFMPEGErrorStringFromErrorCode(status));
9595

96-
sinkContext_ =
97-
createBuffersinkFilter(filterGraph_.get(), filtersContext.outputFormat);
96+
// Configure the sink context.
97+
const AVFilter* bufferSink = avfilter_get_by_name("buffersink");
98+
TORCH_CHECK(bufferSink != nullptr, "Failed to get buffersink filter.");
99+
100+
sinkContext_ = createAVFilterContextWithOptions(
101+
filterGraph_.get(), bufferSink, filtersContext.outputFormat);
98102
TORCH_CHECK(
99103
sinkContext_ != nullptr, "Failed to create and configure buffersink");
100104

105+
// Create the filtergraph nodes based on the source and sink contexts.
101106
UniqueAVFilterInOut outputs(avfilter_inout_alloc());
102-
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
103-
104107
outputs->name = av_strdup("in");
105108
outputs->filter_ctx = sourceContext_;
106109
outputs->pad_idx = 0;
107110
outputs->next = nullptr;
111+
112+
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
108113
inputs->name = av_strdup("out");
109114
inputs->filter_ctx = sinkContext_;
110115
inputs->pad_idx = 0;
111116
inputs->next = nullptr;
112117

118+
// Create the filtergraph specified by the filtergraph string in the context
119+
// of the inputs and outputs. Note the dance we have to do with release and
120+
// resetting the output and input nodes because FFmpeg modifies them in place.
113121
AVFilterInOut* outputsTmp = outputs.release();
114122
AVFilterInOut* inputsTmp = inputs.release();
115123
status = avfilter_graph_parse_ptr(
@@ -126,6 +134,7 @@ FilterGraph::FilterGraph(
126134
getFFMPEGErrorStringFromErrorCode(status),
127135
", provided filters: " + filtersContext.filtergraphStr);
128136

137+
// Check filtergraph validity and configure links and formats.
129138
status = avfilter_graph_config(filterGraph_.get(), nullptr);
130139
TORCH_CHECK(
131140
status >= 0,

src/torchcodec/decoders/_video_decoder.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
import numbers
1010
from pathlib import Path
11-
from typing import Any, List, Literal, Optional, Tuple, Union
11+
from typing import Literal, Optional, Tuple, Union
1212

1313
import torch
1414
from torch import device as torch_device, Tensor
@@ -103,7 +103,6 @@ def __init__(
103103
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
104104
num_ffmpeg_threads: int = 1,
105105
device: Optional[Union[str, torch_device]] = "cpu",
106-
transforms: List[Any] = [], # TRANSFORMS TODO: what is the user-facing type?
107106
seek_mode: Literal["exact", "approximate"] = "exact",
108107
custom_frame_mappings: Optional[
109108
Union[str, bytes, io.RawIOBase, io.BufferedReader]
@@ -149,16 +148,13 @@ def __init__(
149148

150149
device_variant = _get_cuda_backend()
151150

152-
transform_specs = make_transform_specs(transforms)
153-
154151
core.add_video_stream(
155152
self._decoder,
156153
stream_index=stream_index,
157154
dimension_order=dimension_order,
158155
num_threads=num_ffmpeg_threads,
159156
device=device,
160157
device_variant=device_variant,
161-
transform_specs=transform_specs,
162158
custom_frame_mappings=custom_frame_mappings_data,
163159
)
164160

@@ -436,22 +432,6 @@ def _get_and_validate_stream_metadata(
436432
)
437433

438434

439-
def make_transform_specs(transforms: List[Any]) -> str:
440-
from torchvision.transforms import v2
441-
442-
transform_specs = []
443-
for transform in transforms:
444-
if isinstance(transform, v2.Resize):
445-
if len(transform.size) != 2:
446-
raise ValueError(
447-
f"Resize transform must have a (height, width) pair for the size, got {transform.size}."
448-
)
449-
transform_specs.append(f"resize, {transform.size[0]}, {transform.size[1]}")
450-
else:
451-
raise ValueError(f"Unsupported transform {transform}.")
452-
return ";".join(transform_specs)
453-
454-
455435
def _read_custom_frame_mappings(
456436
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
457437
) -> tuple[Tensor, Tensor, Tensor]:

test/generate_reference_resources.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ def generate_frame_by_index(
5252
output_bmp = f"{base_path}.bmp"
5353

5454
# Note that we have an exlicit format conversion to rgb24 in our filtergraph specification,
55-
# which always happens BEFORE any of the filters that we receive as input. We do this to
56-
# ensure that the color conversion happens BEFORE the filters, matching the behavior of the
55+
# which always happens AFTER any of the filters that we receive as input. We do this to
56+
# ensure that the color conversion happens AFTER the filters, matching the behavior of the
5757
# torchcodec filtergraph implementation.
58-
#
59-
# Not doing this would result in the color conversion happening AFTER the filters, which
60-
# would result in different color values for the same frame.
61-
filtergraph = f"select='eq(n\\,{frame_index})',format=rgb24"
58+
select = f"select='eq(n\\,{frame_index})'"
59+
format = "format=rgb24"
6260
if filters is not None:
63-
filtergraph = filtergraph + f",{filters}"
61+
filtergraph = ",".join([select, filters, format])
62+
else:
63+
filtergraph = ",".join([select, format])
6464

6565
cmd = [
6666
"ffmpeg",
@@ -99,7 +99,7 @@ def generate_frame_by_timestamp(
9999
convert_image_to_tensor(output_path)
100100

101101

102-
def generate_nasa_13013_references():
102+
def generate_nasa_13013_references_by_index():
103103
# Note: The naming scheme used here must match the naming scheme used to load
104104
# tensors in ./utils.py.
105105
streams = [0, 3]
@@ -108,13 +108,17 @@ def generate_nasa_13013_references():
108108
for frame in frames:
109109
generate_frame_by_index(NASA_VIDEO, frame_index=frame, stream_index=stream)
110110

111+
112+
def generate_nasa_13013_references_by_timestamp():
111113
# Extract individual frames at specific timestamps, including the last frame of the video.
112114
seek_timestamp = [6.0, 6.1, 10.0, 12.979633]
113115
timestamp_name = [f"{seek_timestamp:06f}" for seek_timestamp in seek_timestamp]
114116
for timestamp, name in zip(seek_timestamp, timestamp_name):
115117
output_bmp = f"{NASA_VIDEO.path}.time{name}.bmp"
116118
generate_frame_by_timestamp(NASA_VIDEO.path, timestamp, output_bmp)
117119

120+
121+
def generate_nasa_13013_references_crop():
118122
# Extract frames with specific filters. We have tests that assume these exact filters.
119123
frames = [0, 15, 200, 389]
120124
crop_filter = "crop=300:200:50:35:exact=1"
@@ -123,6 +127,8 @@ def generate_nasa_13013_references():
123127
NASA_VIDEO, frame_index=frame, stream_index=3, filters=crop_filter
124128
)
125129

130+
131+
def generate_nasa_13013_references_resize():
126132
frames = [17, 230, 389]
127133
# Note that the resize algorithm passed to flags is exposed to users,
128134
# but bilinear is the default we use.
@@ -133,6 +139,13 @@ def generate_nasa_13013_references():
133139
)
134140

135141

142+
def generate_nasa_13013_references():
143+
generate_nasa_13013_references_by_index()
144+
generate_nasa_13013_references_by_timestamp()
145+
generate_nasa_13013_references_crop()
146+
generate_nasa_13013_references_resize()
147+
148+
136149
def generate_h265_video_references():
137150
# This video was generated by running the following:
138151
# conda install -c conda-forge x265
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)