Skip to content

Commit 800b9dc

Browse files
authored
Add load_with_torchcodec, modify load()'s warnings (#3974)
1 parent 6c57850 commit 800b9dc

File tree

6 files changed

+385
-6
lines changed

6 files changed

+385
-6
lines changed

.github/scripts/unittest-linux/install.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ case $GPU_ARCH_TYPE in
7474
;;
7575
esac
7676
PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${GPU_ARCH_ID}"
77-
pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}"
77+
pip install --progress-bar=off --pre torch torchcodec --index-url="${PYTORCH_WHEEL_INDEX}"
7878

7979

8080
# 2. Install torchaudio
@@ -86,6 +86,10 @@ python setup.py install
8686

8787
# 3. Install Test tools
8888
printf "* Installing test tools\n"
89+
# On this CI, for whatever reason, we're only able to install ffmpeg 4.
90+
conda install -y "ffmpeg<5"
91+
python -c "import torch; import torchaudio; import torchcodec; print(torch.__version__, torchaudio.__version__, torchcodec.__version__)"
92+
8993
NUMBA_DEV_CHANNEL=""
9094
if [[ "$(python --version)" = *3.9* || "$(python --version)" = *3.10* ]]; then
9195
# Numba isn't available for Python 3.9 and 3.10 except on the numba dev channel and building from source fails
@@ -94,7 +98,7 @@ if [[ "$(python --version)" = *3.9* || "$(python --version)" = *3.10* ]]; then
9498
fi
9599
(
96100
set -x
97-
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} sox libvorbis parameterized 'requests>=2.20' 'ffmpeg>=6,<7'
101+
conda install -y -c conda-forge ${NUMBA_DEV_CHANNEL} sox libvorbis parameterized 'requests>=2.20'
98102
pip install kaldi-io SoundFile librosa coverage pytest pytest-cov scipy expecttest unidecode inflect Pillow sentencepiece pytorch-lightning 'protobuf<4.21.0' demucs tinytag pyroomacoustics flashlight-text git+https://github.com/kpu/kenlm
99103

100104
# TODO: might be better to fix the single call to `pip install` above

docs/source/torchaudio.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ torchaudio
77
Starting with version 2.8, we are refactoring TorchAudio to transition it
88
into a maintenance phase. As a result:
99

10-
- The APIs listed below are deprecated in 2.8 and will be removed in 2.9.
10+
- Most APIs listed below are deprecated in 2.8 and will be removed in 2.9.
1111
- The decoding and encoding capabilities of PyTorch for both audio and video
12-
are being consolidated into TorchCodec.
12+
are being consolidated into TorchCodec. We provide
13+
``torchaudio.load_with_torchcodec()`` as a replacement for
14+
``torchaudio.load()``.
1315

1416
Please see https://github.com/pytorch/audio/issues/3902 for more information.
1517

@@ -26,6 +28,7 @@ it easy to handle audio data.
2628

2729
info
2830
load
31+
load_with_torchcodec
2932
save
3033
list_audio_backends
3134

src/torchaudio/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77
get_audio_backend as _get_audio_backend,
88
info as _info,
99
list_audio_backends as _list_audio_backends,
10-
load as _load,
10+
load,
1111
save as _save,
1212
set_audio_backend as _set_audio_backend,
1313
)
14+
from ._torchcodec import load_with_torchcodec
1415

1516
AudioMetaData = dropping_class_io_support(_AudioMetaData)
1617
get_audio_backend = dropping_io_support(_get_audio_backend)
1718
info = dropping_io_support(_info)
1819
list_audio_backends = dropping_io_support(_list_audio_backends)
19-
load = dropping_io_support(_load)
2020
save = dropping_io_support(_save)
2121
set_audio_backend = dropping_io_support(_set_audio_backend)
2222

