Skip to content

Commit 49a6614

Browse files
committed
Merge branch 'main' of https://github.com/meta-pytorch/torchcodec into encoding_tutorial
2 parents 9bbeb1f + cac99ae commit 49a6614

File tree

16 files changed

+312
-173
lines changed

16 files changed

+312
-173
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ void tryToValidateCodecOption(
607607
"] for this codec. For more details, run 'ffmpeg -h encoder=",
608608
avCodec.name,
609609
"'");
610-
} catch (const std::invalid_argument& e) {
610+
} catch (const std::invalid_argument&) {
611611
TORCH_CHECK(
612612
false,
613613
"Option ",
@@ -662,7 +662,7 @@ VideoEncoder::~VideoEncoder() {
662662

663663
VideoEncoder::VideoEncoder(
664664
const torch::Tensor& frames,
665-
int frameRate,
665+
double frameRate,
666666
std::string_view fileName,
667667
const VideoStreamOptions& videoStreamOptions)
668668
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
@@ -694,7 +694,7 @@ VideoEncoder::VideoEncoder(
694694

695695
VideoEncoder::VideoEncoder(
696696
const torch::Tensor& frames,
697-
int frameRate,
697+
double frameRate,
698698
std::string_view formatName,
699699
std::unique_ptr<AVIOContextHolder> avioContextHolder,
700700
const VideoStreamOptions& videoStreamOptions)
@@ -787,9 +787,9 @@ void VideoEncoder::initializeEncoder(
787787
avCodecContext_->width = outWidth_;
788788
avCodecContext_->height = outHeight_;
789789
avCodecContext_->pix_fmt = outPixelFormat_;
790-
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
791-
avCodecContext_->time_base = {1, inFrameRate_};
792-
avCodecContext_->framerate = {inFrameRate_, 1};
790+
// TODO-VideoEncoder: Add and utilize output frame_rate option
791+
avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX);
792+
avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate);
793793

794794
// Set flag for containers that require extradata to be in the codec context
795795
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
@@ -833,6 +833,10 @@ void VideoEncoder::initializeEncoder(
833833

834834
// Set the stream time base to encode correct frame timestamps
835835
avStream_->time_base = avCodecContext_->time_base;
836+
// Set the stream frame rate to store correct frame durations for some
837+
// containers (webm, mkv)
838+
avStream_->r_frame_rate = avCodecContext_->framerate;
839+
836840
status = avcodec_parameters_from_context(
837841
avStream_->codecpar, avCodecContext_.get());
838842
TORCH_CHECK(

src/torchcodec/_core/Encoder.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ class VideoEncoder {
143143

144144
VideoEncoder(
145145
const torch::Tensor& frames,
146-
int frameRate,
146+
double frameRate,
147147
std::string_view fileName,
148148
const VideoStreamOptions& videoStreamOptions);
149149

150150
VideoEncoder(
151151
const torch::Tensor& frames,
152-
int frameRate,
152+
double frameRate,
153153
std::string_view formatName,
154154
std::unique_ptr<AVIOContextHolder> avioContextHolder,
155155
const VideoStreamOptions& videoStreamOptions);
@@ -172,7 +172,7 @@ class VideoEncoder {
172172
UniqueSwsContext swsContext_;
173173

174174
const torch::Tensor frames_;
175-
int inFrameRate_;
175+
double inFrameRate_;
176176

177177
int inWidth_ = -1;
178178
int inHeight_ = -1;

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) {
158158
#endif
159159
}
160160

161+
int getNumChannels(const AVCodecParameters* codecpar) {
162+
TORCH_CHECK(codecpar != nullptr, "codecpar is null")
163+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
164+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
165+
return codecpar->ch_layout.nb_channels;
166+
#else
167+
return codecpar->channels;
168+
#endif
169+
}
170+
161171
void setDefaultChannelLayout(
162172
UniqueAVCodecContext& avCodecContext,
163173
int numChannels) {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
180180

181181
int getNumChannels(const UniqueAVFrame& avFrame);
182182
int getNumChannels(const SharedAVCodecContext& avCodecContext);
183+
int getNumChannels(const AVCodecParameters* codecpar);
183184

184185
void setDefaultChannelLayout(
185186
UniqueAVCodecContext& avCodecContext,

src/torchcodec/_core/Metadata.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ std::optional<double> StreamMetadata::getDurationSeconds(
2929
return static_cast<double>(numFramesFromHeader.value()) /
3030
averageFpsFromHeader.value();
3131
}
32+
if (durationSecondsFromContainer.has_value()) {
33+
return durationSecondsFromContainer.value();
34+
}
3235
return std::nullopt;
3336
default:
3437
TORCH_CHECK(false, "Unknown SeekMode");
@@ -80,13 +83,13 @@ std::optional<int64_t> StreamMetadata::getNumFrames(SeekMode seekMode) const {
8083
numFramesFromContent.has_value(), "Missing numFramesFromContent");
8184
return numFramesFromContent.value();
8285
case SeekMode::approximate: {
86+
auto durationSeconds = getDurationSeconds(seekMode);
8387
if (numFramesFromHeader.has_value()) {
8488
return numFramesFromHeader.value();
8589
}
86-
if (averageFpsFromHeader.has_value() &&
87-
durationSecondsFromHeader.has_value()) {
90+
if (averageFpsFromHeader.has_value() && durationSeconds.has_value()) {
8891
return static_cast<int64_t>(
89-
averageFpsFromHeader.value() * durationSecondsFromHeader.value());
92+
averageFpsFromHeader.value() * durationSeconds.value());
9093
}
9194
return std::nullopt;
9295
}

src/torchcodec/_core/Metadata.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ enum class SeekMode { exact, approximate, custom_frame_mappings };
2323
struct StreamMetadata {
2424
// Common (video and audio) fields derived from the AVStream.
2525
int streamIndex;
26+
2627
// See this link for what various values are available:
2728
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
2829
AVMediaType mediaType;
30+
2931
std::optional<AVCodecID> codecId;
3032
std::optional<std::string> codecName;
3133
std::optional<double> durationSecondsFromHeader;
@@ -35,17 +37,22 @@ struct StreamMetadata {
3537
std::optional<double> averageFpsFromHeader;
3638
std::optional<double> bitRate;
3739

40+
// Used as fallback in approximate mode when stream duration is unavailable.
41+
std::optional<double> durationSecondsFromContainer;
42+
3843
// More accurate duration, obtained by scanning the file.
3944
// These presentation timestamps are in time base.
4045
std::optional<int64_t> beginStreamPtsFromContent;
4146
std::optional<int64_t> endStreamPtsFromContent;
47+
4248
// These presentation timestamps are in seconds.
4349
std::optional<double> beginStreamPtsSecondsFromContent;
4450
std::optional<double> endStreamPtsSecondsFromContent;
51+
4552
// This can be useful for index-based seeking.
4653
std::optional<int64_t> numFramesFromContent;
4754

48-
// Video-only fields derived from the AVCodecContext.
55+
// Video-only fields
4956
std::optional<int> width;
5057
std::optional<int> height;
5158
std::optional<AVRational> sampleAspectRatio;
@@ -67,13 +74,17 @@ struct ContainerMetadata {
6774
std::vector<StreamMetadata> allStreamMetadata;
6875
int numAudioStreams = 0;
6976
int numVideoStreams = 0;
77+
7078
// Note that this is the container-level duration, which is usually the max
7179
// of all stream durations available in the container.
7280
std::optional<double> durationSecondsFromHeader;
81+
7382
// Total BitRate level information at the container level in bit/s
7483
std::optional<double> bitRate;
84+
7585
// If set, this is the index to the default audio stream.
7686
std::optional<int> bestAudioStreamIndex;
87+
7788
// If set, this is the index to the default video stream.
7889
std::optional<int> bestVideoStreamIndex;
7990
};

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,26 @@ void SingleStreamDecoder::initializeDecoder() {
100100
"Failed to find stream info: ",
101101
getFFMPEGErrorStringFromErrorCode(status));
102102

103+
if (formatContext_->duration > 0) {
104+
AVRational defaultTimeBase{1, AV_TIME_BASE};
105+
containerMetadata_.durationSecondsFromHeader =
106+
ptsToSeconds(formatContext_->duration, defaultTimeBase);
107+
}
108+
109+
if (formatContext_->bit_rate > 0) {
110+
containerMetadata_.bitRate = formatContext_->bit_rate;
111+
}
112+
113+
int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
114+
if (bestVideoStream >= 0) {
115+
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
116+
}
117+
118+
int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
119+
if (bestAudioStream >= 0) {
120+
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
121+
}
122+
103123
for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
104124
AVStream* avStream = formatContext_->streams[i];
105125
StreamMetadata streamMetadata;
@@ -110,8 +130,8 @@ void SingleStreamDecoder::initializeDecoder() {
110130
", does not match AVStream's index, " +
111131
std::to_string(avStream->index) + ".");
112132
streamMetadata.streamIndex = i;
113-
streamMetadata.mediaType = avStream->codecpar->codec_type;
114133
streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
134+
streamMetadata.mediaType = avStream->codecpar->codec_type;
115135
streamMetadata.bitRate = avStream->codecpar->bit_rate;
116136

117137
int64_t frameCount = avStream->nb_frames;
@@ -133,10 +153,18 @@ void SingleStreamDecoder::initializeDecoder() {
133153
if (fps > 0) {
134154
streamMetadata.averageFpsFromHeader = fps;
135155
}
156+
streamMetadata.width = avStream->codecpar->width;
157+
streamMetadata.height = avStream->codecpar->height;
158+
streamMetadata.sampleAspectRatio =
159+
avStream->codecpar->sample_aspect_ratio;
136160
containerMetadata_.numVideoStreams++;
137161
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
138162
AVSampleFormat format =
139163
static_cast<AVSampleFormat>(avStream->codecpar->format);
164+
streamMetadata.sampleRate =
165+
static_cast<int64_t>(avStream->codecpar->sample_rate);
166+
streamMetadata.numChannels =
167+
static_cast<int64_t>(getNumChannels(avStream->codecpar));
140168

141169
// If the AVSampleFormat is not recognized, we get back nullptr. We have
142170
// to make sure we don't initialize a std::string with nullptr. There's
@@ -149,27 +177,10 @@ void SingleStreamDecoder::initializeDecoder() {
149177
containerMetadata_.numAudioStreams++;
150178
}
151179

152-
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
153-
}
180+
streamMetadata.durationSecondsFromContainer =
181+
containerMetadata_.durationSecondsFromHeader;
154182

155-
if (formatContext_->duration > 0) {
156-
AVRational defaultTimeBase{1, AV_TIME_BASE};
157-
containerMetadata_.durationSecondsFromHeader =
158-
ptsToSeconds(formatContext_->duration, defaultTimeBase);
159-
}
160-
161-
if (formatContext_->bit_rate > 0) {
162-
containerMetadata_.bitRate = formatContext_->bit_rate;
163-
}
164-
165-
int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO);
166-
if (bestVideoStream >= 0) {
167-
containerMetadata_.bestVideoStreamIndex = bestVideoStream;
168-
}
169-
170-
int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO);
171-
if (bestAudioStream >= 0) {
172-
containerMetadata_.bestAudioStreamIndex = bestAudioStream;
183+
containerMetadata_.allStreamMetadata.push_back(streamMetadata);
173184
}
174185

175186
if (seekMode_ == SeekMode::exact) {
@@ -288,6 +299,14 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
288299
streamMetadata.numFramesFromContent =
289300
streamInfos_[streamIndex].allFrames.size();
290301

302+
// This ensures that we are robust in handling cases where
303+
// we are decoding in exact mode and numFrames is 0. The current metadata
304+
// validation logic assumes that these values should not be None
305+
if (streamMetadata.numFramesFromContent.value() == 0) {
306+
streamMetadata.beginStreamPtsFromContent = 0;
307+
streamMetadata.endStreamPtsFromContent = 0;
308+
}
309+
291310
if (streamMetadata.beginStreamPtsFromContent.has_value()) {
292311
streamMetadata.beginStreamPtsSecondsFromContent = ptsToSeconds(
293312
*streamMetadata.beginStreamPtsFromContent, avStream->time_base);
@@ -516,11 +535,6 @@ void SingleStreamDecoder::addVideoStream(
516535
auto& streamInfo = streamInfos_[activeStreamIndex_];
517536
streamInfo.videoStreamOptions = videoStreamOptions;
518537

519-
streamMetadata.width = streamInfo.codecContext->width;
520-
streamMetadata.height = streamInfo.codecContext->height;
521-
streamMetadata.sampleAspectRatio =
522-
streamInfo.codecContext->sample_aspect_ratio;
523-
524538
if (seekMode_ == SeekMode::custom_frame_mappings) {
525539
TORCH_CHECK(
526540
customFrameMappings.has_value(),
@@ -566,13 +580,6 @@ void SingleStreamDecoder::addAudioStream(
566580
auto& streamInfo = streamInfos_[activeStreamIndex_];
567581
streamInfo.audioStreamOptions = audioStreamOptions;
568582

569-
auto& streamMetadata =
570-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
571-
streamMetadata.sampleRate =
572-
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
573-
streamMetadata.numChannels =
574-
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
575-
576583
// FFmpeg docs say that the decoder will try to decode natively in this
577584
// format, if it can. Docs don't say what the decoder does when it doesn't
578585
// support that format, but it looks like it does nothing, so this probably

src/torchcodec/_core/_metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class StreamMetadata:
4444
from the actual frames if a :term:`scan` was performed. Otherwise we
4545
fall back to ``duration_seconds_from_header``. If that value is also None,
4646
we instead calculate the duration from ``num_frames_from_header`` and
47-
``average_fps_from_header``.
47+
``average_fps_from_header``. If all of those are unavailable, we fall back
48+
to the container-level ``duration_seconds_from_header``.
4849
"""
4950
begin_stream_seconds: Optional[float]
5051
"""Beginning of the stream, in seconds (float). Conceptually, this

0 commit comments

Comments
 (0)