Skip to content

Commit 00b0c91

Browse files
authored
Add save_with_torchcodec, modify save()'s warnings (#3975)
1 parent 800b9dc commit 00b0c91

File tree

6 files changed

+674
-208
lines changed

6 files changed

+674
-208
lines changed

docs/source/torchaudio.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ torchaudio
99

1010
- Most APIs listed below are deprecated in 2.8 and will be removed in 2.9.
1111
- The decoding and encoding capabilities of PyTorch for both audio and video
12-
are being consolidated into TorchCodec. We provide
13-
``torchaudio.load_with_torchcodec()`` as a replacement for
14-
``torchaudio.load()``.
12+
are being consolidated into TorchCodec. For convenience, we provide
13+
:func:`~torchaudio.load_with_torchcodec` as a replacement for
14+
:func:`~torchaudio.load` and :func:`~torchaudio.save_with_torchcodec` as a
15+
replacement for :func:`~torchaudio.save`, but we recommend that you port
16+
your code to native torchcodec APIs.
1517

1618
Please see https://github.com/pytorch/audio/issues/3902 for more information.
1719

@@ -30,6 +32,7 @@ it easy to handle audio data.
3032
load
3133
load_with_torchcodec
3234
save
35+
save_with_torchcodec
3336
list_audio_backends
3437

3538
.. _backend:

src/torchaudio/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,15 @@
88
info as _info,
99
list_audio_backends as _list_audio_backends,
1010
load,
11-
save as _save,
11+
save,
1212
set_audio_backend as _set_audio_backend,
1313
)
14-
from ._torchcodec import load_with_torchcodec
14+
from ._torchcodec import load_with_torchcodec, save_with_torchcodec
1515

1616
AudioMetaData = dropping_class_io_support(_AudioMetaData)
1717
get_audio_backend = dropping_io_support(_get_audio_backend)
1818
info = dropping_io_support(_info)
1919
list_audio_backends = dropping_io_support(_list_audio_backends)
20-
save = dropping_io_support(_save)
2120
set_audio_backend = dropping_io_support(_set_audio_backend)
2221

