Skip to content

Commit 64f4595

Browse files
committed
Handle frame names
1 parent 32a0f8f commit 64f4595

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
802802
maybeDesiredPts_ = std::nullopt;
803803
}
804804
// Need to get the next frame or error from PopFrame.
805-
UniqueAVFrame frame(av_frame_alloc());
805+
UniqueAVFrame avFrame(av_frame_alloc());
806806
int ffmpegStatus = AVSUCCESS;
807807
bool reachedEOF = false;
808808
int frameStreamIndex = -1;
@@ -812,7 +812,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
812812
for (int streamIndex : activeStreamIndices_) {
813813
StreamInfo& streamInfo = streamInfos_[streamIndex];
814814
ffmpegStatus =
815-
avcodec_receive_frame(streamInfo.codecContext.get(), frame.get());
815+
avcodec_receive_frame(streamInfo.codecContext.get(), avFrame.get());
816816
bool gotNonRetriableError =
817817
ffmpegStatus != AVSUCCESS && ffmpegStatus != AVERROR(EAGAIN);
818818
if (gotNonRetriableError) {
@@ -829,7 +829,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
829829
}
830830
decodeStats_.numFramesReceivedByDecoder++;
831831
bool gotNeededFrame = ffmpegStatus == AVSUCCESS &&
832-
filterFunction(frameStreamIndex, frame.get());
832+
filterFunction(frameStreamIndex, avFrame.get());
833833
if (gotNeededFrame) {
834834
break;
835835
} else if (ffmpegStatus == AVSUCCESS) {
@@ -897,11 +897,11 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
897897
// av_receive_frame() or the user will have seeked to a different location in
898898
// the file and that will flush the decoder.
899899
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
900-
activeStreamInfo.currentPts = frame->pts;
901-
activeStreamInfo.currentDuration = getDuration(frame);
900+
activeStreamInfo.currentPts = avFrame->pts;
901+
activeStreamInfo.currentDuration = getDuration(avFrame);
902902
RawDecodedOutput rawOutput;
903903
rawOutput.streamIndex = frameStreamIndex;
904-
rawOutput.frame = std::move(frame);
904+
rawOutput.frame = std::move(avFrame);
905905
return rawOutput;
906906
}
907907

@@ -911,14 +911,14 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
911911
// Convert the frame to tensor.
912912
DecodedOutput output;
913913
int streamIndex = rawOutput.streamIndex;
914-
AVFrame* frame = rawOutput.frame.get();
914+
AVFrame* avFrame = rawOutput.frame.get();
915915
output.streamIndex = streamIndex;
916916
auto& streamInfo = streamInfos_[streamIndex];
917917
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
918918
output.ptsSeconds =
919-
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
919+
ptsToSeconds(avFrame->pts, formatContext_->streams[streamIndex]->time_base);
920920
output.durationSeconds = ptsToSeconds(
921-
getDuration(frame), formatContext_->streams[streamIndex]->time_base);
921+
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
922922
// TODO: we should fold preAllocatedOutputTensor into RawDecodedOutput.
923923
if (streamInfo.options.device.type() == torch::kCPU) {
924924
convertAVFrameToDecodedOutputOnCPU(
@@ -951,11 +951,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
951951
DecodedOutput& output,
952952
std::optional<torch::Tensor> preAllocatedOutputTensor) {
953953
int streamIndex = rawOutput.streamIndex;
954-
AVFrame* frame = rawOutput.frame.get();
954+
AVFrame* avFrame = rawOutput.frame.get();
955955
auto& streamInfo = streamInfos_[streamIndex];
956956

957957
auto frameDims =
958-
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame);
958+
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *avFrame);
959959
int expectedOutputHeight = frameDims.height;
960960
int expectedOutputWidth = frameDims.width;
961961

@@ -981,10 +981,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
981981
// resolution to change mid-stream. Finally, we want to reuse the colorspace
982982
// conversion objects as much as possible for performance reasons.
983983
enum AVPixelFormat frameFormat =
984-
static_cast<enum AVPixelFormat>(frame->format);
984+
static_cast<enum AVPixelFormat>(avFrame->format);
985985
auto frameContext = DecodedFrameContext{
986-
frame->width,
987-
frame->height,
986+
avFrame->width,
987+
avFrame->height,
988988
frameFormat,
989989
expectedOutputWidth,
990990
expectedOutputHeight};
@@ -994,11 +994,11 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
994994
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
995995

996996
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
997-
createSwsContext(streamInfo, frameContext, frame->colorspace);
997+
createSwsContext(streamInfo, frameContext, avFrame->colorspace);
998998
streamInfo.prevFrameContext = frameContext;
999999
}
10001000
int resultHeight =
1001-
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
1001+
convertAVFrameToTensorUsingSwsScale(streamIndex, avFrame, outputTensor);
10021002
// If this check failed, it would mean that the frame wasn't reshaped to
10031003
// the expected height.
10041004
// TODO: Can we do the same check for width?
@@ -1018,7 +1018,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
10181018
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
10191019
streamInfo.prevFrameContext = frameContext;
10201020
}
1021-
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
1021+
outputTensor = convertAVFrameToTensorUsingFilterGraph(streamIndex, avFrame);
10221022

