Skip to content

Commit 26095f3

Browse files
committed
Use separate frame/filter contexts for sws and FilterGraph
Signed-off-by: Dmitry Rogozhkin <[email protected]>
1 parent 028b612 commit 26095f3

File tree

4 files changed

+70
-38
lines changed

4 files changed

+70
-38
lines changed

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ static bool g_cpu = registerDeviceInterface(
1515

1616
} // namespace
1717

18+
bool CpuDeviceInterface::SwsFrameContext::operator==(
19+
const CpuDeviceInterface::SwsFrameContext& other) const {
20+
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
21+
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
22+
outputHeight == other.outputHeight;
23+
}
24+
25+
bool CpuDeviceInterface::SwsFrameContext::operator!=(
26+
const CpuDeviceInterface::SwsFrameContext& other) const {
27+
return !(*this == other);
28+
}
29+
1830
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
1931
: DeviceInterface(device) {
2032
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
@@ -56,31 +68,8 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
5668
}
5769

5870
torch::Tensor outputTensor;
59-
// We need to compare the current frame context with our previous frame
60-
// context. If they are different, then we need to re-create our colorspace
61-
// conversion objects. We create our colorspace conversion objects late so
62-
// that we don't have to depend on the unreliable metadata in the header.
63-
// And we sometimes re-create them because it's possible for frame
64-
// resolution to change mid-stream. Finally, we want to reuse the colorspace
65-
// conversion objects as much as possible for performance reasons.
6671
enum AVPixelFormat frameFormat =
6772
static_cast<enum AVPixelFormat>(avFrame->format);
68-
FiltersContext filtersContext;
69-
70-
filtersContext.inputWidth = avFrame->width;
71-
filtersContext.inputHeight = avFrame->height;
72-
filtersContext.inputFormat = frameFormat;
73-
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
74-
filtersContext.outputWidth = expectedOutputWidth;
75-
filtersContext.outputHeight = expectedOutputHeight;
76-
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
77-
filtersContext.timeBase = timeBase;
78-
79-
std::stringstream filters;
80-
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
81-
filters << ":sws_flags=bilinear";
82-
83-
filtersContext.filters = filters.str();
8473

8574
// By default, we want to use swscale for color conversion because it is
8675
// faster. However, it has width requirements, so we may need to fall back
@@ -101,12 +90,27 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
10190
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
10291

10392
if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
93+
// We need to compare the current frame context with our previous frame
94+
// context. If they are different, then we need to re-create our colorspace
95+
// conversion objects. We create our colorspace conversion objects late so
96+
// that we don't have to depend on the unreliable metadata in the header.
97+
// And we sometimes re-create them because it's possible for frame
98+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
99+
// conversion objects as much as possible for performance reasons.
100+
SwsFrameContext swsFrameContext;
101+
102+
swsFrameContext.inputWidth = avFrame->width;
103+
swsFrameContext.inputHeight = avFrame->height;
104+
swsFrameContext.inputFormat = frameFormat;
105+
swsFrameContext.outputWidth = expectedOutputWidth;
106+
swsFrameContext.outputHeight = expectedOutputHeight;
107+
104108
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
105109
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
106110

107-
if (!swsContext_ || prevFiltersContext_ != filtersContext) {
108-
createSwsContext(filtersContext, avFrame->colorspace);
109-
prevFiltersContext_ = std::move(filtersContext);
111+
if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
112+
createSwsContext(swsFrameContext, avFrame->colorspace);
113+
prevSwsFrameContext_ = swsFrameContext;
110114
}
111115
int resultHeight =
112116
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
@@ -122,6 +126,23 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
122126

