Skip to content

Commit 7f1b5f1

Browse files
author
pytorchbot
committed
2025-03-25 nightly release (57899ee)
1 parent c3fe417 commit 7f1b5f1

File tree

12 files changed

+240
-155
lines changed

12 files changed

+240
-155
lines changed

benchmarks/decoders/benchmark_audio_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_duration(path: Path) -> str:
7171

7272

7373
def decode_with_torchcodec(path: Path) -> None:
74-
AudioDecoder(path).get_samples_played_in_range(start_seconds=0, stop_seconds=None)
74+
AudioDecoder(path).get_all_samples()
7575

7676

7777
def decode_with_torchaudio_StreamReader(path: Path) -> None:

docs/source/_templates/dataclass.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
.. autoclass:: {{ name }}
99
:members:
1010
:undoc-members: __init__
11+
:inherited-members:

examples/audio_decoding.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def play_audio(samples):
5959
# ----------------
6060
#
6161
# To get decoded samples, we just need to call the
62-
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method,
62+
# :meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` method,
6363
# which returns an :class:`~torchcodec.AudioSamples` object:
6464

65-
samples = decoder.get_samples_played_in_range()
65+
samples = decoder.get_all_samples()
6666

6767
print(samples)
6868
play_audio(samples)
@@ -76,13 +76,12 @@ def play_audio(samples):
7676
# all streams start exactly at 0! This is not a bug in TorchCodec, this is a
7777
# property of the file that was defined when it was encoded.
7878
#
79-
# %%
8079
# Specifying a range
8180
# ------------------
8281
#
83-
# By default,
84-
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` decodes
85-
# the entire audio stream, but we can specify a custom range:
82+
# If we don't need all the samples, we can use
83+
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` to
84+
# decode the samples within a custom range:
8685

8786
samples = decoder.get_samples_played_in_range(start_seconds=10, stop_seconds=70)
8887

@@ -99,7 +98,7 @@ def play_audio(samples):
9998
# increased:
10099

101100
decoder = AudioDecoder(raw_audio_bytes, sample_rate=16_000)
102-
samples = decoder.get_samples_played_in_range(start_seconds=0)
101+
samples = decoder.get_all_samples()
103102

104103
print(samples)
105104
play_audio(samples)

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torchcodec.decoders import _core as core
1414
from torchcodec.decoders._decoder_utils import (
1515
create_decoder,
16-
get_and_validate_stream_metadata,
16+
ERROR_REPORTING_INSTRUCTIONS,
1717
)
1818

1919

@@ -57,31 +57,54 @@ def __init__(
5757
self._decoder, stream_index=stream_index, sample_rate=sample_rate
5858
)
5959

