Skip to content

Commit ebdd402

Browse files
author
pytorchbot
committed
2025-03-13 nightly release (d75fc58)
1 parent ef4cbfe commit ebdd402

16 files changed

+541
-55
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ format you want. Refer to Nvidia's GPU support matrix for more details
152152
the CUDA Toolkit.
153153

154154
2. Install or compile FFmpeg with NVDEC support.
155-
TorchCodec with CUDA should work with FFmpeg versions in [5, 7].
155+
TorchCodec with CUDA should work with FFmpeg versions in [4, 7].
156156

157157
If FFmpeg is not already installed, or you need a more recent version, an
158158
easy way to install it is to use `conda`:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[project]
2-
name = "TorchCodec"
2+
name = "torchcodec"
33
description = "A video decoder for PyTorch"
44
readme = "README.md"
55
requires-python = ">=3.8"

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,22 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63-
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
63+
int getNumChannels(const AVFrame* avFrame) {
6464
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6565
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66-
int numChannels = avCodecContext->ch_layout.nb_channels;
66+
return avFrame->ch_layout.nb_channels;
6767
#else
68-
int numChannels = avCodecContext->channels;
68+
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
6969
#endif
70+
}
7071

71-
return static_cast<int64_t>(numChannels);
72+
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
73+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
74+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
75+
return avCodecContext->ch_layout.nb_channels;
76+
#else
77+
return avCodecContext->channels;
78+
#endif
7279
}
7380

