@@ -204,15 +204,16 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
204
204
frames = allocateEmptyHWCTensor (height, width, options.device , numFrames);
205
205
}
206
206
207
- bool VideoDecoder::SwsContextKey ::operator ==(
208
- const VideoDecoder::SwsContextKey & other) {
207
+ bool VideoDecoder::DecodedFrameContext ::operator ==(
208
+ const VideoDecoder::DecodedFrameContext & other) {
209
209
return decodedWidth == other.decodedWidth && decodedHeight == decodedHeight &&
210
210
decodedFormat == other.decodedFormat &&
211
- outputWidth == other.outputWidth && outputHeight == other.outputHeight ;
211
+ expectedWidth == other.expectedWidth &&
212
+ expectedHeight == other.expectedHeight ;
212
213
}
213
214
214
- bool VideoDecoder::SwsContextKey ::operator !=(
215
- const VideoDecoder::SwsContextKey & other) {
215
+ bool VideoDecoder::DecodedFrameContext ::operator !=(
216
+ const VideoDecoder::DecodedFrameContext & other) {
216
217
return !(*this == other);
217
218
}
218
219
@@ -313,17 +314,14 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
313
314
return std::unique_ptr<VideoDecoder>(new VideoDecoder (buffer, length));
314
315
}
315
316
316
- void VideoDecoder::initializeFilterGraph (
317
+ void VideoDecoder::createFilterGraph (
317
318
StreamInfo& streamInfo,
318
319
int expectedOutputHeight,
319
320
int expectedOutputWidth) {
320
321
FilterState& filterState = streamInfo.filterState ;
321
- if (filterState.filterGraph ) {
322
- return ;
323
- }
324
-
325
322
filterState.filterGraph .reset (avfilter_graph_alloc ());
326
323
TORCH_CHECK (filterState.filterGraph .get () != nullptr );
324
+
327
325
if (streamInfo.options .ffmpegThreadCount .has_value ()) {
328
326
filterState.filterGraph ->nb_threads =
329
327
streamInfo.options .ffmpegThreadCount .value ();
@@ -921,12 +919,32 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
921
919
922
920
torch::Tensor outputTensor;
923
921
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
922
+ // We need to compare the current frame context with our previous frame
923
+ // context. If they are different, then we need to re-create our colorspace
924
+ // conversion objects. We create our colorspace conversion objects late so
925
+ // that we don't have to depend on the unreliable metadata in the header.
926
+ // And we sometimes re-create them because it's possible for frame
927
+ // resolution to change mid-stream. Finally, we want to reuse the colorspace
928
+ // conversion objects as much as possible for performance reasons.
929
+ enum AVPixelFormat frameFormat =
930
+ static_cast <enum AVPixelFormat>(frame->format );
931
+ auto frameContext = DecodedFrameContext{
932
+ frame->width ,
933
+ frame->height ,
934
+ frameFormat,
935
+ expectedOutputWidth,
936
+ expectedOutputHeight};
937
+
924
938
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
925
939
outputTensor = preAllocatedOutputTensor.value_or (allocateEmptyHWCTensor (
926
940
expectedOutputHeight, expectedOutputWidth, torch::kCPU ));
927
941
942
+ if (!streamInfo.swsContext || streamInfo.prevFrame != frameContext) {
943
+ createSwsContext (streamInfo, frameContext, frame->colorspace );
944
+ streamInfo.prevFrame = frameContext;
945
+ }
928
946
int resultHeight =
929
- convertFrameToBufferUsingSwsScale (streamIndex, frame, outputTensor);
947
+ convertFrameToTensorUsingSwsScale (streamIndex, frame, outputTensor);
930
948
// If this check failed, it would mean that the frame wasn't reshaped to
931
949
// the expected height.
932
950
// TODO: Can we do the same check for width?
@@ -941,16 +959,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
941
959
} else if (
942
960
streamInfo.colorConversionLibrary ==
943
961
ColorConversionLibrary::FILTERGRAPH) {
944
- // Note that is a lazy init; we initialize filtergraph the first time
945
- // we have a raw decoded frame. We do this lazily because up until this
946
- // point, we really don't know what the resolution of the frames are
947
- // without modification. In theory, we should be able to get that from the
948
- // stream metadata, but in practice, we have encountered videos where the
949
- // stream metadata had a different resolution from the actual resolution
950
- // of the raw decoded frames.
951
- if (!streamInfo.filterState .filterGraph ) {
952
- initializeFilterGraph (
962
+ if (!streamInfo.filterState .filterGraph ||
963
+ streamInfo.prevFrame != frameContext) {
964
+ createFilterGraph (
953
965
streamInfo, expectedOutputHeight, expectedOutputWidth);
966
+ streamInfo.prevFrame = frameContext;
954
967
}
955
968
outputTensor = convertFrameToTensorUsingFilterGraph (streamIndex, frame);
956
969
@@ -1351,7 +1364,53 @@ double VideoDecoder::getPtsSecondsForFrame(
1351
1364
return ptsToSeconds (stream.allFrames [frameIndex].pts , stream.timeBase );
1352
1365
}
1353
1366
1354
- int VideoDecoder::convertFrameToBufferUsingSwsScale (
1367
+ void VideoDecoder::createSwsContext (
1368
+ StreamInfo& streamInfo,
1369
+ const DecodedFrameContext& frameContext,
1370
+ const enum AVColorSpace colorspace) {
1371
+ SwsContext* swsContext = sws_getContext (
1372
+ frameContext.decodedWidth ,
1373
+ frameContext.decodedHeight ,
1374
+ frameContext.decodedFormat ,
1375
+ frameContext.expectedWidth ,
1376
+ frameContext.expectedHeight ,
1377
+ AV_PIX_FMT_RGB24,
1378
+ SWS_BILINEAR,
1379
+ nullptr ,
1380
+ nullptr ,
1381
+ nullptr );
1382
+ TORCH_CHECK (swsContext, " sws_getContext() returned nullptr" );
1383
+
1384
+ int * invTable = nullptr ;
1385
+ int * table = nullptr ;
1386
+ int srcRange, dstRange, brightness, contrast, saturation;
1387
+ int ret = sws_getColorspaceDetails (
1388
+ swsContext,
1389
+ &invTable,
1390
+ &srcRange,
1391
+ &table,
1392
+ &dstRange,
1393
+ &brightness,
1394
+ &contrast,
1395
+ &saturation);
1396
+ TORCH_CHECK (ret != -1 , " sws_getColorspaceDetails returned -1" );
1397
+
1398
+ const int * colorspaceTable = sws_getCoefficients (colorspace);
1399
+ ret = sws_setColorspaceDetails (
1400
+ swsContext,
1401
+ colorspaceTable,
1402
+ srcRange,
1403
+ colorspaceTable,
1404
+ dstRange,
1405
+ brightness,
1406
+ contrast,
1407
+ saturation);
1408
+ TORCH_CHECK (ret != -1 , " sws_setColorspaceDetails returned -1" );
1409
+
1410
+ streamInfo.swsContext .reset (swsContext);
1411
+ }
1412
+
1413
+ int VideoDecoder::convertFrameToTensorUsingSwsScale (
1355
1414
int streamIndex,
1356
1415
const AVFrame* frame,
1357
1416
torch::Tensor& outputTensor) {
@@ -1361,50 +1420,6 @@ int VideoDecoder::convertFrameToBufferUsingSwsScale(
1361
1420
1362
1421
int expectedOutputHeight = outputTensor.sizes ()[0 ];
1363
1422
int expectedOutputWidth = outputTensor.sizes ()[1 ];
1364
- auto curFrameSwsContextKey = SwsContextKey{
1365
- frame->width ,
1366
- frame->height ,
1367
- frameFormat,
1368
- expectedOutputWidth,
1369
- expectedOutputHeight};
1370
- if (activeStream.swsContext .get () == nullptr ||
1371
- activeStream.swsContextKey != curFrameSwsContextKey) {
1372
- SwsContext* swsContext = sws_getContext (
1373
- frame->width ,
1374
- frame->height ,
1375
- frameFormat,
1376
- expectedOutputWidth,
1377
- expectedOutputHeight,
1378
- AV_PIX_FMT_RGB24,
1379
- SWS_BILINEAR,
1380
- nullptr ,
1381
- nullptr ,
1382
- nullptr );
1383
- int * invTable = nullptr ;
1384
- int * table = nullptr ;
1385
- int srcRange, dstRange, brightness, contrast, saturation;
1386
- sws_getColorspaceDetails (
1387
- swsContext,
1388
- &invTable,
1389
- &srcRange,
1390
- &table,
1391
- &dstRange,
1392
- &brightness,
1393
- &contrast,
1394
- &saturation);
1395
- const int * colorspaceTable = sws_getCoefficients (frame->colorspace );
1396
- sws_setColorspaceDetails (
1397
- swsContext,
1398
- colorspaceTable,
1399
- srcRange,
1400
- colorspaceTable,
1401
- dstRange,
1402
- brightness,
1403
- contrast,
1404
- saturation);
1405
- activeStream.swsContextKey = curFrameSwsContextKey;
1406
- activeStream.swsContext .reset (swsContext);
1407
- }
1408
1423
SwsContext* swsContext = activeStream.swsContext .get ();
1409
1424
uint8_t * pointers[4 ] = {
1410
1425
outputTensor.data_ptr <uint8_t >(), nullptr , nullptr , nullptr };
@@ -1428,10 +1443,12 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
1428
1443
if (ffmpegStatus < AVSUCCESS) {
1429
1444
throw std::runtime_error (" Failed to add frame to buffer source context" );
1430
1445
}
1446
+
1431
1447
UniqueAVFrame filteredFrame (av_frame_alloc ());
1432
1448
ffmpegStatus =
1433
1449
av_buffersink_get_frame (filterState.sinkContext , filteredFrame.get ());
1434
1450
TORCH_CHECK_EQ (filteredFrame->format , AV_PIX_FMT_RGB24);
1451
+
1435
1452
auto frameDims = getHeightAndWidthFromResizedAVFrame (*filteredFrame.get ());
1436
1453
int height = frameDims.height ;
1437
1454
int width = frameDims.width ;
@@ -1441,9 +1458,8 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
1441
1458
auto deleter = [filteredFramePtr](void *) {
1442
1459
UniqueAVFrame frameToDelete (filteredFramePtr);
1443
1460
};
1444
- torch::Tensor tensor = torch::from_blob (
1461
+ return torch::from_blob (
1445
1462
filteredFramePtr->data [0 ], shape, strides, deleter, {torch::kUInt8 });
1446
- return tensor;
1447
1463
}
1448
1464
1449
1465
VideoDecoder::~VideoDecoder () {
0 commit comments