@@ -51,31 +51,70 @@ void CpuDeviceInterface::initialize(
5151 const VideoStreamOptions& videoStreamOptions,
5252 const std::vector<std::unique_ptr<Transform>>& transforms,
5353 const AVRational& timeBase,
54- const FrameDims& outputDims) {
54+ [[maybe_unused]] const FrameDims& metadataDims,
55+ const std::optional<FrameDims>& resizedOutputDims) {
5556 videoStreamOptions_ = videoStreamOptions;
5657 timeBase_ = timeBase;
57- outputDims_ = outputDims;
58-
59- // We want to use swscale for color conversion if possible because it is
60- // faster than filtergraph. The following are the conditions we need to meet
61- // to use it.
58+ resizedOutputDims_ = resizedOutputDims;
6259
6360 // We can only use swscale when we have a single resize transform. Note that
6461 // this means swscale will not support the case of having several,
6562 // back-to-base resizes. There's no strong reason to even do that, but if
6663 // someone does, it's more correct to implement that with filtergraph.
67- bool areTransformsSwScaleCompatible = transforms.empty () ||
64+ //
65+ // We calculate this value during initilization but we don't refer to it until
66+ // getColorConversionLibrary() is called. Calculating this value during
67+ // initialization saves us from having to save all of the transforms.
68+ areTransformsSwScaleCompatible_ = transforms.empty () ||
6869 (transforms.size () == 1 && transforms[0 ]->isResize ());
6970
70- // swscale requires widths to be multiples of 32:
71- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
72- bool isWidthSwScaleCompatible = (outputDims_.width % 32 ) == 0 ;
73-
7471 // Note that we do not expose this capability in the public API, only through
7572 // the core API.
76- bool userRequestedSwScale = videoStreamOptions_.colorConversionLibrary ==
73+ //
74+ // Same as above, we calculate this value during initialization and refer to
75+ // it in getColorConversionLibrary().
76+ userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary ==
7777 ColorConversionLibrary::SWSCALE;
7878
79+ // We can only use swscale when we have a single resize transform. Note that
80+ // we actually decide on whether or not to actually use swscale at the last
81+ // possible moment, when we actually convert the frame. This is because we
82+ // need to know the actual frame dimensions.
83+ if (transforms.size () == 1 && transforms[0 ]->isResize ()) {
84+ auto resize = dynamic_cast <ResizeTransform*>(transforms[0 ].get ());
85+ TORCH_CHECK (resize != nullptr , " ResizeTransform expected but not found!" )
86+ swsFlags_ = resize->getSwsFlags ();
87+ }
88+
89+ // If we have any transforms, replace filters_ with the filter strings from
90+ // the transforms. As noted above, we decide between swscale and filtergraph
91+ // when we actually decode a frame.
92+ std::stringstream filters;
93+ bool first = true ;
94+ for (const auto & transform : transforms) {
95+ if (!first) {
96+ filters << " ," ;
97+ }
98+ filters << transform->getFilterGraphCpu ();
99+ first = false ;
100+ }
101+ if (!transforms.empty ()) {
102+ filters_ = filters.str ();
103+ }
104+
105+ initialized_ = true ;
106+ }
107+
108+ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary (
109+ const FrameDims& outputDims) {
110+ // swscale requires widths to be multiples of 32:
111+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
112+ bool isWidthSwScaleCompatible = (outputDims.width % 32 ) == 0 ;
113+
114+ // We want to use swscale for color conversion if possible because it is
115+ // faster than filtergraph. The following are the conditions we need to meet
116+ // to use it.
117+ //
79118 // Note that we treat the transform limitation differently from the width
80119 // limitation. That is, we consider the transforms being compatible with
81120 // swscale as a hard requirement. If the transforms are not compatiable,
@@ -86,38 +125,12 @@ void CpuDeviceInterface::initialize(
86125 // behavior. Since we don't expose the ability to choose swscale or
87126 // filtergraph in our public API, this is probably okay. It's also the only
88127 // way that we can be certain we are testing one versus the other.
89- if (areTransformsSwScaleCompatible &&
90- (userRequestedSwScale || isWidthSwScaleCompatible)) {
91- colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
92-
93- // We established above that if the transforms are swscale compatible and
94- // non-empty, then they must have only one transform, and that transform is
95- // ResizeTransform.
96- if (!transforms.empty ()) {
97- auto resize = dynamic_cast <ResizeTransform*>(transforms[0 ].get ());
98- TORCH_CHECK (resize != nullptr , " ResizeTransform expected but not found!" )
99- swsFlags_ = resize->getSwsFlags ();
100- }
128+ if (areTransformsSwScaleCompatible_ &&
129+ (userRequestedSwScale_ || isWidthSwScaleCompatible)) {
130+ return ColorConversionLibrary::SWSCALE;
101131 } else {
102- colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH;
103-
104- // If we have any transforms, replace filters_ with the filter strings from
105- // the transforms.
106- std::stringstream filters;
107- bool first = true ;
108- for (const auto & transform : transforms) {
109- if (!first) {
110- filters << " ," ;
111- }
112- filters << transform->getFilterGraphCpu ();
113- first = false ;
114- }
115- if (!transforms.empty ()) {
116- filters_ = filters.str ();
117- }
132+ return ColorConversionLibrary::FILTERGRAPH;
118133 }
119-
120- initialized_ = true ;
121134}
122135
123136// Note [preAllocatedOutputTensor with swscale and filtergraph]:
@@ -134,24 +147,42 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
134147 FrameOutput& frameOutput,
135148 std::optional<torch::Tensor> preAllocatedOutputTensor) {
136149 TORCH_CHECK (initialized_, " CpuDeviceInterface was not initialized." );
150+
151+ // Note that we ignore the dimensions from the metadata; we don't even bother
152+ // storing them. The resized dimensions take priority. If we don't have any,
153+ // then we use the dimensions from the actual decoded frame. We use the actual
154+ // decoded frame and not the metadata for two reasons:
155+ //
156+ // 1. Metadata may be wrong. If we access to more accurate information, we
157+ // should use it.
158+ // 2. Video streams can have variable resolution. This fact is not captured
159+ // in the stream metadata.
160+ //
161+ // Both cases cause problems for our batch APIs, as we allocate
162+ // FrameBatchOutputs based on the the stream metadata. But single-frame APIs
163+ // can still work in such situations, so they should.
164+ auto outputDims =
165+ resizedOutputDims_.value_or (FrameDims (avFrame->width , avFrame->height ));
166+
137167 if (preAllocatedOutputTensor.has_value ()) {
138168 auto shape = preAllocatedOutputTensor.value ().sizes ();
139169 TORCH_CHECK (
140- (shape.size () == 3 ) && (shape[0 ] == outputDims_ .height ) &&
141- (shape[1 ] == outputDims_ .width ) && (shape[2 ] == 3 ),
170+ (shape.size () == 3 ) && (shape[0 ] == outputDims .height ) &&
171+ (shape[1 ] == outputDims .width ) && (shape[2 ] == 3 ),
142172 " Expected pre-allocated tensor of shape " ,
143- outputDims_ .height ,
173+ outputDims .height ,
144174 " x" ,
145- outputDims_ .width ,
175+ outputDims .width ,
146176 " x3, got " ,
147177 shape);
148178 }
149179
180+ auto colorConversionLibrary = getColorConversionLibrary (outputDims);
150181 torch::Tensor outputTensor;
151182 enum AVPixelFormat frameFormat =
152183 static_cast <enum AVPixelFormat>(avFrame->format );
153184
154- if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) {
185+ if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
155186 // We need to compare the current frame context with our previous frame
156187 // context. If they are different, then we need to re-create our colorspace
157188 // conversion objects. We create our colorspace conversion objects late so
@@ -163,11 +194,11 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
163194 avFrame->width ,
164195 avFrame->height ,
165196 frameFormat,
166- outputDims_ .width ,
167- outputDims_ .height );
197+ outputDims .width ,
198+ outputDims .height );
168199
169200 outputTensor = preAllocatedOutputTensor.value_or (
170- allocateEmptyHWCTensor (outputDims_ , torch::kCPU ));
201+ allocateEmptyHWCTensor (outputDims , torch::kCPU ));
171202
172203 if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
173204 createSwsContext (swsFrameContext, avFrame->colorspace );
@@ -180,42 +211,42 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
180211 // the expected height.
181212 // TODO: Can we do the same check for width?
182213 TORCH_CHECK (
183- resultHeight == outputDims_ .height ,
184- " resultHeight != outputDims_ .height: " ,
214+ resultHeight == outputDims .height ,
215+ " resultHeight != outputDims .height: " ,
185216 resultHeight,
186217 " != " ,
187- outputDims_ .height );
218+ outputDims .height );
188219
189220 frameOutput.data = outputTensor;
190- } else if (colorConversionLibrary_ == ColorConversionLibrary::FILTERGRAPH) {
221+ } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
191222 FiltersContext filtersContext (
192223 avFrame->width ,
193224 avFrame->height ,
194225 frameFormat,
195226 avFrame->sample_aspect_ratio ,
196- outputDims_ .width ,
197- outputDims_ .height ,
227+ outputDims .width ,
228+ outputDims .height ,
198229 AV_PIX_FMT_RGB24,
199230 filters_,
200231 timeBase_);
201232
202- if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
203- filterGraphContext_ =
233+ if (!filterGraph_ || prevFiltersContext_ != filtersContext) {
234+ filterGraph_ =
204235 std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_);
205236 prevFiltersContext_ = std::move (filtersContext);
206237 }
207- outputTensor = rgbAVFrameToTensor (filterGraphContext_ ->convert (avFrame));
238+ outputTensor = rgbAVFrameToTensor (filterGraph_ ->convert (avFrame));
208239
209240 // Similarly to above, if this check fails it means the frame wasn't
210241 // reshaped to its expected dimensions by filtergraph.
211242 auto shape = outputTensor.sizes ();
212243 TORCH_CHECK (
213- (shape.size () == 3 ) && (shape[0 ] == outputDims_ .height ) &&
214- (shape[1 ] == outputDims_ .width ) && (shape[2 ] == 3 ),
244+ (shape.size () == 3 ) && (shape[0 ] == outputDims .height ) &&
245+ (shape[1 ] == outputDims .width ) && (shape[2 ] == 3 ),
215246 " Expected output tensor of shape " ,
216- outputDims_ .height ,
247+ outputDims .height ,
217248 " x" ,
218- outputDims_ .width ,
249+ outputDims .width ,
219250 " x3, got " ,
220251 shape);
221252
@@ -231,7 +262,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
231262 TORCH_CHECK (
232263 false ,
233264 " Invalid color conversion library: " ,
234- static_cast <int >(colorConversionLibrary_ ));
265+ static_cast <int >(colorConversionLibrary ));
235266 }
236267}
237268
0 commit comments