7481
AVIOBytesContext::AVIOBytesContext(

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142-
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
142+
int getNumChannels(const AVFrame* avFrame);
143+
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
143144

144145
// Returns true if sws_scale can handle unaligned data.
145146
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <cstdint>
99
#include <cstdio>
1010
#include <iostream>
11+
#include <limits>
1112
#include <sstream>
1213
#include <stdexcept>
1314
#include <string_view>
@@ -552,7 +553,8 @@ void VideoDecoder::addAudioStream(int streamIndex) {
552553
containerMetadata_.allStreamMetadata[activeStreamIndex_];
553554
streamMetadata.sampleRate =
554555
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
555-
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
556+
streamMetadata.numChannels =
557+
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
556558
}
557559

558560
// --------------------------------------------------------------------------
@@ -567,6 +569,7 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
567569

568570
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
569571
std::optional<torch::Tensor> preAllocatedOutputTensor) {
572+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
570573
AVFrameStream avFrameStream = decodeAVFrame(
571574
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
572575
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
@@ -685,6 +688,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) {
685688
}
686689

687690
VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
691+
validateActiveStream(AVMEDIA_TYPE_VIDEO);
688692
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
689693
double frameStartTime =
690694
ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase);
@@ -757,7 +761,6 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
757761
double startSeconds,
758762
double stopSeconds) {
759763
validateActiveStream(AVMEDIA_TYPE_VIDEO);
760-
761764
const auto& streamMetadata =
762765
containerMetadata_.allStreamMetadata[activeStreamIndex_];
763766
TORCH_CHECK(
@@ -835,6 +838,74 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
835838
return frameBatchOutput;
836839
}
837840

841+
VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
842+
double startSeconds,
843+
std::optional<double> stopSecondsOptional) {
844+
validateActiveStream(AVMEDIA_TYPE_AUDIO);
845+
846+
double stopSeconds =
847+
stopSecondsOptional.value_or(std::numeric_limits<double>::max());
848+
849+
TORCH_CHECK(
850+
startSeconds <= stopSeconds,
851+
"Start seconds (" + std::to_string(startSeconds) +
852+
") must be less than or equal to stop seconds (" +
853+
std::to_string(stopSeconds) + ").");
854+
855+
if (startSeconds == stopSeconds) {
856+
// For consistency with video
857+
return AudioFramesOutput{torch::empty({0}), 0.0};
858+
}
859+
860+
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
861+
862+
auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase);
863+
if (startPts < streamInfo.lastDecodedAvFramePts +
864+
streamInfo.lastDecodedAvFrameDuration) {
865+
// If we need to seek backwards, then we have to seek back to the beginning
866+
// of the stream.
867+
// TODO-AUDIO: document why this is needed in a big comment.
868+
setCursorPtsInSeconds(INT64_MIN);
869+
}
870+
871+
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
872+
// cat(). This would save a copy. We know the duration of the output and the
873+
// sample rate, so in theory we know the number of output samples.
874+
std::vector<torch::Tensor> frames;
875+
876+
double firstFramePtsSeconds = std::numeric_limits<double>::max();
877+
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
878+
auto finished = false;
879+
while (!finished) {
880+
try {
881+
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
882+
return startPts < avFrame->pts + getDuration(avFrame);
883+
});
884+
// TODO: it's not great that we are getting a FrameOutput, which is
885+
// intended for videos. We should consider bypassing
886+
// convertAVFrameToFrameOutput and directly call
887+
// convertAudioAVFrameToFrameOutputOnCPU.
888+
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
889+
firstFramePtsSeconds =
890+
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
891+
frames.push_back(frameOutput.data);
892+
} catch (const EndOfFileException& e) {
893+
finished = true;
894+
}
895+
896+
// If stopSeconds is in [begin, end] of the last decoded frame, we should
897+
// stop decoding more frames. Note that if we were to use [begin, end),
898+
// which may seem more natural, then we would decode the frame starting at
899+
// stopSeconds, which isn't what we want!
900+
auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts +
901+
streamInfo.lastDecodedAvFrameDuration;
902+
finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts &&
903+
(stopPts <= lastDecodedAvFrameEnd);
904+
}
905+
906+
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
907+
}
908+
838909
// --------------------------------------------------------------------------
839910
// SEEKING APIs
840911
// --------------------------------------------------------------------------
@@ -871,6 +942,12 @@ I P P P I P P P I P P I P P I P
871942
(2) is more efficient than (1) if there is an I frame between x and y.
872943
*/
873944
bool VideoDecoder::canWeAvoidSeeking() const {
945+
const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_);
946+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
947+
// For audio, we only need to seek if a backwards seek was requested within
948+
// getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called.
949+
return !cursorWasJustSet_;
950+
}
874951
int64_t lastDecodedAvFramePts =
875952
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
876953
if (cursor_ < lastDecodedAvFramePts) {
@@ -897,7 +974,7 @@ bool VideoDecoder::canWeAvoidSeeking() const {
897974
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
898975
// the comment of canWeAvoidSeeking() for details.
899976
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
900-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
977+
validateActiveStream();
901978
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
902979

903980
decodeStats_.numSeeksAttempted++;
@@ -942,7 +1019,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
9421019

9431020
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
9441021
std::function<bool(AVFrame*)> filterFunction) {
945-
validateActiveStream(AVMEDIA_TYPE_VIDEO);
1022+
validateActiveStream();
9461023

9471024
resetDecodeStats();
9481025

@@ -1071,13 +1148,14 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10711148
AVFrame* avFrame = avFrameStream.avFrame.get();
10721149
frameOutput.streamIndex = streamIndex;
10731150
auto& streamInfo = streamInfos_[streamIndex];
1074-
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
10751151
frameOutput.ptsSeconds = ptsToSeconds(
10761152
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
10771153
frameOutput.durationSeconds = ptsToSeconds(
10781154
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1079-
// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
1080-
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
1155+
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1156+
convertAudioAVFrameToFrameOutputOnCPU(
1157+
avFrameStream, frameOutput, preAllocatedOutputTensor);
1158+
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
10811159
convertAVFrameToFrameOutputOnCPU(
10821160
avFrameStream, frameOutput, preAllocatedOutputTensor);
10831161
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
@@ -1253,6 +1331,45 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
12531331
filteredAVFramePtr->data[0], shape, strides, deleter, {torch::kUInt8});
12541332
}
12551333

1334+
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1335+
VideoDecoder::AVFrameStream& avFrameStream,
1336+
FrameOutput& frameOutput,
1337+
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1338+
TORCH_CHECK(
1339+
!preAllocatedOutputTensor.has_value(),
1340+
"pre-allocated audio tensor not supported yet.");
1341+
1342+
const AVFrame* avFrame = avFrameStream.avFrame.get();
1343+
1344+
auto numSamples = avFrame->nb_samples; // per channel
1345+
auto numChannels = getNumChannels(avFrame);
1346+
torch::Tensor outputData =
1347+
torch::empty({numChannels, numSamples}, torch::kFloat32);
1348+
1349+
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
1350+
// TODO-AUDIO Implement all formats.
1351+
switch (format) {
1352+
case AV_SAMPLE_FMT_FLTP: {
1353+
uint8_t* outputChannelData = static_cast<uint8_t*>(outputData.data_ptr());
1354+
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
1355+
for (auto channel = 0; channel < numChannels;
1356+
++channel, outputChannelData += numBytesPerChannel) {
1357+
memcpy(
1358+
outputChannelData,
1359+
avFrame->extended_data[channel],
1360+
numBytesPerChannel);
1361+
}
1362+
break;
1363+
}
1364+
default:
1365+
TORCH_CHECK(
1366+
false,
1367+
"Unsupported audio format (yet!): ",
1368+
av_get_sample_fmt_name(format));
1369+
}
1370+
frameOutput.data = outputData;
1371+
}
1372+
12561373
// --------------------------------------------------------------------------
12571374
// OUTPUT ALLOCATION AND SHAPE CONVERSION
12581375
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ class VideoDecoder {
170170
const StreamMetadata& streamMetadata);
171171
};
172172

