Skip to content

Commit 3141c18

Browse files
committed
Remove AVFrameStream struct
1 parent 5713507 commit 3141c18

File tree

5 files changed

+52
-73
lines changed

5 files changed

+52
-73
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToFrameOutputOnCuda(
1818
const torch::Device& device,
1919
[[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions,
20-
[[maybe_unused]] VideoDecoder::AVFrameStream& avFrameStream,
20+
[[maybe_unused]] UniqueAVFrame& avFrame,
2121
[[maybe_unused]] VideoDecoder::FrameOutput& frameOutput,
2222
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {
2323
throwUnsupportedDeviceError(device);

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,15 @@ void initializeContextOnCuda(
190190
void convertAVFrameToFrameOutputOnCuda(
191191
const torch::Device& device,
192192
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193-
VideoDecoder::AVFrameStream& avFrameStream,
193+
UniqueAVFrame& avFrame,
194194
VideoDecoder::FrameOutput& frameOutput,
195195
std::optional<torch::Tensor> preAllocatedOutputTensor) {
196-
AVFrame* avFrame = avFrameStream.avFrame.get();
197-
198196
TORCH_CHECK(
199197
avFrame->format == AV_PIX_FMT_CUDA,
200198
"Expected format to be AV_PIX_FMT_CUDA, got " +
201199
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
202-
auto frameDims =
203-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame);
200+
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
201+
videoStreamOptions, *(avFrame.get()));
204202
int height = frameDims.height;
205203
int width = frameDims.width;
206204
torch::Tensor& dst = frameOutput.data;

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void initializeContextOnCuda(
3232
void convertAVFrameToFrameOutputOnCuda(
3333
const torch::Device& device,
3434
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
35-
VideoDecoder::AVFrameStream& avFrameStream,
35+
UniqueAVFrame& avFrame,
3636
VideoDecoder::FrameOutput& frameOutput,
3737
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
3838

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
583583
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
584584
std::optional<torch::Tensor> preAllocatedOutputTensor) {
585585
validateActiveStream();
586-
AVFrameStream avFrameStream = decodeAVFrame(
586+
UniqueAVFrame avFrame = decodeAVFrame(
587587
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
588-
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
588+
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
589589
}
590590

591591
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
@@ -715,27 +715,26 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
715715
}
716716

717717
setCursorPtsInSeconds(seconds);
718-
AVFrameStream avFrameStream =
719-
decodeAVFrame([seconds, this](AVFrame* avFrame) {
720-
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
721-
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
722-
double frameEndTime = ptsToSeconds(
723-
avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
724-
if (frameStartTime > seconds) {
725-
// FFMPEG seeked past the frame we are looking for even though we
726-
// set max_ts to be our needed timestamp in avformat_seek_file()
727-
// in maybeSeekToBeforeDesiredPts().
728-
// This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137
729-
// In this case we return the very next frame instead of throwing an
730-
// exception.
731-
// TODO: Maybe log to stderr for Debug builds?
732-
return true;
733-
}
734-
return seconds >= frameStartTime && seconds < frameEndTime;
735-
});
718+
UniqueAVFrame avFrame = decodeAVFrame([seconds, this](AVFrame* avFrame) {
719+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
720+
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
721+
double frameEndTime =
722+
ptsToSeconds(avFrame->pts + getDuration(avFrame), streamInfo.timeBase);
723+
if (frameStartTime > seconds) {
724+
// FFMPEG seeked past the frame we are looking for even though we
725+
// set max_ts to be our needed timestamp in avformat_seek_file()
726+
// in maybeSeekToBeforeDesiredPts().
727+
// This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137
728+
// In this case we return the very next frame instead of throwing an
729+
// exception.
730+
// TODO: Maybe log to stderr for Debug builds?
731+
return true;
732+
}
733+
return seconds >= frameStartTime && seconds < frameEndTime;
734+
});
736735

737736
// Convert the frame to tensor.
738-
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream);
737+
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame);
739738
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
740739
return frameOutput;
741740
}
@@ -891,14 +890,14 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
891890
auto finished = false;
892891
while (!finished) {
893892
try {
894-
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
893+
UniqueAVFrame avFrame = decodeAVFrame([startPts](AVFrame* avFrame) {
895894
return startPts < avFrame->pts + getDuration(avFrame);
896895
});
897896
// TODO: it's not great that we are getting a FrameOutput, which is
898897
// intended for videos. We should consider bypassing
899898
// convertAVFrameToFrameOutput and directly call
900899
// convertAudioAVFrameToFrameOutputOnCPU.
901-
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
900+
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
902901
firstFramePtsSeconds =
903902
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
904903
frames.push_back(frameOutput.data);
@@ -1035,7 +1034,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
10351034
// LOW-LEVEL DECODING
10361035
// --------------------------------------------------------------------------
10371036

1038-
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1037+
UniqueAVFrame VideoDecoder::decodeAVFrame(
10391038
std::function<bool(AVFrame*)> filterFunction) {
10401039
validateActiveStream();
10411040

@@ -1150,37 +1149,36 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
11501149
streamInfo.lastDecodedAvFramePts = avFrame->pts;
11511150
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11521151

1153-
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
1152+
return avFrame;
11541153
}
11551154

11561155
// --------------------------------------------------------------------------
11571156
// AVFRAME <-> FRAME OUTPUT CONVERSION
11581157
// --------------------------------------------------------------------------
11591158

11601159
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1161-
VideoDecoder::AVFrameStream& avFrameStream,
1160+
UniqueAVFrame& avFrame,
11621161
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11631162
// Convert the frame to tensor.
11641163
FrameOutput frameOutput;
1165-
int streamIndex = avFrameStream.streamIndex;
1166-
AVFrame* avFrame = avFrameStream.avFrame.get();
1167-
frameOutput.streamIndex = streamIndex;
1168-
auto& streamInfo = streamInfos_[streamIndex];
1164+
frameOutput.streamIndex = activeStreamIndex_;
1165+
auto& streamInfo = streamInfos_[activeStreamIndex_];
11691166
frameOutput.ptsSeconds = ptsToSeconds(
1170-
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
1167+
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
11711168
frameOutput.durationSeconds = ptsToSeconds(
1172-
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1169+
getDuration(avFrame),
1170+
formatContext_->streams[activeStreamIndex_]->time_base);
11731171
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
11741172
convertAudioAVFrameToFrameOutputOnCPU(
1175-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1173+
avFrame, frameOutput, preAllocatedOutputTensor);
11761174
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
11771175
convertAVFrameToFrameOutputOnCPU(
1178-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1176+
avFrame, frameOutput, preAllocatedOutputTensor);
11791177
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
11801178
convertAVFrameToFrameOutputOnCuda(
11811179
streamInfo.videoStreamOptions.device,
11821180
streamInfo.videoStreamOptions,
1183-
avFrameStream,
1181+
avFrame,
11841182
frameOutput,
11851183
preAllocatedOutputTensor);
11861184
} else {
@@ -1201,14 +1199,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12011199
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
12021200
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
12031201
void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
1204-
VideoDecoder::AVFrameStream& avFrameStream,
1202+
UniqueAVFrame& avFrame,
12051203
FrameOutput& frameOutput,
12061204
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1207-
AVFrame* avFrame = avFrameStream.avFrame.get();
12081205
auto& streamInfo = streamInfos_[activeStreamIndex_];
12091206

12101207
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1211-
streamInfo.videoStreamOptions, *avFrame);
1208+
streamInfo.videoStreamOptions, *(avFrame.get()));
12121209
int expectedOutputHeight = frameDims.height;
12131210
int expectedOutputWidth = frameDims.width;
12141211

@@ -1251,7 +1248,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
12511248
streamInfo.prevFrameContext = frameContext;
12521249
}
12531250
int resultHeight =
1254-
convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor);
1251+
convertAVFrameToTensorUsingSwsScale(avFrame.get(), outputTensor);
12551252
// If this check failed, it would mean that the frame wasn't reshaped to
12561253
// the expected height.
12571254
// TODO: Can we do the same check for width?
@@ -1271,7 +1268,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
12711268
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
12721269
streamInfo.prevFrameContext = frameContext;
12731270
}
1274-
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame);
1271+
outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame.get());
12751272

12761273
// Similarly to above, if this check fails it means the frame wasn't
12771274
// reshaped to its expected dimensions by filtergraph.
@@ -1350,25 +1347,25 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13501347
}
13511348

