Skip to content

Commit 04b02b9

Browse files
authored
Refactor order of getting metadata and adding a stream (#1060)
1 parent 1ea235a commit 04b02b9

File tree

9 files changed

+108
-59
lines changed

9 files changed

+108
-59
lines changed

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) {
158158
#endif
159159
}
160160

161+
int getNumChannels(const AVCodecParameters* codecpar) {
162+
TORCH_CHECK(codecpar != nullptr, "codecpar is null")
163+
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
164+
(LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
165+
return codecpar->ch_layout.nb_channels;
166+
#else
167+
return codecpar->channels;
168+
#endif
169+
}
170+
161171
void setDefaultChannelLayout(
162172
UniqueAVCodecContext& avCodecContext,
163173
int numChannels) {

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
180180

181181
int getNumChannels(const UniqueAVFrame& avFrame);
182182
int getNumChannels(const SharedAVCodecContext& avCodecContext);
183+
int getNumChannels(const AVCodecParameters* codecpar);
183184

184185
void setDefaultChannelLayout(
185186
UniqueAVCodecContext& avCodecContext,

src/torchcodec/_core/Metadata.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ enum class SeekMode { exact, approximate, custom_frame_mappings };
2323
struct StreamMetadata {
2424
// Common (video and audio) fields derived from the AVStream.
2525
int streamIndex;
26+
2627
// See this link for what various values are available:
2728
// https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48
2829
AVMediaType mediaType;
30+
2931
std::optional<AVCodecID> codecId;
3032
std::optional<std::string> codecName;
3133
std::optional<double> durationSecondsFromHeader;
@@ -39,13 +41,15 @@ struct StreamMetadata {
3941
// These presentation timestamps are in time base.
4042
std::optional<int64_t> beginStreamPtsFromContent;
4143
std::optional<int64_t> endStreamPtsFromContent;
44+
4245
// These presentation timestamps are in seconds.
4346
std::optional<double> beginStreamPtsSecondsFromContent;
4447
std::optional<double> endStreamPtsSecondsFromContent;
48+
4549
// This can be useful for index-based seeking.
4650
std::optional<int64_t> numFramesFromContent;
4751

48-
// Video-only fields derived from the AVCodecContext.
52+
// Video-only fields
4953
std::optional<int> width;
5054
std::optional<int> height;
5155
std::optional<AVRational> sampleAspectRatio;
@@ -67,13 +71,17 @@ struct ContainerMetadata {
6771
std::vector<StreamMetadata> allStreamMetadata;
6872
int numAudioStreams = 0;
6973
int numVideoStreams = 0;
74+
7075
// Note that this is the container-level duration, which is usually the max
7176
// of all stream durations available in the container.
7277
std::optional<double> durationSecondsFromHeader;
78+
7379
// Total BitRate level information at the container level in bit/s
7480
std::optional<double> bitRate;
81+
7582
// If set, this is the index to the default audio stream.
7683
std::optional<int> bestAudioStreamIndex;
84+
7785
// If set, this is the index to the default video stream.
7886
std::optional<int> bestVideoStreamIndex;
7987
};

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ void SingleStreamDecoder::initializeDecoder() {
110110
", does not match AVStream's index, " +
111111
std::to_string(avStream->index) + ".");
112112
streamMetadata.streamIndex = i;
113-
streamMetadata.mediaType = avStream->codecpar->codec_type;
114113
streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id);
114+
streamMetadata.mediaType = avStream->codecpar->codec_type;
115115
streamMetadata.bitRate = avStream->codecpar->bit_rate;
116116

117117
int64_t frameCount = avStream->nb_frames;
@@ -133,10 +133,18 @@ void SingleStreamDecoder::initializeDecoder() {
133133
if (fps > 0) {
134134
streamMetadata.averageFpsFromHeader = fps;
135135
}
136+
streamMetadata.width = avStream->codecpar->width;
137+
streamMetadata.height = avStream->codecpar->height;
138+
streamMetadata.sampleAspectRatio =
139+
avStream->codecpar->sample_aspect_ratio;
136140
containerMetadata_.numVideoStreams++;
137141
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
138142
AVSampleFormat format =
139143
static_cast<AVSampleFormat>(avStream->codecpar->format);
144+
streamMetadata.sampleRate =
145+
static_cast<int64_t>(avStream->codecpar->sample_rate);
146+
streamMetadata.numChannels =
147+
static_cast<int64_t>(getNumChannels(avStream->codecpar));
140148

141149
// If the AVSampleFormat is not recognized, we get back nullptr. We have
142150
// to make sure we don't initialize a std::string with nullptr. There's
@@ -524,11 +532,6 @@ void SingleStreamDecoder::addVideoStream(
524532
auto& streamInfo = streamInfos_[activeStreamIndex_];
525533
streamInfo.videoStreamOptions = videoStreamOptions;
526534

527-
streamMetadata.width = streamInfo.codecContext->width;
528-
streamMetadata.height = streamInfo.codecContext->height;
529-
streamMetadata.sampleAspectRatio =
530-
streamInfo.codecContext->sample_aspect_ratio;
531-
532535
if (seekMode_ == SeekMode::custom_frame_mappings) {
533536
TORCH_CHECK(
534537
customFrameMappings.has_value(),
@@ -574,13 +577,6 @@ void SingleStreamDecoder::addAudioStream(
574577
auto& streamInfo = streamInfos_[activeStreamIndex_];
575578
streamInfo.audioStreamOptions = audioStreamOptions;
576579

577-
auto& streamMetadata =
578-
containerMetadata_.allStreamMetadata[activeStreamIndex_];
579-
streamMetadata.sampleRate =
580-
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
581-
streamMetadata.numChannels =
582-
static_cast<int64_t>(getNumChannels(streamInfo.codecContext));
583-
584580
// FFmpeg docs say that the decoder will try to decode natively in this
585581
// format, if it can. Docs don't say what the decoder does when it doesn't
586582
// support that format, but it looks like it does nothing, so this probably

src/torchcodec/_core/custom_ops.cpp

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,34 @@ SeekMode seekModeFromString(std::string_view seekMode) {
198198
}
199199
}
200200

201+
void writeFallbackBasedMetadata(
202+
std::map<std::string, std::string>& map,
203+
const StreamMetadata& streamMetadata,
204+
SeekMode seekMode) {
205+
auto durationSeconds = streamMetadata.getDurationSeconds(seekMode);
206+
if (durationSeconds.has_value()) {
207+
map["durationSeconds"] = std::to_string(durationSeconds.value());
208+
}
209+
210+
auto numFrames = streamMetadata.getNumFrames(seekMode);
211+
if (numFrames.has_value()) {
212+
map["numFrames"] = std::to_string(numFrames.value());
213+
}
214+
215+
double beginStreamSeconds = streamMetadata.getBeginStreamSeconds(seekMode);
216+
map["beginStreamSeconds"] = std::to_string(beginStreamSeconds);
217+
218+
auto endStreamSeconds = streamMetadata.getEndStreamSeconds(seekMode);
219+
if (endStreamSeconds.has_value()) {
220+
map["endStreamSeconds"] = std::to_string(endStreamSeconds.value());
221+
}
222+
223+
auto averageFps = streamMetadata.getAverageFps(seekMode);
224+
if (averageFps.has_value()) {
225+
map["averageFps"] = std::to_string(averageFps.value());
226+
}
227+
}
228+
201229
int checkedToPositiveInt(const std::string& str) {
202230
int ret = 0;
203231
try {
@@ -917,30 +945,28 @@ std::string get_stream_json_metadata(
917945
// In approximate mode: content-based metadata does not exist for any stream.
918946
// In custom_frame_mappings: content-based metadata exists only for the active
919947
// stream.
948+
//
920949
// Our fallback logic assumes content-based metadata is available.
921950
// It is available for decoding on the active stream, but would break
922951
// when getting metadata from non-active streams.
923952
if ((seekMode != SeekMode::custom_frame_mappings) ||
924953
(seekMode == SeekMode::custom_frame_mappings &&
925954
stream_index == activeStreamIndex)) {
926-
if (streamMetadata.getDurationSeconds(seekMode).has_value()) {
927-
map["durationSeconds"] =
928-
std::to_string(streamMetadata.getDurationSeconds(seekMode).value());
929-
}
930-
if (streamMetadata.getNumFrames(seekMode).has_value()) {
931-
map["numFrames"] =
932-
std::to_string(streamMetadata.getNumFrames(seekMode).value());
933-
}
934-
map["beginStreamSeconds"] =
935-
std::to_string(streamMetadata.getBeginStreamSeconds(seekMode));
936-
if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) {
937-
map["endStreamSeconds"] =
938-
std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value());
939-
}
940-
if (streamMetadata.getAverageFps(seekMode).has_value()) {
941-
map["averageFps"] =
942-
std::to_string(streamMetadata.getAverageFps(seekMode).value());
943-
}
955+
writeFallbackBasedMetadata(map, streamMetadata, seekMode);
956+
} else if (seekMode == SeekMode::custom_frame_mappings) {
957+
// If this is not the active stream, then we don't have content-based
958+
// metadata for custom frame mappings. In that case, we want the same
959+
// behavior as we would get with approximate mode. Encoding this behavior in
960+
// the fallback logic itself is tricky and not worth it for this corner
961+
// case. So we hardcode in approximate mode.
962+
//
963+
// TODO: This hacky behavior is only necessary because the custom frame
964+
// mapping is supplied in SingleStreamDecoder::addVideoStream() rather
965+
// than in the constructor. And it's supplied to addVideoStream() and
966+
// not the constructor because we need to know the stream index. If we
967+
// can encode the relevant stream indices into custom frame mappings
968+
// itself, then we can put it in the constructor.
969+
writeFallbackBasedMetadata(map, streamMetadata, SeekMode::approximate);
944970
}
945971

946972
return mapToJson(map);

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,6 @@ def __init__(
6363
torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder")
6464
self._decoder = create_decoder(source=source, seek_mode="approximate")
6565

66-
core.add_audio_stream(
67-
self._decoder,
68-
stream_index=stream_index,
69-
sample_rate=sample_rate,
70-
num_channels=num_channels,
71-
)
72-
7366
container_metadata = core.get_container_metadata(self._decoder)
7467
self.stream_index = (
7568
container_metadata.best_audio_stream_index
@@ -81,13 +74,28 @@ def __init__(
8174
"The best audio stream is unknown and there is no specified stream. "
8275
+ ERROR_REPORTING_INSTRUCTIONS
8376
)
77+
if self.stream_index >= len(container_metadata.streams):
78+
raise ValueError(
79+
f"The stream at index {stream_index} is not a valid stream."
80+
)
81+
8482
self.metadata = container_metadata.streams[self.stream_index]
85-
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
83+
if not isinstance(self.metadata, core._metadata.AudioStreamMetadata):
84+
raise ValueError(
85+
f"The stream at index {stream_index} is not an audio stream. "
86+
)
8687

8788
self._desired_sample_rate = (
8889
sample_rate if sample_rate is not None else self.metadata.sample_rate
8990
)
9091

92+
core.add_audio_stream(
93+
self._decoder,
94+
stream_index=stream_index,
95+
sample_rate=sample_rate,
96+
num_channels=num_channels,
97+
)
98+
9199
def get_all_samples(self) -> AudioSamples:
92100
"""Returns all the audio samples from the source.
93101

src/torchcodec/decoders/_video_decoder.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ def __init__(
141141

142142
self._decoder = create_decoder(source=source, seek_mode=seek_mode)
143143

144+
(
145+
self.metadata,
146+
self.stream_index,
147+
self._begin_stream_seconds,
148+
self._end_stream_seconds,
149+
self._num_frames,
150+
) = _get_and_validate_stream_metadata(
151+
decoder=self._decoder, stream_index=stream_index
152+
)
153+
144154
allowed_dimension_orders = ("NCHW", "NHWC")
145155
if dimension_order not in allowed_dimension_orders:
146156
raise ValueError(
@@ -157,12 +167,11 @@ def __init__(
157167
device = str(device)
158168

159169
device_variant = _get_cuda_backend()
160-
161170
transform_specs = _make_transform_specs(transforms)
162171

163172
core.add_video_stream(
164173
self._decoder,
165-
stream_index=stream_index,
174+
stream_index=self.stream_index,
166175
dimension_order=dimension_order,
167176
num_threads=num_ffmpeg_threads,
168177
device=device,
@@ -171,16 +180,6 @@ def __init__(
171180
custom_frame_mappings=custom_frame_mappings_data,
172181
)
173182

174-
(
175-
self.metadata,
176-
self.stream_index,
177-
self._begin_stream_seconds,
178-
self._end_stream_seconds,
179-
self._num_frames,
180-
) = _get_and_validate_stream_metadata(
181-
decoder=self._decoder, stream_index=stream_index
182-
)
183-
184183
def __len__(self) -> int:
185184
return self._num_frames
186185

@@ -413,8 +412,12 @@ def _get_and_validate_stream_metadata(
413412
+ ERROR_REPORTING_INSTRUCTIONS
414413
)
415414

415+
if stream_index >= len(container_metadata.streams):
416+
raise ValueError(f"The stream index {stream_index} is not a valid stream.")
417+
416418
metadata = container_metadata.streams[stream_index]
417-
assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy
419+
if not isinstance(metadata, core._metadata.VideoStreamMetadata):
420+
raise ValueError(f"The stream at index {stream_index} is not a video stream. ")
418421

419422
if metadata.begin_stream_seconds is None:
420423
raise ValueError(

test/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def test_create_fails(self, Decoder):
116116
Decoder(123)
117117

118118
# stream index that does not exist
119-
with pytest.raises(ValueError, match="No valid stream found"):
119+
with pytest.raises(ValueError, match="40 is not a valid stream"):
120120
Decoder(NASA_VIDEO.path, stream_index=40)
121121

122122
# stream index that does exist, but it's not audio or video
123-
with pytest.raises(ValueError, match="No valid stream found"):
123+
with pytest.raises(ValueError, match=r"not (a|an) (video|audio) stream"):
124124
Decoder(NASA_VIDEO.path, stream_index=2)
125125

126126
# user mistakenly forgets to specify binary reading when creating a file

test/test_metadata.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_get_metadata(metadata_getter):
5959
)
6060
if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5):
6161
pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5")
62-
with_added_video_stream = seek_mode == "custom_frame_mappings"
6362
metadata = metadata_getter(NASA_VIDEO.path)
6463

6564
with_scan = (
@@ -99,9 +98,7 @@ def test_get_metadata(metadata_getter):
9998
assert best_video_stream_metadata.begin_stream_seconds_from_header == 0
10099
assert best_video_stream_metadata.bit_rate == 128783
101100
assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001)
102-
assert best_video_stream_metadata.pixel_aspect_ratio == (
103-
Fraction(1, 1) if with_added_video_stream else None
104-
)
101+
assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1)
105102
assert best_video_stream_metadata.codec == "h264"
106103
assert best_video_stream_metadata.num_frames_from_content == (
107104
390 if with_scan else None

0 commit comments

Comments
 (0)