60-
(
61-
self.metadata,
62-
self.stream_index,
63-
self._begin_stream_seconds,
64-
self._end_stream_seconds,
65-
) = get_and_validate_stream_metadata(
66-
decoder=self._decoder, stream_index=stream_index, media_type="audio"
60+
container_metadata = core.get_container_metadata(self._decoder)
61+
self.stream_index = (
62+
container_metadata.best_audio_stream_index
63+
if stream_index is None
64+
else stream_index
6765
)
66+
if self.stream_index is None:
67+
raise ValueError(
68+
"The best audio stream is unknown and there is no specified stream. "
69+
+ ERROR_REPORTING_INSTRUCTIONS
70+
)
71+
self.metadata = container_metadata.streams[self.stream_index]
6872
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
73+
6974
self._desired_sample_rate = (
7075
sample_rate if sample_rate is not None else self.metadata.sample_rate
7176
)
7277

78+
def get_all_samples(self) -> AudioSamples:
79+
"""Returns all the audio samples from the source.
80+
81+
To decode samples in a specific range, use
82+
:meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range`.
83+
84+
Returns:
85+
AudioSamples: The samples within the file.
86+
"""
87+
return self.get_samples_played_in_range()
88+
7389
def get_samples_played_in_range(
7490
self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None
7591
) -> AudioSamples:
7692
"""Returns audio samples in the given range.
7793
7894
Samples are in the half open range [start_seconds, stop_seconds).
7995
96+
To decode all the samples from beginning to end, you can call this
97+
method while leaving ``start_seconds`` and ``stop_seconds`` to their
98+
default values, or use
99+
:meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` as a more
100+
convenient alias.
101+
80102
Args:
81103
start_seconds (float): Time, in seconds, of the start of the
82104
range. Default: 0.
83-
stop_seconds (float): Time, in seconds, of the end of the
84-
range. As a half open range, the end is excluded.
105+
stop_seconds (float or None): Time, in seconds, of the end of the
106+
range. As a half open range, the end is excluded. Default: None,
107+
which decodes samples until the end.
85108
86109
Returns:
87110
AudioSamples: The samples within the specified range.
@@ -90,12 +113,6 @@ def get_samples_played_in_range(
90113
raise ValueError(
91114
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
92115
)
93-
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
94-
raise ValueError(
95-
f"Invalid start seconds: {start_seconds}. "
96-
f"It must be greater than or equal to {self._begin_stream_seconds} "
97-
f"and less than or equal to {self._end_stream_seconds}."
98-
)
99116
frames, first_pts = core.get_frames_by_pts_in_range_audio(
100117
self._decoder,
101118
start_seconds=start_seconds,

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ void VideoDecoder::initializeDecoder() {
147147
streamMetadata.durationSeconds =
148148
av_q2d(avStream->time_base) * avStream->duration;
149149
}
150+
if (avStream->start_time != AV_NOPTS_VALUE) {
151+
streamMetadata.beginStreamFromHeader =
152+
av_q2d(avStream->time_base) * avStream->start_time;
153+
}
150154

151155
if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
152156
double fps = av_q2d(avStream->r_frame_rate);
@@ -157,7 +161,15 @@ void VideoDecoder::initializeDecoder() {
157161
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
158162
AVSampleFormat format =
159163
static_cast<AVSampleFormat>(avStream->codecpar->format);
160-
streamMetadata.sampleFormat = av_get_sample_fmt_name(format);
164+
165+
// If the AVSampleFormat is not recognized, we get back nullptr. We have
166+
// to make sure we don't initialize a std::string with nullptr. There's
167+
// nothing to do on the else branch because we're already using an
168+
// optional; it'll just remain empty.
169+
const char* rawSampleFormat = av_get_sample_fmt_name(format);
170+
if (rawSampleFormat != nullptr) {
171+
streamMetadata.sampleFormat = std::string(rawSampleFormat);
172+
}
161173
containerMetadata_.numAudioStreams++;
162174
}
163175

@@ -944,8 +956,9 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
944956
TORCH_CHECK(
945957
frames.size() > 0 && firstFramePtsSeconds.has_value(),
946958
"No audio frames were decoded. ",
947-
"This should probably not happen. ",
948-
"Please report an issue on the TorchCodec repo.");
959+
"This is probably because start_seconds is too high? ",
960+
"Current value is ",
961+
startSeconds);
949962

950963
return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
951964
}

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class VideoDecoder {
5959
std::optional<AVCodecID> codecId;
6060
std::optional<std::string> codecName;
6161
std::optional<double> durationSeconds;
62+
std::optional<double> beginStreamFromHeader;
6263
std::optional<int64_t> numFrames;
6364
std::optional<int64_t> numKeyFrames;
6465
std::optional<double> averageFps;
@@ -238,7 +239,6 @@ class VideoDecoder {
238239
double startSeconds,
239240
double stopSeconds);
240241

241-
// TODO-AUDIO: Should accept sampleRate
242242
AudioFramesOutput getFramesPlayedInRangeAudio(
243243
double startSeconds,
244244
std::optional<double> stopSecondsOptional = std::nullopt);

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,10 @@ std::string get_stream_json_metadata(
473473
if (streamMetadata.numFrames.has_value()) {
474474
map["numFrames"] = std::to_string(*streamMetadata.numFrames);
475475
}
476+
if (streamMetadata.beginStreamFromHeader.has_value()) {
477+
map["beginStreamFromHeader"] =
478+
std::to_string(*streamMetadata.beginStreamFromHeader);
479+
}
476480
if (streamMetadata.minPtsSecondsFromScan.has_value()) {
477481
map["minPtsSecondsFromScan"] =
478482
std::to_string(*streamMetadata.minPtsSecondsFromScan);

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,37 +22,64 @@
2222
SPACES = " "
2323

2424

25-
# TODO-AUDIO: docs below are mostly for video streams, we should edit them and /
26-
# or make sure they're OK for audio streams as well. Not sure how to best handle
27-
# docs for such class hierarchy.
28-
# TODO very related, none of these common fields in this base class show up in
29-
# the docs right now.
3025
@dataclass
3126
class StreamMetadata:
3227
duration_seconds_from_header: Optional[float]
3328
"""Duration of the stream, in seconds, obtained from the header (float or
3429
None). This could be inaccurate."""
30+
begin_stream_seconds_from_header: Optional[float]
31+
"""Beginning of the stream, in seconds, obtained from the header (float or
32+
None). Usually, this is equal to 0."""
3533
bit_rate: Optional[float]
3634
"""Bit rate of the stream, in seconds (float or None)."""
35+
codec: Optional[str]
36+
"""Codec (str or None)."""
37+
stream_index: int
38+
"""Index of the stream that this metadata refers to (int)."""
39+
40+
def __repr__(self):
41+
s = self.__class__.__name__ + ":\n"
42+
for field in dataclasses.fields(self):
43+
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
44+
return s
45+
46+
47+
@dataclass
48+
class VideoStreamMetadata(StreamMetadata):
49+
"""Metadata of a single video stream."""
50+
3751
begin_stream_seconds_from_content: Optional[float]
3852
"""Beginning of the stream, in seconds (float or None).
39-
Conceptually, this corresponds to the first frame's :term:`pts`. It is
40-
computed as min(frame.pts) across all frames in the stream. Usually, this is
41-
equal to 0."""
53+
Conceptually, this corresponds to the first frame's :term:`pts`. It is only
54+
computed when a :term:`scan` is done as min(frame.pts) across all frames in
55+
the stream. Usually, this is equal to 0."""
4256
end_stream_seconds_from_content: Optional[float]
4357
"""End of the stream, in seconds (float or None).
4458
Conceptually, this corresponds to last_frame.pts + last_frame.duration. It
45-
is computed as max(frame.pts + frame.duration) across all frames in the
46-
stream. Note that no frame is played at this time value, so calling
47-
:meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with
48-
this value would result in an error. Retrieving the last frame is best done
49-
by simply indexing the :class:`~torchcodec.decoders.VideoDecoder`
50-
object with ``[-1]``.
59+
is only computed when a :term:`scan` is done as max(frame.pts +
60+
frame.duration) across all frames in the stream. Note that no frame is
61+
played at this time value, so calling
62+
:meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with this
63+
value would result in an error. Retrieving the last frame is best done by
64+
simply indexing the :class:`~torchcodec.decoders.VideoDecoder` object with
65+
``[-1]``.
5166
"""
52-
codec: Optional[str]
53-
"""Codec (str or None)."""
54-
stream_index: int
55-
"""Index of the stream within the video (int)."""
67+
width: Optional[int]
68+
"""Width of the frames (int or None)."""
69+
height: Optional[int]
70+
"""Height of the frames (int or None)."""
71+
num_frames_from_header: Optional[int]
72+
"""Number of frames, from the stream's metadata. This is potentially
73+
inaccurate. We recommend using the ``num_frames`` attribute instead.
74+
(int or None)."""
75+
num_frames_from_content: Optional[int]
76+
"""Number of frames computed by TorchCodec by scanning the stream's
77+
content (the scan doesn't involve decoding). This is more accurate
78+
than ``num_frames_from_header``. We recommend using the
79+
``num_frames`` attribute instead. (int or None)."""
80+
average_fps_from_header: Optional[float]
81+
"""Averate fps of the stream, obtained from the header (float or None).
82+
We recommend using the ``average_fps`` attribute instead."""
5683

5784
@property
5885
def duration_seconds(self) -> Optional[float]:
@@ -94,36 +121,6 @@ def end_stream_seconds(self) -> Optional[float]:
94121
else:
95122
return self.end_stream_seconds_from_content
96123

97-
def __repr__(self):
98-
# Overridden because properites are not printed by default.
99-
s = self.__class__.__name__ + ":\n"
100-
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
101-
for field in dataclasses.fields(self):
102-
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
103-
return s
104-
105-
106-
@dataclass
107-
class VideoStreamMetadata(StreamMetadata):
108-
"""Metadata of a single video stream."""
109-
110-
width: Optional[int]
111-
"""Width of the frames (int or None)."""
112-
height: Optional[int]
113-
"""Height of the frames (int or None)."""
114-
num_frames_from_header: Optional[int]
115-
"""Number of frames, from the stream's metadata. This is potentially
116-
inaccurate. We recommend using the ``num_frames`` attribute instead.
117-
(int or None)."""
118-
num_frames_from_content: Optional[int]
119-
"""Number of frames computed by TorchCodec by scanning the stream's
120-
content (the scan doesn't involve decoding). This is more accurate
121-
than ``num_frames_from_header``. We recommend using the
122-
``num_frames`` attribute instead. (int or None)."""
123-
average_fps_from_header: Optional[float]
124-
"""Averate fps of the stream, obtained from the header (float or None).
125-
We recommend using the ``average_fps`` attribute instead."""
126-
127124
@property
128125
def num_frames(self) -> Optional[int]:
129126
"""Number of frames in the stream. This corresponds to
@@ -154,6 +151,9 @@ def average_fps(self) -> Optional[float]:
154151

155152
def __repr__(self):
156153
s = super().__repr__()
154+
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
155+
s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n"
156+
s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n"
157157
s += f"{SPACES}num_frames: {self.num_frames}\n"
158158
s += f"{SPACES}average_fps: {self.average_fps}\n"
159159
return s
@@ -224,14 +224,19 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
224224
common_meta = dict(
225225
duration_seconds_from_header=stream_dict.get("durationSeconds"),
226226
bit_rate=stream_dict.get("bitRate"),
227-
begin_stream_seconds_from_content=stream_dict.get("minPtsSecondsFromScan"),
228-
end_stream_seconds_from_content=stream_dict.get("maxPtsSecondsFromScan"),
227+
begin_stream_seconds_from_header=stream_dict.get("beginStreamFromHeader"),
229228
codec=stream_dict.get("codec"),
230229
stream_index=stream_index,
231230
)
232231
if stream_dict["mediaType"] == "video":
233232
streams_metadata.append(
234233
VideoStreamMetadata(
234+
begin_stream_seconds_from_content=stream_dict.get(
235+
"minPtsSecondsFromScan"
236+
),
237+
end_stream_seconds_from_content=stream_dict.get(
238+
"maxPtsSecondsFromScan"
239+
),
235240
width=stream_dict.get("width"),
236241
height=stream_dict.get("height"),
237242
num_frames_from_header=stream_dict.get("numFrames"),

0 commit comments

Comments
 (0)