123127
frameOutput.data = outputTensor;
124128
} else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
129+
FiltersContext filtersContext;
130+
131+
filtersContext.inputWidth = avFrame->width;
132+
filtersContext.inputHeight = avFrame->height;
133+
filtersContext.inputFormat = frameFormat;
134+
filtersContext.inputAspectRatio = avFrame->sample_aspect_ratio;
135+
filtersContext.outputWidth = expectedOutputWidth;
136+
filtersContext.outputHeight = expectedOutputHeight;
137+
filtersContext.outputFormat = AV_PIX_FMT_RGB24;
138+
filtersContext.timeBase = timeBase;
139+
140+
std::stringstream filters;
141+
filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
142+
filters << ":sws_flags=bilinear";
143+
144+
filtersContext.filtergraphStr = filters.str();
145+
125146
if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
126147
filterGraphContext_ =
127148
std::make_unique<FilterGraph>(filtersContext, videoStreamOptions);
@@ -196,15 +217,15 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
196217
}
197218

198219
void CpuDeviceInterface::createSwsContext(
199-
const FiltersContext& filtersContext,
220+
const SwsFrameContext& swsFrameContext,
200221
const enum AVColorSpace colorspace) {
201222
SwsContext* swsContext = sws_getContext(
202-
filtersContext.inputWidth,
203-
filtersContext.inputHeight,
204-
filtersContext.inputFormat,
205-
filtersContext.outputWidth,
206-
filtersContext.outputHeight,
207-
filtersContext.outputFormat,
223+
swsFrameContext.inputWidth,
224+
swsFrameContext.inputHeight,
225+
swsFrameContext.inputFormat,
226+
swsFrameContext.outputWidth,
227+
swsFrameContext.outputHeight,
228+
AV_PIX_FMT_RGB24,
208229
SWS_BILINEAR,
209230
nullptr,
210231
nullptr,

src/torchcodec/_core/CpuDeviceInterface.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,18 @@ class CpuDeviceInterface : public DeviceInterface {
4242
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
4343
const UniqueAVFrame& avFrame);
4444

45+
struct SwsFrameContext {
46+
int inputWidth;
47+
int inputHeight;
48+
AVPixelFormat inputFormat;
49+
int outputWidth;
50+
int outputHeight;
51+
bool operator==(const SwsFrameContext&) const;
52+
bool operator!=(const SwsFrameContext&) const;
53+
};
54+
4555
void createSwsContext(
46-
const FiltersContext& filtersContext,
56+
const SwsFrameContext& swsFrameContext,
4757
const enum AVColorSpace colorspace);
4858

4959
// color-conversion fields. Only one of FilterGraphContext and
@@ -53,6 +63,7 @@ class CpuDeviceInterface : public DeviceInterface {
5363

5464
// Used to know whether a new FilterGraphContext or UniqueSwsContext should
5565
// be created before decoding a new frame.
66+
SwsFrameContext prevSwsFrameContext_;
5667
FiltersContext prevFiltersContext_;
5768
};
5869

src/torchcodec/_core/FilterGraph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ bool FiltersContext::operator==(const FiltersContext& other) {
2121
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
2222
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
2323
outputHeight == other.outputHeight &&
24-
outputFormat == other.outputFormat && filters == other.filters &&
25-
timeBase == other.timeBase &&
24+
outputFormat == other.outputFormat &&
25+
filtergraphStr == other.filtergraphStr && timeBase == other.timeBase &&
2626
hwFramesCtx.get() == other.hwFramesCtx.get();
2727
}
2828

@@ -108,7 +108,7 @@ FilterGraph::FilterGraph(
108108
AVFilterInOut* inputsTmp = inputs.release();
109109
status = avfilter_graph_parse_ptr(
110110
filterGraph_.get(),
111-
filtersContext.filters.c_str(),
111+
filtersContext.filtergraphStr.c_str(),
112112
&inputsTmp,
113113
&outputsTmp,
114114
nullptr);

src/torchcodec/_core/FilterGraph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct FiltersContext {
2020
int outputHeight = 0;
2121
AVPixelFormat outputFormat = AV_PIX_FMT_NONE;
2222

23-
std::string filters;
23+
std::string filtergraphStr;
2424
AVRational timeBase = {0, 0};
2525
UniqueAVBufferRef hwFramesCtx;
2626

0 commit comments

Comments
 (0)