@@ -45,6 +45,7 @@
4545
__all__ = [
4646
"AudioMetaData",
4747
"load",
48+
"load_with_torchcodec",
4849
"info",
4950
"save",
5051
"io",

src/torchaudio/_backend/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from functools import lru_cache
33
from typing import BinaryIO, Dict, Optional, Tuple, Type, Union
4+
import warnings
45

56
import torch
67

@@ -127,6 +128,14 @@ def load(
127128
) -> Tuple[torch.Tensor, int]:
128129
"""Load audio data from source.
129130
131+
.. warning::
132+
In 2.9, this function's implementation will be changed to use
133+
:func:`~torchaudio.load_with_torchcodec` under the hood. Some
134+
parameters like ``normalize``, ``format``, ``buffer_size``, and
135+
``backend`` will be ignored. We recommend that you port your code to
136+
rely directly on TorchCodec's decoder instead:
137+
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.
138+
130139
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
131140
``float32`` dtype, and the shape of `[channel, time]`.
132141
@@ -201,6 +210,14 @@ def load(
201210
integer type, else ``float32`` type. If ``channels_first=True``, it has
202211
`[channel, time]` else `[time, channel]`.
203212
"""
213+
warnings.warn(
214+
"In 2.9, this function's implementation will be changed to use "
215+
"torchaudio.load_with_torchcodec` under the hood. Some "
216+
"parameters like ``normalize``, ``format``, ``buffer_size``, and "
217+
"``backend`` will be ignored. We recommend that you port your code to "
218+
"rely directly on TorchCodec's decoder instead: "
219+
"https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder."
220+
)
204221
backend = dispatcher(uri, format, backend)
205222
return backend.load(uri, frame_offset, num_frames, normalize, channels_first, format, buffer_size)
206223

src/torchaudio/_torchcodec.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""TorchCodec integration for TorchAudio."""
2+
3+
import os
4+
from typing import BinaryIO, Optional, Tuple, Union
5+
6+
import torch
7+
8+
9+
def load_with_torchcodec(
10+
uri: Union[BinaryIO, str, os.PathLike],
11+
frame_offset: int = 0,
12+
num_frames: int = -1,
13+
normalize: bool = True,
14+
channels_first: bool = True,
15+
format: Optional[str] = None,
16+
buffer_size: int = 4096,
17+
backend: Optional[str] = None,
18+
) -> Tuple[torch.Tensor, int]:
19+
"""Load audio data from source using TorchCodec's AudioDecoder.
20+
21+
.. note::
22+
23+
This function supports the same API as ``torchaudio.load()``, and relies
24+
on TorchCodec's decoding capabilities under the hood. It is provided for
25+
convenience, but we do recommend that you port your code to natively use
26+
``torchcodec``'s ``AudioDecoder`` class for better performance:
27+
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.
28+
In TorchAudio 2.9, ``torchaudio.load()`` will be relying on
29+
``load_with_torchcodec``. Note that some parameters of
30+
``torchaudio.load()``, like ``normalize``, ``buffer_size``, and
31+
``backend``, are ignored by ``load_with_torchcodec``.
32+
33+
34+
Args:
35+
uri (path-like object or file-like object):
36+
Source of audio data. The following types are accepted:
37+
38+
* ``path-like``: File path or URL.
39+
* ``file-like``: Object with ``read(size: int) -> bytes`` method.
40+
41+
frame_offset (int, optional):
42+
Number of samples to skip before start reading data.
43+
num_frames (int, optional):
44+
Maximum number of samples to read. ``-1`` reads all the remaining samples,
45+
starting from ``frame_offset``.
46+
normalize (bool, optional):
47+
TorchCodec always returns normalized float32 samples. This parameter
48+
is ignored and a warning is issued if set to False.
49+
Default: ``True``.
50+
channels_first (bool, optional):
51+
When True, the returned Tensor has dimension `[channel, time]`.
52+
Otherwise, the returned Tensor's dimension is `[time, channel]`.
53+
format (str or None, optional):
54+
Format hint for the decoder. May not be supported by all TorchCodec
55+
decoders. (Default: ``None``)
56+
buffer_size (int, optional):
57+
Not used by TorchCodec AudioDecoder. Provided for API compatibility.
58+
backend (str or None, optional):
59+
Not used by TorchCodec AudioDecoder. Provided for API compatibility.
60+
61+
Returns:
62+
(torch.Tensor, int): Resulting Tensor and sample rate.
63+
Always returns float32 tensors. If ``channels_first=True``, shape is
64+
`[channel, time]`, otherwise `[time, channel]`.
65+
66+
Raises:
67+
ImportError: If torchcodec is not available.
68+
ValueError: If unsupported parameters are used.
69+
RuntimeError: If TorchCodec fails to decode the audio.
70+
71+
Note:
72+
- TorchCodec always returns normalized float32 samples, so the ``normalize``
73+
parameter has no effect.
74+
- The ``buffer_size`` and ``backend`` parameters are ignored.
75+
- Not all audio formats supported by torchaudio backends may be supported
76+
by TorchCodec.
77+
"""
78+
# Import torchcodec here to provide clear error if not available
79+
try:
80+
from torchcodec.decoders import AudioDecoder
81+
except ImportError as e:
82+
raise ImportError(
83+
"TorchCodec is required for load_with_torchcodec. "
84+
"Please install torchcodec to use this function."
85+
) from e
86+
87+
# Parameter validation and warnings
88+
if not normalize:
89+
import warnings
90+
warnings.warn(
91+
"TorchCodec AudioDecoder always returns normalized float32 samples. "
92+
"The 'normalize=False' parameter is ignored.",
93+
UserWarning,
94+
stacklevel=2
95+
)
96+
97+
if buffer_size != 4096:
98+
import warnings
99+
warnings.warn(
100+
"The 'buffer_size' parameter is not used by TorchCodec AudioDecoder.",
101+
UserWarning,
102+
stacklevel=2
103+
)
104+
105+
if backend is not None:
106+
import warnings
107+
warnings.warn(
108+
"The 'backend' parameter is not used by TorchCodec AudioDecoder.",
109+
UserWarning,
110+
stacklevel=2
111+
)
112+
113+
if format is not None:
114+
import warnings
115+
warnings.warn(
116+
"The 'format' parameter is not supported by TorchCodec AudioDecoder.",
117+
UserWarning,
118+
stacklevel=2
119+
)
120+
121+
# Create AudioDecoder
122+
try:
123+
decoder = AudioDecoder(uri)
124+
except Exception as e:
125+
raise RuntimeError(f"Failed to create AudioDecoder for {uri}: {e}") from e
126+
127+
# Get sample rate from metadata
128+
sample_rate = decoder.metadata.sample_rate
129+
if sample_rate is None:
130+
raise RuntimeError("Unable to determine sample rate from audio metadata")
131+
132+
# Decode the entire file first, then subsample manually
133+
# This is the simplest approach since torchcodec uses time-based indexing
134+
try:
135+
audio_samples = decoder.get_all_samples()
136+
except Exception as e:
137+
raise RuntimeError(f"Failed to decode audio samples: {e}") from e
138+
139+
data = audio_samples.data
140+
141+
# Apply frame_offset and num_frames (which are actually sample offsets)
142+
if frame_offset > 0:
143+
if frame_offset >= data.shape[1]:
144+
# Return empty tensor if offset is beyond available data
145+
empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0])
146+
return torch.zeros(empty_shape, dtype=torch.float32), sample_rate
147+
data = data[:, frame_offset:]
148+
149+
if num_frames == 0:
150+
# Return empty tensor if num_frames is 0
151+
empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0])
152+
return torch.zeros(empty_shape, dtype=torch.float32), sample_rate
153+
elif num_frames > 0:
154+
data = data[:, :num_frames]
155+
156+
# TorchCodec returns data in [channel, time] format by default
157+
# Handle channels_first parameter
158+
if not channels_first:
159+
data = data.transpose(0, 1) # [channel, time] -> [time, channel]
160+
161+
return data, sample_rate

0 commit comments

Comments
 (0)