10231023
// Similarly to above, if this check fails it means the frame wasn't
10241024
// reshaped to its expected dimensions by filtergraph.
@@ -1064,11 +1064,11 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
10641064

10651065
setCursorPtsInSeconds(seconds);
10661066
RawDecodedOutput rawOutput = getDecodedOutputWithFilter(
1067-
[seconds, this](int frameStreamIndex, AVFrame* frame) {
1067+
[seconds, this](int frameStreamIndex, AVFrame* avFrame) {
10681068
StreamInfo& streamInfo = streamInfos_[frameStreamIndex];
1069-
double frameStartTime = ptsToSeconds(frame->pts, streamInfo.timeBase);
1069+
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
10701070
double frameEndTime =
1071-
ptsToSeconds(frame->pts + getDuration(frame), streamInfo.timeBase);
1071+
ptsToSeconds(avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
10721072
if (frameStartTime > seconds) {
10731073
// FFMPEG seeked past the frame we are looking for even though we
10741074
// set max_ts to be our needed timestamp in avformat_seek_file()
@@ -1443,9 +1443,9 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
14431443

14441444
VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
14451445
auto rawOutput =
1446-
getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* frame) {
1446+
getDecodedOutputWithFilter([this](int frameStreamIndex, AVFrame* avFrame) {
14471447
StreamInfo& activeStreamInfo = streamInfos_[frameStreamIndex];
1448-
return frame->pts >= activeStreamInfo.discardFramesBeforePts;
1448+
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
14491449
});
14501450
return rawOutput;
14511451
}
@@ -1534,9 +1534,9 @@ void VideoDecoder::createSwsContext(
15341534
streamInfo.swsContext.reset(swsContext);
15351535
}
15361536

1537-
int VideoDecoder::convertFrameToTensorUsingSwsScale(
1537+
int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
15381538
int streamIndex,
1539-
const AVFrame* frame,
1539+
const AVFrame* avFrame,
15401540
torch::Tensor& outputTensor) {
15411541
StreamInfo& activeStreamInfo = streamInfos_[streamIndex];
15421542
SwsContext* swsContext = activeStreamInfo.swsContext.get();
@@ -1546,40 +1546,40 @@ int VideoDecoder::convertFrameToTensorUsingSwsScale(
15461546
int linesizes[4] = {expectedOutputWidth * 3, 0, 0, 0};
15471547
int resultHeight = sws_scale(
15481548
swsContext,
1549-
frame->data,
1550-
frame->linesize,
1549+
avFrame->data,
1550+
avFrame->linesize,
15511551
0,
1552-
frame->height,
1552+
avFrame->height,
15531553
pointers,
15541554
linesizes);
15551555
return resultHeight;
15561556
}
15571557

1558-
torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
1558+
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
15591559
int streamIndex,
1560-
const AVFrame* frame) {
1560+
const AVFrame* avFrame) {
15611561
FilterState& filterState = streamInfos_[streamIndex].filterState;
1562-
int ffmpegStatus = av_buffersrc_write_frame(filterState.sourceContext, frame);
1562+
int ffmpegStatus = av_buffersrc_write_frame(filterState.sourceContext, avFrame);
15631563
if (ffmpegStatus < AVSUCCESS) {
15641564
throw std::runtime_error("Failed to add frame to buffer source context");
15651565
}
15661566

1567-
UniqueAVFrame filteredFrame(av_frame_alloc());
1567+
UniqueAVFrame filteredAVFrame(av_frame_alloc());
15681568
ffmpegStatus =
1569-
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
1570-
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);
1569+
av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get());
1570+
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);
15711571

1572-
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredFrame.get());
1572+
auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
15731573
int height = frameDims.height;
15741574
int width = frameDims.width;
15751575
std::vector<int64_t> shape = {height, width, 3};
1576-
std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1};
1577-
AVFrame* filteredFramePtr = filteredFrame.release();
1578-
auto deleter = [filteredFramePtr](void*) {
1579-
UniqueAVFrame frameToDelete(filteredFramePtr);
1576+
std::vector<int64_t> strides = {filteredAVFrame->linesize[0], 3, 1};
1577+
AVFrame* filteredAVFramePtr = filteredAVFrame.release();
1578+
auto deleter = [filteredAVFramePtr](void*) {
1579+
UniqueAVFrame avFrameToDelete(filteredAVFramePtr);
15801580
};
15811581
return torch::from_blob(
1582-
filteredFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
1582+
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
15831583
}
15841584

15851585
VideoDecoder::~VideoDecoder() {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,12 +393,12 @@ class VideoDecoder {
393393
int streamIndex,
394394
AVCodecContext* codecContext);
395395
void populateVideoMetadataFromStreamIndex(int streamIndex);
396-
torch::Tensor convertFrameToTensorUsingFilterGraph(
396+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
397397
int streamIndex,
398-
const AVFrame* frame);
399-
int convertFrameToTensorUsingSwsScale(
398+
const AVFrame* avFrame);
399+
int convertAVFrameToTensorUsingSwsScale(
400400
int streamIndex,
401-
const AVFrame* frame,
401+
const AVFrame* avFrame,
402402
torch::Tensor& outputTensor);
403403
DecodedOutput convertAVFrameToDecodedOutput(
404404
RawDecodedOutput& rawOutput,

0 commit comments

Comments
 (0)