Skip to content

Commit ee5e41b

Browse files
committed
WIPWIPWIPWIWPWIP
1 parent 22e88d6 commit ee5e41b

File tree

5 files changed

+128
-3
lines changed

5 files changed

+128
-3
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ function(make_torchcodec_libraries
9696
Encoder.cpp
9797
ValidationUtils.cpp
9898
Transform.cpp
99+
SwsContext.cpp
99100
)
100101

101102
if(ENABLE_CUDA)

src/torchcodec/_core/CUDACommon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ UniqueAVFrame transferCpuFrameToGpuNV12(
367367
cpuFrame, outputDims, cpuFrame->colorspace, AV_PIX_FMT_NV12, SWS_BILINEAR);
368368

369369
int convertedHeight = sws_scale(
370-
swsContext.get(),
370+
swsContext,
371371
cpuFrame->data,
372372
cpuFrame->linesize,
373373
0,

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
215215
const UniqueAVFrame& avFrame,
216216
torch::Tensor& outputTensor,
217217
const FrameDims& outputDims) {
218-
// Get or create swscale context. The SwsContext class manages caching
218+
// Get or create swscale context. The SwsScaler class manages caching
219219
// and recreation logic internally based on frame properties.
220220
auto swsContext = swsCtx_.getOrCreateContext(
221221
avFrame, outputDims, avFrame->colorspace, AV_PIX_FMT_RGB24, swsFlags_);
@@ -225,7 +225,7 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale(
225225
int expectedOutputWidth = outputTensor.sizes()[1];
226226
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
227227
int resultHeight = sws_scale(
228-
swsContext.get(),
228+
swsContext,
229229
avFrame->data,
230230
avFrame->linesize,
231231
0,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "src/torchcodec/_core/SwsContext.h"
8+
#include "src/torchcodec/_core/FFMPEGCommon.h"
9+
10+
extern "C" {
11+
#include <libswscale/swscale.h>
12+
}
13+
14+
namespace facebook::torchcodec {
15+
16+
SwsFrameContext::SwsFrameContext(
17+
int inputWidth,
18+
int inputHeight,
19+
AVPixelFormat inputFormat,
20+
int outputWidth,
21+
int outputHeight)
22+
: inputWidth(inputWidth),
23+
inputHeight(inputHeight),
24+
inputFormat(inputFormat),
25+
outputWidth(outputWidth),
26+
outputHeight(outputHeight) {}
27+
28+
bool SwsFrameContext::operator==(const SwsFrameContext& other) const {
29+
return inputWidth == other.inputWidth && inputHeight == other.inputHeight &&
30+
inputFormat == other.inputFormat && outputWidth == other.outputWidth &&
31+
outputHeight == other.outputHeight;
32+
}
33+
34+
bool SwsFrameContext::operator!=(const SwsFrameContext& other) const {
35+
return !(*this == other);
36+
}
37+
38+
SwsContext* SwsScaler::getOrCreateContext(
39+
const UniqueAVFrame& avFrame,
40+
const FrameDims& outputDims,
41+
AVColorSpace colorspace,
42+
AVPixelFormat outputFormat,
43+
int swsFlags) {
44+
enum AVPixelFormat frameFormat =
45+
static_cast<enum AVPixelFormat>(avFrame->format);
46+
47+
SwsFrameContext currentFrameContext(
48+
avFrame->width,
49+
avFrame->height,
50+
frameFormat,
51+
outputDims.width,
52+
outputDims.height);
53+
54+
// Recreate swscale context only if frame properties changed
55+
if (!swsContext_ || prevFrameContext_ != currentFrameContext) {
56+
swsContext_ = createSwsContext(
57+
currentFrameContext, colorspace, outputFormat, swsFlags);
58+
prevFrameContext_ = currentFrameContext;
59+
}
60+
61+
return swsContext_.get();
62+
}
63+
64+
} // namespace facebook::torchcodec

src/torchcodec/_core/SwsContext.h

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
extern "C" {
10+
#include <libswscale/swscale.h>
11+
}
12+
13+
#include "src/torchcodec/_core/Frame.h"
14+
15+
namespace facebook::torchcodec {
16+
17+
// Context describing frame properties needed for swscale conversion.
18+
// Used to detect when swscale context needs to be recreated.
19+
struct SwsFrameContext {
20+
int inputWidth;
21+
int inputHeight;
22+
AVPixelFormat inputFormat;
23+
int outputWidth;
24+
int outputHeight;
25+
26+
SwsFrameContext(
27+
int inputWidth,
28+
int inputHeight,
29+
AVPixelFormat inputFormat,
30+
int outputWidth,
31+
int outputHeight);
32+
33+
bool operator==(const SwsFrameContext& other) const;
34+
bool operator!=(const SwsFrameContext& other) const;
35+
};
36+
37+
// Manages swscale context creation and caching across multiple frame conversions.
38+
// Reuses the context as long as frame properties remain the same.
39+
class SwsScaler {
40+
public:
41+
SwsScaler() = default;
42+
~SwsScaler() = default;
43+
44+
// Get or create a swscale context for the given frame and output dimensions.
45+
// Reuses cached context if frame properties haven't changed.
46+
// Returns a raw pointer to the internal swscale context. The pointer is valid
47+
// as long as this SwsScaler object is alive.
48+
SwsContext* getOrCreateContext(
49+
const UniqueAVFrame& avFrame,
50+
const FrameDims& outputDims,
51+
AVColorSpace colorspace,
52+
AVPixelFormat outputFormat,
53+
int swsFlags = SWS_BILINEAR);
54+
55+
private:
56+
UniqueSwsContext swsContext_;
57+
SwsFrameContext prevFrameContext_ = SwsFrameContext(0, 0, AV_PIX_FMT_NONE, 0, 0);
58+
};
59+
60+
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)