Skip to content

Commit 90e1262

Browse files
authored
Add get_all_samples() method to AudioDecoder (#594)
1 parent 09ee988 commit 90e1262

File tree

5 files changed

+36
-11
lines changed

5 files changed

+36
-11
lines changed

benchmarks/decoders/benchmark_audio_decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_duration(path: Path) -> str:
7171

7272

7373
def decode_with_torchcodec(path: Path) -> None:
74-
AudioDecoder(path).get_samples_played_in_range(start_seconds=0, stop_seconds=None)
74+
AudioDecoder(path).get_all_samples()
7575

7676

7777
def decode_with_torchaudio_StreamReader(path: Path) -> None:

examples/audio_decoding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def play_audio(samples):
5959
# ----------------
6060
#
6161
# To get decoded samples, we just need to call the
62-
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method,
62+
# :meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` method,
6363
# which returns an :class:`~torchcodec.AudioSamples` object:
6464

65-
samples = decoder.get_samples_played_in_range()
65+
samples = decoder.get_all_samples()
6666

6767
print(samples)
6868
play_audio(samples)
@@ -79,9 +79,9 @@ def play_audio(samples):
7979
# Specifying a range
8080
# ------------------
8181
#
82-
# By default,
83-
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` decodes
84-
# the entire audio stream, but we can specify a custom range:
82+
# If we don't need all the samples, we can use
83+
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` to
84+
# decode the samples within a custom range:
8585

8686
samples = decoder.get_samples_played_in_range(start_seconds=10, stop_seconds=70)
8787

@@ -98,7 +98,7 @@ def play_audio(samples):
9898
# increased:
9999

100100
decoder = AudioDecoder(raw_audio_bytes, sample_rate=16_000)
101-
samples = decoder.get_samples_played_in_range(start_seconds=0)
101+
samples = decoder.get_all_samples()
102102

103103
print(samples)
104104
play_audio(samples)

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,36 @@ def __init__(
7575
sample_rate if sample_rate is not None else self.metadata.sample_rate
7676
)
7777

78+
def get_all_samples(self) -> AudioSamples:
79+
"""Returns all the audio samples from the source.
80+
81+
To decode samples in a specific range, use
82+
:meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range`.
83+
84+
Returns:
85+
AudioSamples: The samples within the file.
86+
"""
87+
return self.get_samples_played_in_range()
88+
7889
def get_samples_played_in_range(
7990
self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None
8091
) -> AudioSamples:
8192
"""Returns audio samples in the given range.
8293
8394
Samples are in the half open range [start_seconds, stop_seconds).
8495
96+
To decode all the samples from beginning to end, you can call this
97+
method while leaving ``start_seconds`` and ``stop_seconds`` to their
98+
default values, or use
99+
:meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` as a more
100+
convenient alias.
101+
85102
Args:
86103
start_seconds (float): Time, in seconds, of the start of the
87104
range. Default: 0.
88-
stop_seconds (float): Time, in seconds, of the end of the
89-
range. As a half open range, the end is excluded.
105+
stop_seconds (float or None): Time, in seconds, of the end of the
106+
range. As a half open range, the end is excluded. Default: None,
107+
which decodes samples until the end.
90108
91109
Returns:
92110
AudioSamples: The samples within the specified range.

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ class VideoDecoder {
239239
double startSeconds,
240240
double stopSeconds);
241241

242-
// TODO-AUDIO: Should accept sampleRate
243242
AudioFramesOutput getFramesPlayedInRangeAudio(
244243
double startSeconds,
245244
std::optional<double> stopSecondsOptional = std::nullopt);

test/decoders/test_decoders.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ def test_negative_start(self, asset):
982982

983983
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
984984
@pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999))
985-
def test_get_all_samples(self, asset, stop_seconds):
985+
def test_get_all_samples_with_range(self, asset, stop_seconds):
986986
decoder = AudioDecoder(asset.path)
987987

988988
if stop_seconds == "duration":
@@ -998,6 +998,14 @@ def test_get_all_samples(self, asset, stop_seconds):
998998
assert samples.sample_rate == asset.sample_rate
999999
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
10001000

1001+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1002+
def test_get_all_samples(self, asset):
1003+
decoder = AudioDecoder(asset.path)
1004+
torch.testing.assert_close(
1005+
decoder.get_all_samples().data,
1006+
decoder.get_samples_played_in_range().data,
1007+
)
1008+
10011009
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
10021010
def test_at_frame_boundaries(self, asset):
10031011
decoder = AudioDecoder(asset.path)

0 commit comments

Comments
 (0)