13521349
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1353-
VideoDecoder::AVFrameStream& avFrameStream,
1350+
UniqueAVFrame& srcAVFrame,
13541351
FrameOutput& frameOutput,
13551352
std::optional<torch::Tensor> preAllocatedOutputTensor) {
13561353
TORCH_CHECK(
13571354
!preAllocatedOutputTensor.has_value(),
13581355
"pre-allocated audio tensor not supported yet.");
13591356

13601357
AVSampleFormat sourceSampleFormat =
1361-
static_cast<AVSampleFormat>(avFrameStream.avFrame->format);
1358+
static_cast<AVSampleFormat>(srcAVFrame->format);
13621359
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13631360

13641361
UniqueAVFrame convertedAVFrame;
13651362
if (sourceSampleFormat != desiredSampleFormat) {
13661363
convertedAVFrame = convertAudioAVFrameSampleFormat(
1367-
avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat);
1364+
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
13681365
}
13691366
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
13701367
? convertedAVFrame
1371-
: avFrameStream.avFrame;
1368+
: srcAVFrame;
13721369

13731370
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13741371
TORCH_CHECK(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -244,23 +244,6 @@ class VideoDecoder {
244244
// These are APIs that should be private, but that are effectively exposed for
245245
// practical reasons, typically for testing purposes.
246246

247-
// This struct is needed because AVFrame doesn't retain the streamIndex. Only
248-
// the AVPacket knows its stream. This is what the low-level private decoding
249-
// entry points return. The AVFrameStream is then converted to a FrameOutput
250-
// with convertAVFrameToFrameOutput. It should be private, but is currently
251-
// used by DeviceInterface.
252-
struct AVFrameStream {
253-
// The actual decoded output as a unique pointer to an AVFrame.
254-
// Usually, this is a YUV frame. It'll be converted to RGB in
255-
// convertAVFrameToFrameOutput.
256-
UniqueAVFrame avFrame;
257-
// The stream index of the decoded frame.
258-
int streamIndex;
259-
260-
explicit AVFrameStream(UniqueAVFrame&& a, int s)
261-
: avFrame(std::move(a)), streamIndex(s) {}
262-
};
263-
264247
// Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
265248
// can move it back to private.
266249
FrameOutput getFrameAtIndexInternal(
@@ -376,28 +359,29 @@ class VideoDecoder {
376359

377360
void maybeSeekToBeforeDesiredPts();
378361

379-
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
362+
UniqueAVFrame decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
380363

381364
FrameOutput getNextFrameInternal(
382365
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
383366

384367
torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
385368

386369
FrameOutput convertAVFrameToFrameOutput(
387-
AVFrameStream& avFrameStream,
370+
UniqueAVFrame& avFrame,
388371
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
389372

390373
void convertAVFrameToFrameOutputOnCPU(
391-
AVFrameStream& avFrameStream,
374+
UniqueAVFrame& avFrame,
392375
FrameOutput& frameOutput,
393376
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
394377

395378
void convertAudioAVFrameToFrameOutputOnCPU(
396-
AVFrameStream& avFrameStream,
379+
UniqueAVFrame& avFrame,
397380
FrameOutput& frameOutput,
398381
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
399382

400-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
383+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
384+
const AVFrame* srcAVFrame);
401385

402386
int convertAVFrameToTensorUsingSwsScale(
403387
const AVFrame* avFrame,

0 commit comments

Comments
 (0)