Skip to content

Commit 02f1fe4

Browse files
committed
Add get_all_samples() method
1 parent abc9b10 commit 02f1fe4

File tree

4 files changed

+23
-8
lines changed

4 files changed

+23
-8
lines changed

examples/audio_decoding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def play_audio(samples):
5959
# ----------------
6060
#
6161
# To get decoded samples, we just need to call the
62-
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method,
62+
# :meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` method,
6363
# which returns an :class:`~torchcodec.AudioSamples` object:
6464

65-
samples = decoder.get_samples_played_in_range()
65+
samples = decoder.get_all_samples()
6666

6767
print(samples)
6868
play_audio(samples)
@@ -80,9 +80,9 @@ def play_audio(samples):
8080
# Specifying a range
8181
# ------------------
8282
#
83-
# By default,
84-
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` decodes
85-
# the entire audio stream, but we can specify a custom range:
83+
# If we don't need all the samples, we can use
84+
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` to
85+
# decode the samples within a custom range:
8686

8787
samples = decoder.get_samples_played_in_range(start_seconds=10, stop_seconds=70)
8888

@@ -99,7 +99,7 @@ def play_audio(samples):
9999
# increased:
100100

101101
decoder = AudioDecoder(raw_audio_bytes, sample_rate=16_000)
102-
samples = decoder.get_samples_played_in_range(start_seconds=0)
102+
samples = decoder.get_all_samples()
103103

104104
print(samples)
105105
play_audio(samples)

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@ def __init__(
7575
sample_rate if sample_rate is not None else self.metadata.sample_rate
7676
)
7777

78+
def get_all_samples(self) -> AudioSamples:
79+
"""Returns all the audio samples from the source.
80+
81+
Returns:
82+
AudioSamples: The samples within the file.
83+
"""
84+
return self.get_samples_played_in_range()
85+
7886
def get_samples_played_in_range(
7987
self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None
8088
) -> AudioSamples:

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ class VideoDecoder {
239239
double startSeconds,
240240
double stopSeconds);
241241

242-
// TODO-AUDIO: Should accept sampleRate
243242
AudioFramesOutput getFramesPlayedInRangeAudio(
244243
double startSeconds,
245244
std::optional<double> stopSecondsOptional = std::nullopt);

test/decoders/test_decoders.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ def test_negative_start(self, asset):
982982

983983
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
984984
@pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999))
985-
def test_get_all_samples(self, asset, stop_seconds):
985+
def test_get_all_samples_with_range(self, asset, stop_seconds):
986986
decoder = AudioDecoder(asset.path)
987987

988988
if stop_seconds == "duration":
@@ -998,6 +998,14 @@ def test_get_all_samples(self, asset, stop_seconds):
998998
assert samples.sample_rate == asset.sample_rate
999999
assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds
10001000

1001+
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
1002+
def test_get_all_samples(self, asset):
1003+
decoder = AudioDecoder(asset.path)
1004+
torch.testing.assert_close(
1005+
decoder.get_all_samples().data,
1006+
decoder.get_samples_played_in_range().data,
1007+
)
1008+
10011009
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
10021010
def test_at_frame_boundaries(self, asset):
10031011
decoder = AudioDecoder(asset.path)

0 commit comments

Comments
 (0)