173+
struct AudioFramesOutput {
174+
torch::Tensor data; // shape is (numChannels, numSamples)
175+
double ptsSeconds;
176+
};
177+
173178
// Places the cursor at the first frame on or after the position in seconds.
174179
// Calling getNextFrame() will return the first frame at
175180
// or after this position.
@@ -221,6 +226,11 @@ class VideoDecoder {
221226
double startSeconds,
222227
double stopSeconds);
223228

229+
// TODO-AUDIO: Should accept sampleRate
230+
AudioFramesOutput getFramesPlayedInRangeAudio(
231+
double startSeconds,
232+
std::optional<double> stopSecondsOptional = std::nullopt);
233+
224234
class EndOfFileException : public std::runtime_error {
225235
public:
226236
explicit EndOfFileException(const std::string& msg)
@@ -379,6 +389,11 @@ class VideoDecoder {
379389
FrameOutput& frameOutput,
380390
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
381391

392+
void convertAudioAVFrameToFrameOutputOnCPU(
393+
AVFrameStream& avFrameStream,
394+
FrameOutput& frameOutput,
395+
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
396+
382397
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
383398

384399
int convertAVFrameToTensorUsingSwsScale(

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ namespace facebook::torchcodec {
2525
// https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native#readme
2626
TORCH_LIBRARY(torchcodec_ns, m) {
2727
m.impl_abstract_pystub(
28-
"torchcodec.decoders._core.video_decoder_ops",
29-
"//pytorch/torchcodec:torchcodec");
28+
"torchcodec.decoders._core.ops", "//pytorch/torchcodec:torchcodec");
3029
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3130
m.def(
3231
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
@@ -48,6 +47,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4847
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4948
m.def(
5049
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
50+
m.def(
51+
"get_frames_by_pts_in_range_audio(Tensor(a!) decoder, *, float start_seconds, float? stop_seconds) -> (Tensor, Tensor)");
5152
m.def(
5253
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
5354
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
@@ -93,6 +94,13 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput(
9394
return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds);
9495
}
9596

97+
OpsAudioFramesOutput makeOpsAudioFramesOutput(
98+
VideoDecoder::AudioFramesOutput& audioFrames) {
99+
return std::make_tuple(
100+
audioFrames.data,
101+
torch::tensor(audioFrames.ptsSeconds, torch::dtype(torch::kFloat64)));
102+
}
103+
96104
VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
97105
if (seekMode == "exact") {
98106
return VideoDecoder::SeekMode::exact;
@@ -289,6 +297,16 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
289297
return makeOpsFrameBatchOutput(result);
290298
}
291299

300+
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
301+
at::Tensor& decoder,
302+
double start_seconds,
303+
std::optional<double> stop_seconds) {
304+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
305+
auto result =
306+
videoDecoder->getFramesPlayedInRangeAudio(start_seconds, stop_seconds);
307+
return makeOpsAudioFramesOutput(result);
308+
}
309+
292310
std::string quoteValue(const std::string& value) {
293311
return "\"" + value + "\"";
294312
}
@@ -540,6 +558,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
540558
m.impl("get_frames_at_indices", &get_frames_at_indices);
541559
m.impl("get_frames_in_range", &get_frames_in_range);
542560
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
561+
m.impl("get_frames_by_pts_in_range_audio", &get_frames_by_pts_in_range_audio);
543562
m.impl("get_frames_by_pts", &get_frames_by_pts);
544563
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);
545564
m.impl(

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ using OpsFrameOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7474
// single float.
7575
using OpsFrameBatchOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7676

77+
// The elements of this tuple are all tensors that represent the concatenation
78+
// of multiple audio frames:
79+
// 1. The frames data (concatenated)
80+
// 2. A single float value for the pts of the first frame, in seconds.
81+
using OpsAudioFramesOutput = std::tuple<at::Tensor, at::Tensor>;
82+
7783
// Return the frame that is visible at a given timestamp in seconds. Each frame
7884
// in FFMPEG has a presentation timestamp and a duration. The frame visible at a
7985
// given timestamp T has T >= PTS and T < PTS + Duration.
@@ -112,6 +118,11 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
112118
double start_seconds,
113119
double stop_seconds);
114120

121+
OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
122+
at::Tensor& decoder,
123+
double start_seconds,
124+
std::optional<double> stop_seconds = std::nullopt);
125+
115126
// For testing only. We need to implement this operation as a core library
116127
// function because what we're testing is round-tripping pts values as
117128
// double-precision floating point numbers from C++ to Python and back to C++.

0 commit comments

Comments
 (0)