Skip to content

Commit 8e7072f

Browse files
committed
Handle swscale correctly
1 parent 781f956 commit 8e7072f

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ void CpuDeviceInterface::initialize(
6565
// choose color conversion library publicly; we only use this ability
6666
// internally.
6767

68-
// If any transforms are not swscale compatible, then we can't use swscale.
69-
bool areTransformsSwScaleCompatible = true;
70-
for (const auto& transform : transforms) {
71-
areTransformsSwScaleCompatible =
72-
areTransformsSwScaleCompatible && transform->isSwScaleCompatible();
73-
}
68+
// We can only use swscale when we have a single resize transform. Note that
69+
// this means swscale will not support the case of having several,
70+
// back-to-base resizes. There's no strong reason to even do that, but if
71+
// someone does, it's more correct to implement that with filtergraph.
72+
bool areTransformsSwScaleCompatible = transforms.empty() ||
73+
(transforms.size() == 1 && transforms[0]->isResize());
7474

7575
// swscale requires widths to be multiples of 32:
7676
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
@@ -95,8 +95,14 @@ void CpuDeviceInterface::initialize(
9595
(userRequestedSwScale || isWidthSwScaleCompatible)) {
9696
colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
9797

98-
// SCOTT NEXT TODO: set swsFlags_
99-
98+
// We established above that if the transforms are swscale compatible and
99+
// non-empty, then they must have only one transforms, and that transform is
100+
// ResizeTransform.
101+
if (!transforms.empty()) {
102+
auto resize = dynamic_cast<ResizeTransform*>(transforms[0].get());
103+
TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!")
104+
swsFlags_ = resize->getSwsFlags();
105+
}
100106
} else {
101107
colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH;
102108

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class CpuDeviceInterface : public DeviceInterface {
7272
FrameDims outputDims_;
7373

7474
// If we use swscale for resizing, the flags control the resizing algorithm.
75-
// We exclusively get the value from the ResizeTransform.
76-
int swsFlags_ = 0;
75+
// We default to bilinear. Users can override this with a ResizeTransform.
76+
int swsFlags_ = SWS_BILINEAR;
7777

7878
// The copy filter just copies the input to the output. Computationally, it
7979
// should be a no-op. If we get no user-provided transforms, we will use the

src/torchcodec/_core/Transform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ std::optional<FrameDims> ResizeTransform::getOutputFrameDims() const {
5656
return FrameDims(width_, height_);
5757
}
5858

59-
bool ResizeTransform::isSwScaleCompatible() const {
59+
bool ResizeTransform::isResize() const {
6060
return true;
6161
}
6262

src/torchcodec/_core/Transform.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class Transform {
2828
return std::nullopt;
2929
}
3030

31-
virtual bool isSwScaleCompatible() const {
31+
// The ResizeTransform is special, because it is the only transform that
32+
// swscale can handle.
33+
virtual bool isResize() const {
3234
return false;
3335
}
3436
};
@@ -47,7 +49,7 @@ class ResizeTransform : public Transform {
4749

4850
std::string getFilterGraphCpu() const override;
4951
std::optional<FrameDims> getOutputFrameDims() const override;
50-
bool isSwScaleCompatible() const override;
52+
bool isResize() const override;
5153

5254
int getSwsFlags() const;
5355

0 commit comments

Comments
 (0)