Skip to content

Commit ae60012

Browse files
author
pytorchbot
committed
2025-03-19 nightly release (23c73ea)
1 parent 20e811f commit ae60012

File tree

8 files changed

+61
-82
lines changed

8 files changed

+61
-82
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToFrameOutputOnCuda(
1818
const torch::Device& device,
1919
[[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions,
20-
[[maybe_unused]] VideoDecoder::AVFrameStream& avFrameStream,
20+
[[maybe_unused]] UniqueAVFrame& avFrame,
2121
[[maybe_unused]] VideoDecoder::FrameOutput& frameOutput,
2222
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {
2323
throwUnsupportedDeviceError(device);

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,15 @@ void initializeContextOnCuda(
190190
void convertAVFrameToFrameOutputOnCuda(
191191
const torch::Device& device,
192192
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193-
VideoDecoder::AVFrameStream& avFrameStream,
193+
UniqueAVFrame& avFrame,
194194
VideoDecoder::FrameOutput& frameOutput,
195195
std::optional<torch::Tensor> preAllocatedOutputTensor) {
196-
AVFrame* avFrame = avFrameStream.avFrame.get();
197-
198196
TORCH_CHECK(
199197
avFrame->format == AV_PIX_FMT_CUDA,
200198
"Expected format to be AV_PIX_FMT_CUDA, got " +
201199
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
202200
auto frameDims =
203-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame);
201+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
204202
int height = frameDims.height;
205203
int width = frameDims.width;
206204
torch::Tensor& dst = frameOutput.data;

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void initializeContextOnCuda(
3232
void convertAVFrameToFrameOutputOnCuda(
3333
const torch::Device& device,
3434
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
35-
VideoDecoder::AVFrameStream& avFrameStream,
35+
UniqueAVFrame& avFrame,
3636
VideoDecoder::FrameOutput& frameOutput,
3737
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
3838

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,11 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode) {
4848
return std::string(errorBuffer);
4949
}
5050

51-
int64_t getDuration(const UniqueAVFrame& frame) {
52-
return getDuration(frame.get());
53-
}
54-
55-
int64_t getDuration(const AVFrame* frame) {
51+
int64_t getDuration(const UniqueAVFrame& avFrame) {
5652
#if LIBAVUTIL_VERSION_MAJOR < 58
57-
return frame->pkt_duration;
53+
return avFrame->pkt_duration;
5854
#else
59-
return frame->duration;
55+
return avFrame->duration;
6056
#endif
6157
}
6258

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
140140
// struct member representing duration has changed across the versions we
141141
// support.
142142
int64_t getDuration(const UniqueAVFrame& frame);
143-
int64_t getDuration(const AVFrame* frame);
144143

145144
int getNumChannels(const UniqueAVFrame& avFrame);
146145
int getNumChannels(const UniqueAVCodecContext& avCodecContext);

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
583583
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
584584
std::optional<torch::Tensor> preAllocatedOutputTensor) {
585585
validateActiveStream();
586-
AVFrameStream avFrameStream = decodeAVFrame(
587-
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
588-
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
586+
UniqueAVFrame avFrame = decodeAVFrame(
587+
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
588+
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
589589
}
590590

591591
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
@@ -715,8 +715,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
715715
}
716716

717717
setCursorPtsInSeconds(seconds);
718-
AVFrameStream avFrameStream =
719-
decodeAVFrame([seconds, this](AVFrame* avFrame) {
718+
UniqueAVFrame avFrame =
719+
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
720720
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
721721
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
722722
double frameEndTime = ptsToSeconds(
@@ -735,7 +735,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
735735
});
736736

737737
// Convert the frame to tensor.
738-
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream);
738+
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame);
739739
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
740740
return frameOutput;
741741
}
@@ -891,14 +891,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
891891
auto finished = false;
892892
while (!finished) {
893893
try {
894-
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
895-
return startPts < avFrame->pts + getDuration(avFrame);
896-
});
897-
// TODO: it's not great that we are getting a FrameOutput, which is
898-
// intended for videos. We should consider bypassing
899-
// convertAVFrameToFrameOutput and directly call
900-
// convertAudioAVFrameToFrameOutputOnCPU.
901-
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
894+
UniqueAVFrame avFrame =
895+
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
896+
return startPts < avFrame->pts + getDuration(avFrame);
897+
});
898+
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
902899
firstFramePtsSeconds =
903900
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
904901
frames.push_back(frameOutput.data);
@@ -1035,8 +1032,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
10351032
// LOW-LEVEL DECODING
10361033
// --------------------------------------------------------------------------
10371034

1038-
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1039-
std::function<bool(AVFrame*)> filterFunction) {
1035+
UniqueAVFrame VideoDecoder::decodeAVFrame(
1036+
std::function<bool(const UniqueAVFrame&)> filterFunction) {
10401037
validateActiveStream();
10411038

10421039
resetDecodeStats();
@@ -1064,7 +1061,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
10641061

10651062
decodeStats_.numFramesReceivedByDecoder++;
10661063
// Is this the kind of frame we're looking for?
1067-
if (status == AVSUCCESS && filterFunction(avFrame.get())) {
1064+
if (status == AVSUCCESS && filterFunction(avFrame)) {
10681065
// Yes, this is the frame we'll return; break out of the decoding loop.
10691066
break;
10701067
} else if (status == AVSUCCESS) {
@@ -1150,37 +1147,35 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
11501147
streamInfo.lastDecodedAvFramePts = avFrame->pts;
11511148
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11521149

1153-
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
1150+
return avFrame;
11541151
}
11551152

11561153
// --------------------------------------------------------------------------
11571154
// AVFRAME <-> FRAME OUTPUT CONVERSION
11581155
// --------------------------------------------------------------------------
11591156

11601157
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1161-
VideoDecoder::AVFrameStream& avFrameStream,
1158+
UniqueAVFrame& avFrame,
11621159
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11631160
// Convert the frame to tensor.
11641161
FrameOutput frameOutput;
1165-
int streamIndex = avFrameStream.streamIndex;
1166-
AVFrame* avFrame = avFrameStream.avFrame.get();
1167-
frameOutput.streamIndex = streamIndex;
1168-
auto& streamInfo = streamInfos_[streamIndex];
1162+
auto& streamInfo = streamInfos_[activeStreamIndex_];
11691163
frameOutput.ptsSeconds = ptsToSeconds(
1170-
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
1164+
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
11711165
frameOutput.durationSeconds = ptsToSeconds(
1172-
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1166+
getDuration(avFrame),
1167+
formatContext_->streams[activeStreamIndex_]->time_base);
11731168
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
11741169
convertAudioAVFrameToFrameOutputOnCPU(
1175-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1170+
avFrame, frameOutput, preAllocatedOutputTensor);
11761171
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
11771172
convertAVFrameToFrameOutputOnCPU(
1178-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1173+
avFrame, frameOutput, preAllocatedOutputTensor);
11791174
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
11801175
convertAVFrameToFrameOutputOnCuda(
11811176
streamInfo.videoStreamOptions.device,
11821177
streamInfo.videoStreamOptions,
1183-
avFrameStream,
1178+
avFrame,
11841179
frameOutput,
11851180
preAllocatedOutputTensor);
11861181
} else {
@@ -1201,14 +1196,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12011196
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
12021197
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
12031198
void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
1204-
VideoDecoder::AVFrameStream& avFrameStream,
1199+
UniqueAVFrame& avFrame,
12051200
FrameOutput& frameOutput,
12061201
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1207-
AVFrame* avFrame = avFrameStream.avFrame.get();
12081202
auto& streamInfo = streamInfos_[activeStreamIndex_];
12091203

12101204
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1211-
streamInfo.videoStreamOptions, *avFrame);
1205+
streamInfo.videoStreamOptions, avFrame);
12121206
int expectedOutputHeight = frameDims.height;
12131207
int expectedOutputWidth = frameDims.width;
12141208

@@ -1302,7 +1296,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
13021296
}
13031297

13041298
int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
1305-
const AVFrame* avFrame,
1299+
const UniqueAVFrame& avFrame,
13061300
torch::Tensor& outputTensor) {
13071301
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
13081302
SwsContext* swsContext = activeStreamInfo.swsContext.get();
@@ -1322,11 +1316,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
13221316
}
13231317

13241318
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1325-
const AVFrame* avFrame) {
1319+
const UniqueAVFrame& avFrame) {
13261320
FilterGraphContext& filterGraphContext =
13271321
streamInfos_[activeStreamIndex_].filterGraphContext;
13281322
int status =
1329-
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame);
1323+
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame.get());
13301324
if (status < AVSUCCESS) {
13311325
throw std::runtime_error("Failed to add frame to buffer source context");
13321326
}
@@ -1350,25 +1344,25 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13501344
}
13511345

