@@ -46,6 +46,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
4646 device_.type () == torch::kCPU , " Unsupported device: " , device_.str ());
4747}
4848
49+ void CpuDeviceInterface::initialize (
50+ [[maybe_unused]] AVCodecContext* codecContext,
51+ const VideoStreamOptions& videoStreamOptions,
52+ const std::vector<std::unique_ptr<Transform>>& transforms,
53+ const AVRational& timeBase,
54+ const FrameDims& outputDims) {
55+ videoStreamOptions_ = videoStreamOptions;
56+ timeBase_ = timeBase;
57+ outputDims_ = outputDims;
58+
59+ // TODO: rationalize comment below with new stuff.
60+ // By default, we want to use swscale for color conversion because it is
61+ // faster. However, it has width requirements, so we may need to fall back
62+ // to filtergraph. We also need to respect what was requested from the
63+ // options; we respect the options unconditionally, so it's possible for
64+ // swscale's width requirements to be violated. We don't expose the ability to
65+ // choose color conversion library publicly; we only use this ability
66+ // internally.
67+
68+ // If any transforms are not swscale compatible, then we can't use swscale.
69+ bool areTransformsSwScaleCompatible = true ;
70+ for (const auto & transform : transforms) {
71+ areTransformsSwScaleCompatible =
72+ areTransformsSwScaleCompatible && transform->isSwScaleCompatible ();
73+ }
74+
75+ // swscale requires widths to be multiples of 32:
76+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
77+ bool isWidthSwScaleCompatible = (outputDims_.width % 32 ) == 0 ;
78+
79+ bool userRequestedSwScale =
80+ videoStreamOptions_.colorConversionLibrary .has_value () &&
81+ videoStreamOptions_.colorConversionLibrary .value () ==
82+ ColorConversionLibrary::SWSCALE;
83+
84+ // Note that we treat the transform limitation differently from the width
85+ // limitation. That is, we consider the transforms being compatible with
86+ // sws_scale as a hard requirement. If the transforms are not compatiable,
87+ // then we will end up not applying the transforms, and that is wrong.
88+ //
89+ // The width requirement, however, is a soft requirement. Even if we don't
90+ // meet it, we let the user override it. We have tests that depend on this
91+ // behavior. Since we don't expose the ability to choose swscale or
92+ // filtergraph in our public API, this is probably okay. It's also the only
93+ // way that we can be certain we are testing one versus the other.
94+ if (areTransformsSwScaleCompatible &&
95+ (userRequestedSwScale || isWidthSwScaleCompatible)) {
96+ colorConversionLibrary_ = ColorConversionLibrary::SWSCALE;
97+ } else {
98+ colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH;
99+
100+ // If we have any transforms, replace filters_ with the filter strings from
101+ // the transforms.
102+ std::stringstream filters;
103+ bool first = true ;
104+ for (const auto & transform : transforms) {
105+ if (!first) {
106+ filters << " ," ;
107+ }
108+ filters << transform->getFilterGraphCpu ();
109+ first = false ;
110+ }
111+ if (!transforms.empty ()) {
112+ filters_ = filters.str ();
113+ }
114+ }
115+ }
116+
49117// Note [preAllocatedOutputTensor with swscale and filtergraph]:
50118// Callers may pass a pre-allocated tensor, where the output.data tensor will
51119// be stored. This parameter is honored in any case, but it only leads to a
@@ -56,25 +124,18 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
56124// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
57125// `dimension_order` parameter. It's up to callers to re-shape it if needed.
58126void CpuDeviceInterface::convertAVFrameToFrameOutput (
59- const VideoStreamOptions& videoStreamOptions,
60- const AVRational& timeBase,
61127 UniqueAVFrame& avFrame,
62128 FrameOutput& frameOutput,
63129 std::optional<torch::Tensor> preAllocatedOutputTensor) {
64- auto frameDims =
65- getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
66- int expectedOutputHeight = frameDims.height ;
67- int expectedOutputWidth = frameDims.width ;
68-
69130 if (preAllocatedOutputTensor.has_value ()) {
70131 auto shape = preAllocatedOutputTensor.value ().sizes ();
71132 TORCH_CHECK (
72- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight ) &&
73- (shape[1 ] == expectedOutputWidth ) && (shape[2 ] == 3 ),
133+ (shape.size () == 3 ) && (shape[0 ] == outputDims_. height ) &&
134+ (shape[1 ] == outputDims_. width ) && (shape[2 ] == 3 ),
74135 " Expected pre-allocated tensor of shape " ,
75- expectedOutputHeight ,
136+ outputDims_. height ,
76137 " x" ,
77- expectedOutputWidth ,
138+ outputDims_. width ,
78139 " x3, got " ,
79140 shape);
80141 }
@@ -83,25 +144,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
83144 enum AVPixelFormat frameFormat =
84145 static_cast <enum AVPixelFormat>(avFrame->format );
85146
86- // By default, we want to use swscale for color conversion because it is
87- // faster. However, it has width requirements, so we may need to fall back
88- // to filtergraph. We also need to respect what was requested from the
89- // options; we respect the options unconditionally, so it's possible for
90- // swscale's width requirements to be violated. We don't expose the ability to
91- // choose color conversion library publicly; we only use this ability
92- // internally.
93-
94- // swscale requires widths to be multiples of 32:
95- // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
96- // so we fall back to filtergraph if the width is not a multiple of 32.
97- auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
98- ? ColorConversionLibrary::SWSCALE
99- : ColorConversionLibrary::FILTERGRAPH;
100-
101- ColorConversionLibrary colorConversionLibrary =
102- videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
103-
104- if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
147+ if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) {
105148 // We need to compare the current frame context with our previous frame
106149 // context. If they are different, then we need to re-create our colorspace
107150 // conversion objects. We create our colorspace conversion objects late so
@@ -113,11 +156,11 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
113156 avFrame->width ,
114157 avFrame->height ,
115158 frameFormat,
116- expectedOutputWidth ,
117- expectedOutputHeight );
159+ outputDims_. width ,
160+ outputDims_. height );
118161
119- outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
120- expectedOutputHeight, expectedOutputWidth , torch::kCPU ));
162+ outputTensor = preAllocatedOutputTensor.value_or (
163+ allocateEmptyHWCTensor (outputDims_ , torch::kCPU ));
121164
122165 if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) {
123166 createSwsContext (swsFrameContext, avFrame->colorspace );
@@ -129,34 +172,28 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
129172 // the expected height.
130173 // TODO: Can we do the same check for width?
131174 TORCH_CHECK (
132- resultHeight == expectedOutputHeight ,
133- " resultHeight != expectedOutputHeight : " ,
175+ resultHeight == outputDims_. height ,
176+ " resultHeight != outputDims_.height : " ,
134177 resultHeight,
135178 " != " ,
136- expectedOutputHeight );
179+ outputDims_. height );
137180
138181 frameOutput.data = outputTensor;
139- } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
140- // See comment above in swscale branch about the filterGraphContext_
141- // creation. creation
142- std::stringstream filters;
143- filters << " scale=" << expectedOutputWidth << " :" << expectedOutputHeight;
144- filters << " :sws_flags=bilinear" ;
145-
182+ } else if (colorConversionLibrary_ == ColorConversionLibrary::FILTERGRAPH) {
146183 FiltersContext filtersContext (
147184 avFrame->width ,
148185 avFrame->height ,
149186 frameFormat,
150187 avFrame->sample_aspect_ratio ,
151- expectedOutputWidth ,
152- expectedOutputHeight ,
188+ outputDims_. width ,
189+ outputDims_. height ,
153190 AV_PIX_FMT_RGB24,
154- filters. str () ,
155- timeBase );
191+ filters_ ,
192+ timeBase_ );
156193
157194 if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) {
158195 filterGraphContext_ =
159- std::make_unique<FilterGraph>(filtersContext, videoStreamOptions );
196+ std::make_unique<FilterGraph>(filtersContext, videoStreamOptions_ );
160197 prevFiltersContext_ = std::move (filtersContext);
161198 }
162199 outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
@@ -165,12 +202,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
165202 // reshaped to its expected dimensions by filtergraph.
166203 auto shape = outputTensor.sizes ();
167204 TORCH_CHECK (
168- (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight ) &&
169- (shape[1 ] == expectedOutputWidth ) && (shape[2 ] == 3 ),
205+ (shape.size () == 3 ) && (shape[0 ] == outputDims_. height ) &&
206+ (shape[1 ] == outputDims_. width ) && (shape[2 ] == 3 ),
170207 " Expected output tensor of shape " ,
171- expectedOutputHeight ,
208+ outputDims_. height ,
172209 " x" ,
173- expectedOutputWidth ,
210+ outputDims_. width ,
174211 " x3, got " ,
175212 shape);
176213
@@ -186,7 +223,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
186223 TORCH_CHECK (
187224 false ,
188225 " Invalid color conversion library: " ,
189- static_cast <int >(colorConversionLibrary ));
226+ static_cast <int >(colorConversionLibrary_ ));
190227 }
191228}
192229
@@ -214,9 +251,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
214251
215252 TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
216253
217- auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
218- int height = frameDims.height ;
219- int width = frameDims.width ;
254+ int height = filteredAVFrame->height ;
255+ int width = filteredAVFrame->width ;
220256 std::vector<int64_t > shape = {height, width, 3 };
221257 std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
222258 AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
0 commit comments