Skip to content

Commit ee05058

Browse files
committed
Add tuto
1 parent d7ec60e commit ee05058

File tree

4 files changed

+85
-7
lines changed

4 files changed

+85
-7
lines changed

examples/decoding/README.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
Decoding
22
--------
3-

examples/encoding/audio_encoding.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,78 @@
1414
"""
1515

1616
# %%
17-
print("hello")
17+
# Let's first generate some samples to be encoded. The data to be encoded could
18+
# also just come from an :class:`~torchcodec.decoders.AudioDecoder`!
19+
import torch
20+
from IPython.display import Audio as play_audio
21+
22+
23+
def make_sinewave() -> tuple[torch.Tensor, int]:
24+
freq_A = 440 # Hz
25+
sample_rate = 16000 # Hz
26+
duration_seconds = 3 # seconds
27+
t = torch.linspace(0, duration_seconds, int(sample_rate * duration_seconds), dtype=torch.float32)
28+
return torch.sin(2 * torch.pi * freq_A * t), sample_rate
29+
30+
31+
samples, sample_rate = make_sinewave()
32+
33+
print(f"Encoding samples with {samples.shape = } and {sample_rate = }")
34+
play_audio(samples, rate=sample_rate)
35+
36+
# %%
37+
# We first instantiate an :class:`~torchcodec.encoders.AudioEncoder`. We pass it
38+
# the samples to be encoded. The samples must a 2D tensors of shape
39+
# ``(num_channels, num_samples)``, or in this case, a 1D tensor where
40+
# ``num_channels`` is assumed to be 1. The values must be float values
41+
# normalized in ``[-1, 1]``: this is also what the
42+
# :class:`~torchcodec.decoders.AudioDecoder` would return.
43+
#
44+
# .. note::
45+
#
46+
# The ``sample_rate`` parameter corresponds to the sample rate of the
47+
# *input*, not the desired encoded sample rate.
48+
from torchcodec.encoders import AudioEncoder
49+
50+
encoder = AudioEncoder(samples=samples, sample_rate=sample_rate)
51+
52+
53+
# %%
54+
# :class:`~torchcodec.encoders.AudioEncoder` supports encoding samples into a
55+
# file via the :meth:`~torchcodec.encoders.AudioEncoder.to_file` method, or to
56+
# raw bytes via :meth:`~torchcodec.encoders.AudioEncoder.to_tensor`. For the
57+
# purpose of this tutorial we'll use
58+
# :meth:`~torchcodec.encoders.AudioEncoder.to_tensor`, so that we can easily
59+
# re-decode the encoded samples and check their properies. The
60+
# :meth:`~torchcodec.encoders.AudioEncoder.to_file` method works very similarly.
61+
62+
encoded_samples = encoder.to_tensor(format="mp3")
63+
print(f"{encoded_samples.shape = }, {encoded_samples.dtype = }")
64+
65+
66+
# %%
67+
# That's it!
68+
#
69+
# Now that we have our encoded data, we can decode it back, to make sure it
70+
# looks and sounds as expected:
71+
from torchcodec.decoders import AudioDecoder
72+
73+
samples_back = AudioDecoder(encoded_samples).get_all_samples()
74+
75+
print(samples_back)
76+
play_audio(samples_back.data, rate=samples_back.sample_rate)
77+
78+
# %%
79+
# The encoder supports some encoding options that allow you to change how to
80+
# data is encoded. For example, we can decide to encode our mono data (1
81+
# channel) into stereo data (2 channels):
82+
encoded_samples = encoder.to_tensor(format="wav", num_channels=2)
83+
84+
stereo_samples_back = AudioDecoder(encoded_samples).get_all_samples()
85+
86+
print(stereo_samples_back)
87+
play_audio(stereo_samples_back.data, rate=stereo_samples_back.sample_rate)
88+
89+
# %%
90+
# Check the docstring of the encoding methods to learn about the different
91+
# encoding options.

src/torchcodec/encoders/_audio_encoder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ class AudioEncoder:
1212
1313
Args:
1414
samples (``torch.Tensor``): The samples to encode. This must be a 2D
15-
tensor of shape ``(num_channels, num_samples)``
16-
sample_rate (int): The sample rate of the **input** ``samples``.
15+
tensor of shape ``(num_channels, num_samples)``, or a 1D tensor in
16+
which case ``num_channels = 1`` is assumed. Values must be float
17+
values in ``[-1, 1]``.
18+
sample_rate (int): The sample rate of the **input** ``samples``.
1719
"""
1820

1921
def __init__(self, samples: Tensor, *, sample_rate: int):
@@ -24,8 +26,11 @@ def __init__(self, samples: Tensor, *, sample_rate: int):
2426
raise ValueError(
2527
f"Expected samples to be a Tensor, got {type(samples) = }."
2628
)
29+
if samples.ndim == 1:
30+
# make it 2D and assume 1 channel
31+
samples = samples[None, :]
2732
if samples.ndim != 2:
28-
raise ValueError(f"Expected 2D samples, got {samples.shape = }.")
33+
raise ValueError(f"Expected 1D or 2D samples, got {samples.shape = }.")
2934
if samples.dtype != torch.float32:
3035
raise ValueError(f"Expected float32 samples, got {samples.dtype = }.")
3136
if sample_rate <= 0:

test/test_encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def decode(self, source) -> torch.Tensor:
2626
def test_bad_input(self):
2727
with pytest.raises(ValueError, match="Expected samples to be a Tensor"):
2828
AudioEncoder(samples=123, sample_rate=32_000)
29-
with pytest.raises(ValueError, match="Expected 2D samples"):
30-
AudioEncoder(samples=torch.rand(10), sample_rate=32_000)
29+
with pytest.raises(ValueError, match="Expected 1D or 2D samples"):
30+
AudioEncoder(samples=torch.rand(3, 4, 5), sample_rate=32_000)
3131
with pytest.raises(ValueError, match="Expected float32 samples"):
3232
AudioEncoder(
3333
samples=torch.rand(10, 10, dtype=torch.float64), sample_rate=32_000

0 commit comments

Comments
 (0)