Skip to content

Commit 0e61aba

Browse files
committed
Full 180
1 parent f853e3a commit 0e61aba

14 files changed

+23
-65
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ void CpuDeviceInterface::initializeVideo(
4646
// We calculate this value during initilization but we don't refer to it until
4747
// getColorConversionLibrary() is called. Calculating this value during
4848
// initialization saves us from having to save all of the transforms.
49-
areTransformsSwScaleCompatible_ = transforms.empty() ||
50-
(transforms.size() == 1 && transforms[0]->isResize());
49+
areTransformsSwScaleCompatible_ = transforms.empty();
5150

5251
// Note that we do not expose this capability in the public API, only through
5352
// the core API.
@@ -57,16 +56,6 @@ void CpuDeviceInterface::initializeVideo(
5756
userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
5857
ColorConversionLibrary::SWSCALE;
5958

60-
// We can only use swscale when we have a single resize transform. Note that
61-
// we actually decide on whether or not to actually use swscale at the last
62-
// possible moment, when we actually convert the frame. This is because we
63-
// need to know the actual frame dimensions.
64-
if (transforms.size() == 1 && transforms[0]->isResize()) {
65-
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
66-
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
67-
swsFlags_ = resize->getSwsFlags();
68-
}
69-
7059
// If we have any transforms, replace filters_ with the filter strings from
7160
// the transforms. As noted above, we decide between swscale and filtergraph
7261
// when we actually decode a frame.
@@ -83,7 +72,7 @@ void CpuDeviceInterface::initializeVideo(
8372
// Note that we ensure that the transforms come BEFORE the format
8473
// conversion. This means that the transforms are applied in the frame's
8574
// original pixel format and colorspace.
86-
filters_ = filters.str() + "," + filters_;
75+
filters_ += "," + filters.str();
8776
}
8877

8978
initialized_ = true;
@@ -221,6 +210,11 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
221210
enum AVPixelFormat frameFormat =
222211
static_cast<enum AVPixelFormat>(avFrame->format);
223212

213+
TORCH_CHECK(
214+
avFrame->height == outputDims.height &&
215+
avFrame->width == outputDims.width,
216+
"Input dimensions are not equal to output dimensions; resize for sws_scale() is not yet supported.");
217+
224218
// We need to compare the current frame context with our previous frame
225219
// context. If they are different, then we need to re-create our colorspace
226220
// conversion objects. We create our colorspace conversion objects late so
@@ -237,7 +231,11 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
237231

238232
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
239233
swsContext_ = createSwsContext(
240-
swsFrameContext, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_);
234+
swsFrameContext,
235+
avFrame->colorspace,
236+
/*outputFormat=*/AV_PIX_FMT_RGB24,
237+
/*swsFlags=*/0); // We don't set any flags because we don't yet use
238+
// sws_scale() for resizing.
241239
prevSwsFrameContext_ = swsFrameContext;
242240
}
243241

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class CpuDeviceInterface : public DeviceInterface {
9393
// initialization, we convert the user-supplied transforms into this string of
9494
// filters.
9595
//
96+
// TODO: make sure Scott corrects the below:
9697
// Note that we start with just the format conversion, and then we ensure that
9798
// the user-supplied filters always happen BEFORE the format conversion. We
9899
// want the user-supplied filters to operate on frames in their original pixel

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ struct SwsFrameContext {
275275
UniqueSwsContext createSwsContext(
276276
const SwsFrameContext& swsFrameContext,
277277
AVColorSpace colorspace,
278-
AVPixelFormat outputFormat = AV_PIX_FMT_RGB24,
279-
int swsFlags = SWS_BILINEAR);
278+
AVPixelFormat outputFormat,
279+
int swsFlags);
280280

281281
} // namespace facebook::torchcodec

src/torchcodec/_core/Transform.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,11 @@ std::string toFilterGraphInterpolation(
2525
}
2626
}
2727

28-
int toSwsInterpolation(ResizeTransform::InterpolationMode mode) {
29-
switch (mode) {
30-
case ResizeTransform::InterpolationMode::BILINEAR:
31-
return SWS_BILINEAR;
32-
default:
33-
TORCH_CHECK(
34-
false,
35-
"Unknown interpolation mode: " +
36-
std::to_string(static_cast<int>(mode)));
37-
}
38-
}
39-
4028
} // namespace
4129

4230
std::string ResizeTransform::getFilterGraphCpu() const {
31+
// Note that we turn on gamma correct scaling. This produces results that are
32+
// closer to what TorchVision's resize produces.
4333
return "scale=" + std::to_string(outputDims_.width) + ":" +
4434
std::to_string(outputDims_.height) +
4535
":flags=" + toFilterGraphInterpolation(interpolationMode_);
@@ -49,14 +39,6 @@ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
4939
return outputDims_;
5040
}
5141

52-
bool ResizeTransform::isResize() const {
53-
return true;
54-
}
55-
56-
int ResizeTransform::getSwsFlags() const {
57-
return toSwsInterpolation(interpolationMode_);
58-
}
59-
6042
CropTransform::CropTransform(const FrameDims& dims, int x, int y)
6143
: outputDims_(dims), x_(x), y_(y) {
6244
TORCH_CHECK(x_ >= 0, "Crop x position must be >= 0, got: ", x_);

src/torchcodec/_core/Transform.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ class Transform {
2929
return std::nullopt;
3030
}
3131

32-
// The ResizeTransform is special, because it is the only transform that
33-
// swscale can handle.
34-
virtual bool isResize() const {
35-
return false;
36-
}
37-
3832
// The validity of some transforms depends on the characteristics of the
3933
// AVStream they're being applied to. For example, some transforms will
4034
// specify coordinates inside a frame, we need to validate that those are
@@ -58,9 +52,6 @@ class ResizeTransform : public Transform {
5852

5953
std::string getFilterGraphCpu() const override;
6054
std::optional<FrameDims> getOutputFrameDims() const override;
61-
bool isResize() const override;
62-
63-
int getSwsFlags() const;
6455

6556
private:
6657
FrameDims outputDims_;

test/generate_reference_resources.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,13 @@ def generate_frame_by_index(
5252
output_bmp = f"{base_path}.bmp"
5353

5454
# Note that we have an exlicit format conversion to rgb24 in our filtergraph
55-
# specification, and we always place the user-supplied filters BEFORE the
55+
# specification, and we always place the user-supplied filters AFTER the
5656
# format conversion. We do this to ensure that the filters are applied in
57-
# the pixel format and colorspace of the input frames. This behavior matches
58-
# the TorchCodec implementation.
57+
# RGB24 colorspace, which matches TorchCodec's behavior.
5958
select = f"select='eq(n\\,{frame_index})'"
6059
format = "format=rgb24"
6160
if filters is not None:
62-
filtergraph = ",".join([select, filters, format])
61+
filtergraph = ",".join([select, format, filters])
6362
else:
6463
filtergraph = ",".join([select, format])
6564

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)