13521346
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1353-
VideoDecoder::AVFrameStream& avFrameStream,
1347+
UniqueAVFrame& srcAVFrame,
13541348
FrameOutput& frameOutput,
13551349
std::optional<torch::Tensor> preAllocatedOutputTensor) {
13561350
TORCH_CHECK(
13571351
!preAllocatedOutputTensor.has_value(),
13581352
"pre-allocated audio tensor not supported yet.");
13591353

13601354
AVSampleFormat sourceSampleFormat =
1361-
static_cast<AVSampleFormat>(avFrameStream.avFrame->format);
1355+
static_cast<AVSampleFormat>(srcAVFrame->format);
13621356
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13631357

13641358
UniqueAVFrame convertedAVFrame;
13651359
if (sourceSampleFormat != desiredSampleFormat) {
13661360
convertedAVFrame = convertAudioAVFrameSampleFormat(
1367-
avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat);
1361+
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
13681362
}
13691363
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
13701364
? convertedAVFrame
1371-
: avFrameStream.avFrame;
1365+
: srcAVFrame;
13721366

13731367
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13741368
TORCH_CHECK(
@@ -1944,10 +1938,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
19441938

19451939
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
19461940
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
1947-
const AVFrame& avFrame) {
1941+
const UniqueAVFrame& avFrame) {
19481942
return FrameDims(
1949-
videoStreamOptions.height.value_or(avFrame.height),
1950-
videoStreamOptions.width.value_or(avFrame.width));
1943+
videoStreamOptions.height.value_or(avFrame->height),
1944+
videoStreamOptions.width.value_or(avFrame->width));
19511945
}
19521946