2322
from . import ( # noqa: F401
@@ -46,6 +45,7 @@
4645
"AudioMetaData",
4746
"load",
4847
"load_with_torchcodec",
48+
"save_with_torchcodec",
4949
"info",
5050
"save",
5151
"io",

src/torchaudio/_backend/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,14 @@ def save(
252252
):
253253
"""Save audio data to file.
254254
255+
.. warning::
256+
In 2.9, this function's implementation will be changed to use
257+
:func:`~torchaudio.save_with_torchcodec` under the hood. Some
258+
parameters like format, encoding, bits_per_sample, buffer_size, and
259+
``backend`` will be ignored. We recommend that you port your code to
260+
rely directly on TorchCodec's decoder instead:
261+
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder
262+
255263
Note:
256264
The formats this function can handle depend on the availability of backends.
257265
Please use the following functions to fetch the supported formats.
@@ -326,6 +334,14 @@ def save(
326334
Refer to http://sox.sourceforge.net/soxformat.html for more details.
327335
328336
"""
337+
warnings.warn(
338+
"In 2.9, this function's implementation will be changed to use "
339+
"torchaudio.save_with_torchcodec` under the hood. Some "
340+
"parameters like format, encoding, bits_per_sample, buffer_size, and "
341+
"``backend`` will be ignored. We recommend that you port your code to "
342+
"rely directly on TorchCodec's encoder instead: "
343+
"https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder"
344+
)
329345
backend = dispatcher(uri, format, backend)
330346
return backend.save(
331347
uri, src, sample_rate, channels_first, format, encoding, bits_per_sample, buffer_size, compression

src/torchaudio/_torchcodec.py

Lines changed: 200 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@ def load_with_torchcodec(
2020
2121
.. note::
2222
23-
This function supports the same API as ``torchaudio.load()``, and relies
24-
on TorchCodec's decoding capabilities under the hood. It is provided for
25-
convenience, but we do recommend that you port your code to natively use
26-
``torchcodec``'s ``AudioDecoder`` class for better performance:
23+
This function supports the same API as :func:`~torchaudio.load`, and
24+
relies on TorchCodec's decoding capabilities under the hood. It is
25+
provided for convenience, but we do recommend that you port your code to
26+
natively use ``torchcodec``'s ``AudioDecoder`` class for better
27+
performance:
2728
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.
28-
In TorchAudio 2.9, ``torchaudio.load()`` will be relying on
29-
``load_with_torchcodec``. Note that some parameters of
30-
``torchaudio.load()``, like ``normalize``, ``buffer_size``, and
31-
``backend``, are ignored by ``load_with_torchcodec``.
29+
In TorchAudio 2.9, :func:`~torchaudio.load` will be relying on
30+
:func:`~torchaudio.load_with_torchcodec`. Note that some parameters of
31+
:func:`~torchaudio.load`, like ``normalize``, ``buffer_size``, and
32+
``backend``, are ignored by :func:`~torchaudio.load_with_torchcodec`.
3233
3334
3435
Args:
@@ -158,4 +159,194 @@ def load_with_torchcodec(
158159
if not channels_first:
159160
data = data.transpose(0, 1) # [channel, time] -> [time, channel]
160161

161-
return data, sample_rate
162+
return data, sample_rate
163+
164+
165+
def save_with_torchcodec(
166+
uri: Union[str, os.PathLike],
167+
src: torch.Tensor,
168+
sample_rate: int,
169+
channels_first: bool = True,
170+
format: Optional[str] = None,
171+
encoding: Optional[str] = None,
172+
bits_per_sample: Optional[int] = None,
173+
buffer_size: int = 4096,
174+
backend: Optional[str] = None,
175+
compression: Optional[Union[float, int]] = None,
176+
) -> None:
177+
"""Save audio data to file using TorchCodec's AudioEncoder.
178+
179+
.. note::
180+
181+
This function supports the same API as :func:`~torchaudio.save`, and
182+
relies on TorchCodec's encoding capabilities under the hood. It is
183+
provided for convenience, but we do recommend that you port your code to
184+
natively use ``torchcodec``'s ``AudioEncoder`` class for better
185+
performance:
186+
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder.
187+
In TorchAudio 2.9, :func:`~torchaudio.save` will be relying on
188+
:func:`~torchaudio.save_with_torchcodec`. Note that some parameters of
189+
:func:`~torchaudio.save`, like ``format``, ``encoding``,
190+
``bits_per_sample``, ``buffer_size``, and ``backend``, are ignored by
191+
are ignored by :func:`~torchaudio.save_with_torchcodec`.
192+
193+
This function provides a TorchCodec-based alternative to torchaudio.save
194+
with the same API. TorchCodec's AudioEncoder provides efficient encoding
195+
with FFmpeg under the hood.
196+
197+
Args:
198+
uri (path-like object):
199+
Path to save the audio file. The file extension determines the format.
200+
201+
src (torch.Tensor):
202+
Audio data to save. Must be a 1D or 2D tensor with float32 values
203+
in the range [-1, 1]. If 2D, shape should be [channel, time] when
204+
channels_first=True, or [time, channel] when channels_first=False.
205+
206+
sample_rate (int):
207+
Sample rate of the audio data.
208+
209+
channels_first (bool, optional):
210+
Indicates whether the input tensor has channels as the first dimension.
211+
If True, expects [channel, time]. If False, expects [time, channel].
212+
Default: True.
213+
214+
format (str or None, optional):
215+
Audio format hint. Not used by TorchCodec (format is determined by
216+
file extension). A warning is issued if provided.
217+
Default: None.
218+
219+
encoding (str or None, optional):
220+
Audio encoding. Not fully supported by TorchCodec AudioEncoder.
221+
A warning is issued if provided. Default: None.
222+
223+
bits_per_sample (int or None, optional):
224+
Bits per sample. Not directly supported by TorchCodec AudioEncoder.
225+
A warning is issued if provided. Default: None.
226+
227+
buffer_size (int, optional):
228+
Not used by TorchCodec AudioEncoder. Provided for API compatibility.
229+
A warning is issued if not default value. Default: 4096.
230+
231+
backend (str or None, optional):
232+
Not used by TorchCodec AudioEncoder. Provided for API compatibility.
233+
A warning is issued if provided. Default: None.
234+
235+
compression (float, int or None, optional):
236+
Compression level or bit rate. Maps to bit_rate parameter in
237+
TorchCodec AudioEncoder. Default: None.
238+
239+
Raises:
240+
ImportError: If torchcodec is not available.
241+
ValueError: If input parameters are invalid.
242+
RuntimeError: If TorchCodec fails to encode the audio.
243+
244+
Note:
245+
- TorchCodec AudioEncoder expects float32 samples in [-1, 1] range.
246+
- Some parameters (format, encoding, bits_per_sample, buffer_size, backend)
247+
are not used by TorchCodec but are provided for API compatibility.
248+
- The output format is determined by the file extension in the uri.
249+
- TorchCodec uses FFmpeg under the hood for encoding.
250+
"""
251+
# Import torchcodec here to provide clear error if not available
252+
try:
253+
from torchcodec.encoders import AudioEncoder
254+
except ImportError as e:
255+
raise ImportError(
256+
"TorchCodec is required for save_with_torchcodec. "
257+
"Please install torchcodec to use this function."
258+
) from e
259+
260+
# Parameter validation and warnings
261+
if format is not None:
262+
import warnings
263+
warnings.warn(
264+
"The 'format' parameter is not used by TorchCodec AudioEncoder. "
265+
"Format is determined by the file extension.",
266+
UserWarning,
267+
stacklevel=2
268+
)
269+
270+
if encoding is not None:
271+
import warnings
272+
warnings.warn(
273+
"The 'encoding' parameter is not fully supported by TorchCodec AudioEncoder.",
274+
UserWarning,
275+
stacklevel=2
276+
)
277+
278+
if bits_per_sample is not None:
279+
import warnings
280+
warnings.warn(
281+
"The 'bits_per_sample' parameter is not directly supported by TorchCodec AudioEncoder.",
282+
UserWarning,
283+
stacklevel=2
284+
)
285+
286+
if buffer_size != 4096:
287+
import warnings
288+
warnings.warn(
289+
"The 'buffer_size' parameter is not used by TorchCodec AudioEncoder.",
290+
UserWarning,
291+
stacklevel=2
292+
)
293+
294+
if backend is not None:
295+
import warnings
296+
warnings.warn(
297+
"The 'backend' parameter is not used by TorchCodec AudioEncoder.",
298+
UserWarning,
299+
stacklevel=2
300+
)
301+
302+
# Input validation
303+
if not isinstance(src, torch.Tensor):
304+
raise ValueError(f"Expected src to be a torch.Tensor, got {type(src)}")
305+
306+
if src.dtype != torch.float32:
307+
src = src.float()
308+
309+
if sample_rate <= 0:
310+
raise ValueError(f"sample_rate must be positive, got {sample_rate}")
311+
312+
# Handle tensor shape and channels_first
313+
if src.ndim == 1:
314+
# Convert to 2D: [1, time] for channels_first=True
315+
if channels_first:
316+
data = src.unsqueeze(0) # [1, time]
317+
else:
318+
# For channels_first=False, input is [time] -> reshape to [time, 1] -> transpose to [1, time]
319+
data = src.unsqueeze(1).transpose(0, 1) # [time, 1] -> [1, time]
320+
elif src.ndim == 2:
321+
if channels_first:
322+
data = src # Already [channel, time]
323+
else:
324+
data = src.transpose(0, 1) # [time, channel] -> [channel, time]
325+
else:
326+
raise ValueError(f"Expected 1D or 2D tensor, got {src.ndim}D tensor")
327+
328+
# Create AudioEncoder
329+
try:
330+
encoder = AudioEncoder(data, sample_rate=sample_rate)
331+
except Exception as e:
332+
raise RuntimeError(f"Failed to create AudioEncoder: {e}") from e
333+
334+
# Determine bit_rate from compression parameter
335+
bit_rate = None
336+
if compression is not None:
337+
if isinstance(compression, (int, float)):
338+
bit_rate = int(compression)
339+
else:
340+
import warnings
341+
warnings.warn(
342+
f"Unsupported compression type {type(compression)}. "
343+
"TorchCodec AudioEncoder expects int or float for bit_rate.",
344+
UserWarning,
345+
stacklevel=2
346+
)
347+
348+
# Save to file
349+
try:
350+
encoder.to_file(uri, bit_rate=bit_rate)
351+
except Exception as e:
352+
raise RuntimeError(f"Failed to save audio to {uri}: {e}") from e

0 commit comments

Comments
 (0)