Skip to content

Commit 04af57f

Browse files
authored
Generalize logic to recreate colorspace conversion objects (#436)
1 parent 9653969 commit 04af57f

File tree

3 files changed

+96
-76
lines changed

3 files changed

+96
-76
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 83 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,16 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
204204
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
205205
}
206206

207-
bool VideoDecoder::SwsContextKey::operator==(
208-
const VideoDecoder::SwsContextKey& other) {
207+
bool VideoDecoder::DecodedFrameContext::operator==(
208+
const VideoDecoder::DecodedFrameContext& other) {
209209
return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight &&
210210
decodedFormat == other.decodedFormat &&
211-
outputWidth == other.outputWidth && outputHeight == other.outputHeight;
211+
expectedWidth == other.expectedWidth &&
212+
expectedHeight == other.expectedHeight;
212213
}
213214

214-
bool VideoDecoder::SwsContextKey::operator!=(
215-
const VideoDecoder::SwsContextKey& other) {
215+
bool VideoDecoder::DecodedFrameContext::operator!=(
216+
const VideoDecoder::DecodedFrameContext& other) {
216217
return !(*this == other);
217218
}
218219

@@ -313,17 +314,14 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
313314
return std::unique_ptr<VideoDecoder>(new VideoDecoder(buffer, length));
314315
}
315316

316-
void VideoDecoder::initializeFilterGraph(
317+
void VideoDecoder::createFilterGraph(
317318
StreamInfo& streamInfo,
318319
int expectedOutputHeight,
319320
int expectedOutputWidth) {
320321
FilterState& filterState = streamInfo.filterState;
321-
if (filterState.filterGraph) {
322-
return;
323-
}
324-
325322
filterState.filterGraph.reset(avfilter_graph_alloc());
326323
TORCH_CHECK(filterState.filterGraph.get() != nullptr);
324+
327325
if (streamInfo.options.ffmpegThreadCount.has_value()) {
328326
filterState.filterGraph->nb_threads =
329327
streamInfo.options.ffmpegThreadCount.value();
@@ -921,12 +919,32 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
921919

922920
torch::Tensor outputTensor;
923921
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
922+
// We need to compare the current frame context with our previous frame
923+
// context. If they are different, then we need to re-create our colorspace
924+
// conversion objects. We create our colorspace conversion objects late so
925+
// that we don't have to depend on the unreliable metadata in the header.
926+
// And we sometimes re-create them because it's possible for frame
927+
// resolution to change mid-stream. Finally, we want to reuse the colorspace
928+
// conversion objects as much as possible for performance reasons.
929+
enum AVPixelFormat frameFormat =
930+
static_cast<enum AVPixelFormat>(frame->format);
931+
auto frameContext = DecodedFrameContext{
932+
frame->width,
933+
frame->height,
934+
frameFormat,
935+
expectedOutputWidth,
936+
expectedOutputHeight};
937+
924938
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
925939
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
926940
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
927941

942+
if (!streamInfo.swsContext || streamInfo.prevFrame != frameContext) {
943+
createSwsContext(streamInfo, frameContext, frame->colorspace);
944+
streamInfo.prevFrame = frameContext;
945+
}
928946
int resultHeight =
929-
convertFrameToBufferUsingSwsScale(streamIndex, frame, outputTensor);
947+
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
930948
// If this check failed, it would mean that the frame wasn't reshaped to
931949
// the expected height.
932950
// TODO: Can we do the same check for width?
@@ -941,16 +959,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
941959
} else if (
942960
streamInfo.colorConversionLibrary ==
943961
ColorConversionLibrary::FILTERGRAPH) {
944-
// Note that is a lazy init; we initialize filtergraph the first time
945-
// we have a raw decoded frame. We do this lazily because up until this
946-
// point, we really don't know what the resolution of the frames are
947-
// without modification. In theory, we should be able to get that from the
948-
// stream metadata, but in practice, we have encountered videos where the
949-
// stream metadata had a different resolution from the actual resolution
950-
// of the raw decoded frames.
951-
if (!streamInfo.filterState.filterGraph) {
952-
initializeFilterGraph(
962+
if (!streamInfo.filterState.filterGraph ||
963+
streamInfo.prevFrame != frameContext) {
964+
createFilterGraph(
953965
streamInfo, expectedOutputHeight, expectedOutputWidth);
966+
streamInfo.prevFrame = frameContext;
954967
}
955968
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
956969

@@ -1351,7 +1364,53 @@ double VideoDecoder::getPtsSecondsForFrame(
13511364
return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase);
13521365
}
13531366

1354-
int VideoDecoder::convertFrameToBufferUsingSwsScale(
1367+
void VideoDecoder::createSwsContext(
1368+
StreamInfo& streamInfo,
1369+
const DecodedFrameContext& frameContext,
1370+
const enum AVColorSpace colorspace) {
1371+
SwsContext* swsContext = sws_getContext(
1372+
frameContext.decodedWidth,
1373+
frameContext.decodedHeight,
1374+
frameContext.decodedFormat,
1375+
frameContext.expectedWidth,
1376+
frameContext.expectedHeight,
1377+
AV_PIX_FMT_RGB24,
1378+
SWS_BILINEAR,
1379+
nullptr,
1380+
nullptr,
1381+
nullptr);
1382+
TORCH_CHECK(swsContext, "sws_getContext() returned nullptr");
1383+
1384+
int* invTable = nullptr;
1385+
int* table = nullptr;
1386+
int srcRange, dstRange, brightness, contrast, saturation;
1387+
int ret = sws_getColorspaceDetails(
1388+
swsContext,
1389+
&invTable,
1390+
&srcRange,
1391+
&table,
1392+
&dstRange,
1393+
&brightness,
1394+
&contrast,
1395+
&saturation);
1396+
TORCH_CHECK(ret != -1, "sws_getColorspaceDetails returned -1");
1397+
1398+
const int* colorspaceTable = sws_getCoefficients(colorspace);
1399+
ret = sws_setColorspaceDetails(
1400+
swsContext,
1401+
colorspaceTable,
1402+
srcRange,
1403+
colorspaceTable,
1404+
dstRange,
1405+
brightness,
1406+
contrast,
1407+
saturation);
1408+
TORCH_CHECK(ret != -1, "sws_setColorspaceDetails returned -1");
1409+
1410+
streamInfo.swsContext.reset(swsContext);
1411+
}
1412+
1413+
int VideoDecoder::convertFrameToTensorUsingSwsScale(
13551414
int streamIndex,
13561415
const AVFrame* frame,
13571416
torch::Tensor& outputTensor) {
@@ -1361,50 +1420,6 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale(
13611420

13621421
int expectedOutputHeight = outputTensor.sizes()[0];
13631422
int expectedOutputWidth = outputTensor.sizes()[1];
1364-
auto curFrameSwsContextKey = SwsContextKey{
1365-
frame->width,
1366-
frame->height,
1367-
frameFormat,
1368-
expectedOutputWidth,
1369-
expectedOutputHeight};
1370-
if (activeStream.swsContext.get() == nullptr ||
1371-
activeStream.swsContextKey != curFrameSwsContextKey) {
1372-
SwsContext* swsContext = sws_getContext(
1373-
frame->width,
1374-
frame->height,
1375-
frameFormat,
1376-
expectedOutputWidth,
1377-
expectedOutputHeight,
1378-
AV_PIX_FMT_RGB24,
1379-
SWS_BILINEAR,
1380-
nullptr,
1381-
nullptr,
1382-
nullptr);
1383-
int* invTable = nullptr;
1384-
int* table = nullptr;
1385-
int srcRange, dstRange, brightness, contrast, saturation;
1386-
sws_getColorspaceDetails(
1387-
swsContext,
1388-
&invTable,
1389-
&srcRange,
1390-
&table,
1391-
&dstRange,
1392-
&brightness,
1393-
&contrast,
1394-
&saturation);
1395-
const int* colorspaceTable = sws_getCoefficients(frame->colorspace);
1396-
sws_setColorspaceDetails(
1397-
swsContext,
1398-
colorspaceTable,
1399-
srcRange,
1400-
colorspaceTable,
1401-
dstRange,
1402-
brightness,
1403-
contrast,
1404-
saturation);
1405-
activeStream.swsContextKey = curFrameSwsContextKey;
1406-
activeStream.swsContext.reset(swsContext);
1407-
}
14081423
SwsContext* swsContext = activeStream.swsContext.get();
14091424
uint8_t* pointers[4] = {
14101425
outputTensor.data_ptr<uint8_t>(), nullptr, nullptr, nullptr};
@@ -1428,10 +1443,12 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
14281443
if (ffmpegStatus < AVSUCCESS) {
14291444
throw std::runtime_error("Failed to add frame to buffer source context");
14301445
}
1446+
14311447
UniqueAVFrame filteredFrame(av_frame_alloc());
14321448
ffmpegStatus =
14331449
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
14341450
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);
1451+
14351452
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get());
14361453
int height = frameDims.height;
14371454
int width = frameDims.width;
@@ -1441,9 +1458,8 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
14411458
auto deleter = [filteredFramePtr](void*) {
14421459
UniqueAVFrame frameToDelete(filteredFramePtr);
14431460
};
1444-
torch::Tensor tensor = torch::from_blob(
1461+
return torch::from_blob(
14451462
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
1446-
return tensor;
14471463
}
14481464

14491465
VideoDecoder::~VideoDecoder() {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ class VideoDecoder {
313313
AVFilterContext* sourceContext = nullptr;
314314
AVFilterContext* sinkContext = nullptr;
315315
};
316-
struct SwsContextKey {
316+
struct DecodedFrameContext {
317317
int decodedWidth;
318318
int decodedHeight;
319319
AVPixelFormat decodedFormat;
320-
int outputWidth;
321-
int outputHeight;
322-
bool operator==(const SwsContextKey&);
323-
bool operator!=(const SwsContextKey&);
320+
int expectedWidth;
321+
int expectedHeight;
322+
bool operator==(const DecodedFrameContext&);
323+
bool operator!=(const DecodedFrameContext&);
324324
};
325325
// Stores information for each stream.
326326
struct StreamInfo {
@@ -342,7 +342,7 @@ class VideoDecoder {
342342
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
343343
std::vector<FrameInfo> keyFrames;
344344
std::vector<FrameInfo> allFrames;
345-
SwsContextKey swsContextKey;
345+
DecodedFrameContext prevFrame;
346346
UniqueSwsContext swsContext;
347347
};
348348
// Returns the key frame index of the presentation timestamp using FFMPEG's
@@ -371,10 +371,14 @@ class VideoDecoder {
371371
void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex);
372372
// Creates and initializes a filter graph for a stream. The filter graph can
373373
// do rescaling and color conversion.
374-
void initializeFilterGraph(
374+
void createFilterGraph(
375375
StreamInfo& streamInfo,
376376
int expectedOutputHeight,
377377
int expectedOutputWidth);
378+
void createSwsContext(
379+
StreamInfo& streamInfo,
380+
const DecodedFrameContext& frameContext,
381+
const enum AVColorSpace colorspace);
378382
void maybeSeekToBeforeDesiredPts();
379383
RawDecodedOutput getDecodedOutputWithFilter(
380384
std::function<bool(int, AVFrame*)>);
@@ -389,7 +393,7 @@ class VideoDecoder {
389393
torch::Tensor convertFrameToTensorUsingFilterGraph(
390394
int streamIndex,
391395
const AVFrame* frame);
392-
int convertFrameToBufferUsingSwsScale(
396+
int convertFrameToTensorUsingSwsScale(
393397
int streamIndex,
394398
const AVFrame* frame,
395399
torch::Tensor& outputTensor);

version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.0.4a0
1+
0.1.2a0

0 commit comments

Comments
 (0)