66
77#include " src/torchcodec/_core/CpuDeviceInterface.h"
88
9+ extern " C" {
10+ #include < libavfilter/buffersink.h>
11+ #include < libavfilter/buffersrc.h>
12+ }
13+
914namespace facebook ::torchcodec {
1015namespace {
1116
@@ -17,6 +22,20 @@ bool g_cpu = registerDeviceInterface(
1722
1823} // namespace
1924
25+ bool CpuDeviceInterface::DecodedFrameContext::operator ==(
26+ const CpuDeviceInterface::DecodedFrameContext& other) {
27+ return decodedWidth == other.decodedWidth &&
28+ decodedHeight == other.decodedHeight &&
29+ decodedFormat == other.decodedFormat &&
30+ expectedWidth == other.expectedWidth &&
31+ expectedHeight == other.expectedHeight ;
32+ }
33+
34+ bool CpuDeviceInterface::DecodedFrameContext::operator !=(
35+ const CpuDeviceInterface::DecodedFrameContext& other) {
36+ return !(*this == other);
37+ }
38+
2039CpuDeviceInterface::CpuDeviceInterface (
2140 const torch::Device& device,
2241 const AVRational& timeBase)
@@ -26,4 +45,321 @@ CpuDeviceInterface::CpuDeviceInterface(
2645 }
2746}
2847
48+ // Note [preAllocatedOutputTensor with swscale and filtergraph]:
49+ // Callers may pass a pre-allocated tensor, where the output.data tensor will
50+ // be stored. This parameter is honored in any case, but it only leads to a
51+ // speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
52+ // decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
53+ // found a way to do that with filtegraph.
54+ // TODO: Figure out whether that's possible!
55+ // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
56+ // `dimension_order` parameter. It's up to callers to re-shape it if needed.
57+ void CpuDeviceInterface::convertAVFrameToFrameOutput (
58+ const VideoStreamOptions& videoStreamOptions,
59+ UniqueAVFrame& avFrame,
60+ FrameOutput& frameOutput,
61+ std::optional<torch::Tensor> preAllocatedOutputTensor) {
62+ auto frameDims =
63+ getHeightAndWidthFromOptionsOrAVFrame (videoStreamOptions, avFrame);
64+ int expectedOutputHeight = frameDims.height ;
65+ int expectedOutputWidth = frameDims.width ;
66+
67+ if (preAllocatedOutputTensor.has_value ()) {
68+ auto shape = preAllocatedOutputTensor.value ().sizes ();
69+ TORCH_CHECK (
70+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
71+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
72+ " Expected pre-allocated tensor of shape " ,
73+ expectedOutputHeight,
74+ " x" ,
75+ expectedOutputWidth,
76+ " x3, got " ,
77+ shape);
78+ }
79+
80+ torch::Tensor outputTensor;
81+ // We need to compare the current frame context with our previous frame
82+ // context. If they are different, then we need to re-create our colorspace
83+ // conversion objects. We create our colorspace conversion objects late so
84+ // that we don't have to depend on the unreliable metadata in the header.
85+ // And we sometimes re-create them because it's possible for frame
86+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
87+ // conversion objects as much as possible for performance reasons.
88+ enum AVPixelFormat frameFormat =
89+ static_cast <enum AVPixelFormat>(avFrame->format );
90+ auto frameContext = DecodedFrameContext{
91+ avFrame->width ,
92+ avFrame->height ,
93+ frameFormat,
94+ avFrame->sample_aspect_ratio ,
95+ expectedOutputWidth,
96+ expectedOutputHeight};
97+
98+ // By default, we want to use swscale for color conversion because it is
99+ // faster. However, it has width requirements, so we may need to fall back
100+ // to filtergraph. We also need to respect what was requested from the
101+ // options; we respect the options unconditionally, so it's possible for
102+ // swscale's width requirements to be violated. We don't expose the ability to
103+ // choose color conversion library publicly; we only use this ability
104+ // internally.
105+
106+ // swscale requires widths to be multiples of 32:
107+ // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
108+ // so we fall back to filtergraph if the width is not a multiple of 32.
109+ auto defaultLibrary = (expectedOutputWidth % 32 == 0 )
110+ ? ColorConversionLibrary::SWSCALE
111+ : ColorConversionLibrary::FILTERGRAPH;
112+
113+ ColorConversionLibrary colorConversionLibrary =
114+ videoStreamOptions.colorConversionLibrary .value_or (defaultLibrary);
115+
116+ if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
117+ outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
118+ expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
119+
120+ if (!swsContext_ || prevFrameContext_ != frameContext) {
121+ createSwsContext (frameContext, avFrame->colorspace );
122+ prevFrameContext_ = frameContext;
123+ }
124+ int resultHeight =
125+ convertAVFrameToTensorUsingSwsScale (avFrame, outputTensor);
126+ // If this check failed, it would mean that the frame wasn't reshaped to
127+ // the expected height.
128+ // TODO: Can we do the same check for width?
129+ TORCH_CHECK (
130+ resultHeight == expectedOutputHeight,
131+ " resultHeight != expectedOutputHeight: " ,
132+ resultHeight,
133+ " != " ,
134+ expectedOutputHeight);
135+
136+ frameOutput.data = outputTensor;
137+ } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
138+ if (!filterGraphContext_.filterGraph || prevFrameContext_ != frameContext) {
139+ createFilterGraph (frameContext, videoStreamOptions);
140+ prevFrameContext_ = frameContext;
141+ }
142+ outputTensor = convertAVFrameToTensorUsingFilterGraph (avFrame);
143+
144+ // Similarly to above, if this check fails it means the frame wasn't
145+ // reshaped to its expected dimensions by filtergraph.
146+ auto shape = outputTensor.sizes ();
147+ TORCH_CHECK (
148+ (shape.size () == 3 ) && (shape[0 ] == expectedOutputHeight) &&
149+ (shape[1 ] == expectedOutputWidth) && (shape[2 ] == 3 ),
150+ " Expected output tensor of shape " ,
151+ expectedOutputHeight,
152+ " x" ,
153+ expectedOutputWidth,
154+ " x3, got " ,
155+ shape);
156+
157+ if (preAllocatedOutputTensor.has_value ()) {
158+ // We have already validated that preAllocatedOutputTensor and
159+ // outputTensor have the same shape.
160+ preAllocatedOutputTensor.value ().copy_ (outputTensor);
161+ frameOutput.data = preAllocatedOutputTensor.value ();
162+ } else {
163+ frameOutput.data = outputTensor;
164+ }
165+ } else {
166+ throw std::runtime_error (
167+ " Invalid color conversion library: " +
168+ std::to_string (static_cast <int >(colorConversionLibrary)));
169+ }
170+ }
171+
172+ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale (
173+ const UniqueAVFrame& avFrame,
174+ torch::Tensor& outputTensor) {
175+ uint8_t * pointers[4 ] = {
176+ outputTensor.data_ptr <uint8_t >(), nullptr , nullptr , nullptr };
177+ int expectedOutputWidth = outputTensor.sizes ()[1 ];
178+ int linesizes[4 ] = {expectedOutputWidth * 3 , 0 , 0 , 0 };
179+ int resultHeight = sws_scale (
180+ swsContext_.get (),
181+ avFrame->data ,
182+ avFrame->linesize ,
183+ 0 ,
184+ avFrame->height ,
185+ pointers,
186+ linesizes);
187+ return resultHeight;
188+ }
189+
190+ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph (
191+ const UniqueAVFrame& avFrame) {
192+ int status = av_buffersrc_write_frame (
193+ filterGraphContext_.sourceContext , avFrame.get ());
194+ if (status < AVSUCCESS) {
195+ throw std::runtime_error (" Failed to add frame to buffer source context" );
196+ }
197+
198+ UniqueAVFrame filteredAVFrame (av_frame_alloc ());
199+ status = av_buffersink_get_frame (
200+ filterGraphContext_.sinkContext , filteredAVFrame.get ());
201+ TORCH_CHECK_EQ (filteredAVFrame->format , AV_PIX_FMT_RGB24);
202+
203+ auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredAVFrame.get ());
204+ int height = frameDims.height ;
205+ int width = frameDims.width ;
206+ std::vector<int64_t > shape = {height, width, 3 };
207+ std::vector<int64_t > strides = {filteredAVFrame->linesize [0 ], 3 , 1 };
208+ AVFrame* filteredAVFramePtr = filteredAVFrame.release ();
209+ auto deleter = [filteredAVFramePtr](void *) {
210+ UniqueAVFrame avFrameToDelete (filteredAVFramePtr);
211+ };
212+ return torch::from_blob (
213+ filteredAVFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
214+ }
215+
216+ void CpuDeviceInterface::createFilterGraph (
217+ const DecodedFrameContext& frameContext,
218+ const VideoStreamOptions& videoStreamOptions) {
219+ filterGraphContext_.filterGraph .reset (avfilter_graph_alloc ());
220+ TORCH_CHECK (filterGraphContext_.filterGraph .get () != nullptr );
221+
222+ if (videoStreamOptions.ffmpegThreadCount .has_value ()) {
223+ filterGraphContext_.filterGraph ->nb_threads =
224+ videoStreamOptions.ffmpegThreadCount .value ();
225+ }
226+
227+ const AVFilter* buffersrc = avfilter_get_by_name (" buffer" );
228+ const AVFilter* buffersink = avfilter_get_by_name (" buffersink" );
229+
230+ std::stringstream filterArgs;
231+ filterArgs << " video_size=" << frameContext.decodedWidth << " x"
232+ << frameContext.decodedHeight ;
233+ filterArgs << " :pix_fmt=" << frameContext.decodedFormat ;
234+ filterArgs << " :time_base=" << timeBase_.num << " /" << timeBase_.den ;
235+ filterArgs << " :pixel_aspect=" << frameContext.decodedAspectRatio .num << " /"
236+ << frameContext.decodedAspectRatio .den ;
237+
238+ int status = avfilter_graph_create_filter (
239+ &filterGraphContext_.sourceContext ,
240+ buffersrc,
241+ " in" ,
242+ filterArgs.str ().c_str (),
243+ nullptr ,
244+ filterGraphContext_.filterGraph .get ());
245+ if (status < 0 ) {
246+ throw std::runtime_error (
247+ std::string (" Failed to create filter graph: " ) + filterArgs.str () +
248+ " : " + getFFMPEGErrorStringFromErrorCode (status));
249+ }
250+
251+ status = avfilter_graph_create_filter (
252+ &filterGraphContext_.sinkContext ,
253+ buffersink,
254+ " out" ,
255+ nullptr ,
256+ nullptr ,
257+ filterGraphContext_.filterGraph .get ());
258+ if (status < 0 ) {
259+ throw std::runtime_error (
260+ " Failed to create filter graph: " +
261+ getFFMPEGErrorStringFromErrorCode (status));
262+ }
263+
264+ enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};
265+
266+ status = av_opt_set_int_list (
267+ filterGraphContext_.sinkContext ,
268+ " pix_fmts" ,
269+ pix_fmts,
270+ AV_PIX_FMT_NONE,
271+ AV_OPT_SEARCH_CHILDREN);
272+ if (status < 0 ) {
273+ throw std::runtime_error (
274+ " Failed to set output pixel formats: " +
275+ getFFMPEGErrorStringFromErrorCode (status));
276+ }
277+
278+ UniqueAVFilterInOut outputs (avfilter_inout_alloc ());
279+ UniqueAVFilterInOut inputs (avfilter_inout_alloc ());
280+
281+ outputs->name = av_strdup (" in" );
282+ outputs->filter_ctx = filterGraphContext_.sourceContext ;
283+ outputs->pad_idx = 0 ;
284+ outputs->next = nullptr ;
285+ inputs->name = av_strdup (" out" );
286+ inputs->filter_ctx = filterGraphContext_.sinkContext ;
287+ inputs->pad_idx = 0 ;
288+ inputs->next = nullptr ;
289+
290+ std::stringstream description;
291+ description << " scale=" << frameContext.expectedWidth << " :"
292+ << frameContext.expectedHeight ;
293+ description << " :sws_flags=bilinear" ;
294+
295+ AVFilterInOut* outputsTmp = outputs.release ();
296+ AVFilterInOut* inputsTmp = inputs.release ();
297+ status = avfilter_graph_parse_ptr (
298+ filterGraphContext_.filterGraph .get (),
299+ description.str ().c_str (),
300+ &inputsTmp,
301+ &outputsTmp,
302+ nullptr );
303+ outputs.reset (outputsTmp);
304+ inputs.reset (inputsTmp);
305+ if (status < 0 ) {
306+ throw std::runtime_error (
307+ " Failed to parse filter description: " +
308+ getFFMPEGErrorStringFromErrorCode (status));
309+ }
310+
311+ status =
312+ avfilter_graph_config (filterGraphContext_.filterGraph .get (), nullptr );
313+ if (status < 0 ) {
314+ throw std::runtime_error (
315+ " Failed to configure filter graph: " +
316+ getFFMPEGErrorStringFromErrorCode (status));
317+ }
318+ }
319+
320+ void CpuDeviceInterface::createSwsContext (
321+ const DecodedFrameContext& frameContext,
322+ const enum AVColorSpace colorspace) {
323+ SwsContext* swsContext = sws_getContext (
324+ frameContext.decodedWidth ,
325+ frameContext.decodedHeight ,
326+ frameContext.decodedFormat ,
327+ frameContext.expectedWidth ,
328+ frameContext.expectedHeight ,
329+ AV_PIX_FMT_RGB24,
330+ SWS_BILINEAR,
331+ nullptr ,
332+ nullptr ,
333+ nullptr );
334+ TORCH_CHECK (swsContext, " sws_getContext() returned nullptr" );
335+
336+ int * invTable = nullptr ;
337+ int * table = nullptr ;
338+ int srcRange, dstRange, brightness, contrast, saturation;
339+ int ret = sws_getColorspaceDetails (
340+ swsContext,
341+ &invTable,
342+ &srcRange,
343+ &table,
344+ &dstRange,
345+ &brightness,
346+ &contrast,
347+ &saturation);
348+ TORCH_CHECK (ret != -1 , " sws_getColorspaceDetails returned -1" );
349+
350+ const int * colorspaceTable = sws_getCoefficients (colorspace);
351+ ret = sws_setColorspaceDetails (
352+ swsContext,
353+ colorspaceTable,
354+ srcRange,
355+ colorspaceTable,
356+ dstRange,
357+ brightness,
358+ contrast,
359+ saturation);
360+ TORCH_CHECK (ret != -1 , " sws_setColorspaceDetails returned -1" );
361+
362+ swsContext_.reset (swsContext);
363+ }
364+
29365} // namespace facebook::torchcodec
0 commit comments