19531947
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,13 @@ class VideoDecoder {
153153
// They are the equivalent of the user-facing Frame and FrameBatch classes in
154154
// Python. They contain RGB decoded frames along with some associated data
155155
// like PTS and duration.
156+
// FrameOutput is also relevant for audio decoding, typically as the output of
157+
// getNextFrame(), or as a temporary output variable.
156158
struct FrameOutput {
157-
torch::Tensor data; // 3D: of shape CHW or HWC.
158-
int streamIndex;
159+
// data shape is:
160+
// - 3D (C, H, W) or (H, W, C) for videos
161+
// - 2D (numChannels, numSamples) for audio
162+
torch::Tensor data;
159163
double ptsSeconds;
160164
double durationSeconds;
161165
};
@@ -244,23 +248,6 @@ class VideoDecoder {
244248
// These are APIs that should be private, but that are effectively exposed for
245249
// practical reasons, typically for testing purposes.
246250

247-
// This struct is needed because AVFrame doesn't retain the streamIndex. Only
248-
// the AVPacket knows its stream. This is what the low-level private decoding
249-
// entry points return. The AVFrameStream is then converted to a FrameOutput
250-
// with convertAVFrameToFrameOutput. It should be private, but is currently
251-
// used by DeviceInterface.
252-
struct AVFrameStream {
253-
// The actual decoded output as a unique pointer to an AVFrame.
254-
// Usually, this is a YUV frame. It'll be converted to RGB in
255-
// convertAVFrameToFrameOutput.
256-
UniqueAVFrame avFrame;
257-
// The stream index of the decoded frame.
258-
int streamIndex;
259-
260-
explicit AVFrameStream(UniqueAVFrame&& a, int s)
261-
: avFrame(std::move(a)), streamIndex(s) {}
262-
};
263-
264251
// Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
265252
// can move it back to private.
266253
FrameOutput getFrameAtIndexInternal(
@@ -376,31 +363,33 @@ class VideoDecoder {
376363

377364
void maybeSeekToBeforeDesiredPts();
378365

379-
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
366+
UniqueAVFrame decodeAVFrame(
367+
std::function<bool(const UniqueAVFrame&)> filterFunction);
380368

381369
FrameOutput getNextFrameInternal(
382370
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
383371

384372
torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
385373

386374
FrameOutput convertAVFrameToFrameOutput(
387-
AVFrameStream& avFrameStream,
375+
UniqueAVFrame& avFrame,
388376
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
389377

390378
void convertAVFrameToFrameOutputOnCPU(
391-
AVFrameStream& avFrameStream,
379+
UniqueAVFrame& avFrame,
392380
FrameOutput& frameOutput,
393381
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
394382

395383
void convertAudioAVFrameToFrameOutputOnCPU(
396-
AVFrameStream& avFrameStream,
384+
UniqueAVFrame& srcAVFrame,
397385
FrameOutput& frameOutput,
398386
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
399387

400-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
388+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
389+
const UniqueAVFrame& avFrame);
401390

402391
int convertAVFrameToTensorUsingSwsScale(
403-
const AVFrame* avFrame,
392+
const UniqueAVFrame& avFrame,
404393
torch::Tensor& outputTensor);
405394

406395
UniqueAVFrame convertAudioAVFrameSampleFormat(
@@ -568,7 +557,7 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
568557

569558
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
570559
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
571-
const AVFrame& avFrame);
560+
const UniqueAVFrame& avFrame);
572561

573562
torch::Tensor allocateEmptyHWCTensor(
574563
int height,

test/decoders/test_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,12 +650,15 @@ def test_next(self, asset):
650650
frame_index = 0
651651
while True:
652652
try:
653-
frame, *_ = get_next_frame(decoder)
653+
frame, pts_seconds, duration_seconds = get_next_frame(decoder)
654654
except IndexError:
655655
break
656656
torch.testing.assert_close(
657657
frame, asset.get_frame_data_by_index(frame_index)
658658
)
659+
frame_info = asset.get_frame_info(frame_index)
660+
assert pts_seconds == frame_info.pts_seconds
661+
assert duration_seconds == frame_info.duration_seconds
659662
frame_index += 1
660663

661664
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)