@@ -20,15 +20,16 @@ def load_with_torchcodec(
20
20
21
21
.. note::
22
22
23
- This function supports the same API as ``torchaudio.load()``, and relies
24
- on TorchCodec's decoding capabilities under the hood. It is provided for
25
- convenience, but we do recommend that you port your code to natively use
26
- ``torchcodec``'s ``AudioDecoder`` class for better performance:
23
+ This function supports the same API as :func:`~torchaudio.load`, and
24
+ relies on TorchCodec's decoding capabilities under the hood. It is
25
+ provided for convenience, but we do recommend that you port your code to
26
+ natively use ``torchcodec``'s ``AudioDecoder`` class for better
27
+ performance:
27
28
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.
28
- In TorchAudio 2.9, `` torchaudio.load()` ` will be relying on
29
- `` load_with_torchcodec` `. Note that some parameters of
30
- `` torchaudio.load()` `, like ``normalize``, ``buffer_size``, and
31
- ``backend``, are ignored by `` load_with_torchcodec` `.
29
+ In TorchAudio 2.9, :func:`~ torchaudio.load` will be relying on
30
+ :func:`~torchaudio. load_with_torchcodec`. Note that some parameters of
31
+ :func:`~ torchaudio.load`, like ``normalize``, ``buffer_size``, and
32
+ ``backend``, are ignored by :func:`~torchaudio. load_with_torchcodec`.
32
33
33
34
34
35
Args:
@@ -158,4 +159,194 @@ def load_with_torchcodec(
158
159
if not channels_first :
159
160
data = data .transpose (0 , 1 ) # [channel, time] -> [time, channel]
160
161
161
- return data , sample_rate
162
+ return data , sample_rate
163
+
164
+
165
+ def save_with_torchcodec (
166
+ uri : Union [str , os .PathLike ],
167
+ src : torch .Tensor ,
168
+ sample_rate : int ,
169
+ channels_first : bool = True ,
170
+ format : Optional [str ] = None ,
171
+ encoding : Optional [str ] = None ,
172
+ bits_per_sample : Optional [int ] = None ,
173
+ buffer_size : int = 4096 ,
174
+ backend : Optional [str ] = None ,
175
+ compression : Optional [Union [float , int ]] = None ,
176
+ ) -> None :
177
+ """Save audio data to file using TorchCodec's AudioEncoder.
178
+
179
+ .. note::
180
+
181
+ This function supports the same API as :func:`~torchaudio.save`, and
182
+ relies on TorchCodec's encoding capabilities under the hood. It is
183
+ provided for convenience, but we do recommend that you port your code to
184
+ natively use ``torchcodec``'s ``AudioEncoder`` class for better
185
+ performance:
186
+ https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder.
187
+ In TorchAudio 2.9, :func:`~torchaudio.save` will be relying on
188
+ :func:`~torchaudio.save_with_torchcodec`. Note that some parameters of
189
+ :func:`~torchaudio.save`, like ``format``, ``encoding``,
190
+ ``bits_per_sample``, ``buffer_size``, and ``backend``, are ignored by
191
+ are ignored by :func:`~torchaudio.save_with_torchcodec`.
192
+
193
+ This function provides a TorchCodec-based alternative to torchaudio.save
194
+ with the same API. TorchCodec's AudioEncoder provides efficient encoding
195
+ with FFmpeg under the hood.
196
+
197
+ Args:
198
+ uri (path-like object):
199
+ Path to save the audio file. The file extension determines the format.
200
+
201
+ src (torch.Tensor):
202
+ Audio data to save. Must be a 1D or 2D tensor with float32 values
203
+ in the range [-1, 1]. If 2D, shape should be [channel, time] when
204
+ channels_first=True, or [time, channel] when channels_first=False.
205
+
206
+ sample_rate (int):
207
+ Sample rate of the audio data.
208
+
209
+ channels_first (bool, optional):
210
+ Indicates whether the input tensor has channels as the first dimension.
211
+ If True, expects [channel, time]. If False, expects [time, channel].
212
+ Default: True.
213
+
214
+ format (str or None, optional):
215
+ Audio format hint. Not used by TorchCodec (format is determined by
216
+ file extension). A warning is issued if provided.
217
+ Default: None.
218
+
219
+ encoding (str or None, optional):
220
+ Audio encoding. Not fully supported by TorchCodec AudioEncoder.
221
+ A warning is issued if provided. Default: None.
222
+
223
+ bits_per_sample (int or None, optional):
224
+ Bits per sample. Not directly supported by TorchCodec AudioEncoder.
225
+ A warning is issued if provided. Default: None.
226
+
227
+ buffer_size (int, optional):
228
+ Not used by TorchCodec AudioEncoder. Provided for API compatibility.
229
+ A warning is issued if not default value. Default: 4096.
230
+
231
+ backend (str or None, optional):
232
+ Not used by TorchCodec AudioEncoder. Provided for API compatibility.
233
+ A warning is issued if provided. Default: None.
234
+
235
+ compression (float, int or None, optional):
236
+ Compression level or bit rate. Maps to bit_rate parameter in
237
+ TorchCodec AudioEncoder. Default: None.
238
+
239
+ Raises:
240
+ ImportError: If torchcodec is not available.
241
+ ValueError: If input parameters are invalid.
242
+ RuntimeError: If TorchCodec fails to encode the audio.
243
+
244
+ Note:
245
+ - TorchCodec AudioEncoder expects float32 samples in [-1, 1] range.
246
+ - Some parameters (format, encoding, bits_per_sample, buffer_size, backend)
247
+ are not used by TorchCodec but are provided for API compatibility.
248
+ - The output format is determined by the file extension in the uri.
249
+ - TorchCodec uses FFmpeg under the hood for encoding.
250
+ """
251
+ # Import torchcodec here to provide clear error if not available
252
+ try :
253
+ from torchcodec .encoders import AudioEncoder
254
+ except ImportError as e :
255
+ raise ImportError (
256
+ "TorchCodec is required for save_with_torchcodec. "
257
+ "Please install torchcodec to use this function."
258
+ ) from e
259
+
260
+ # Parameter validation and warnings
261
+ if format is not None :
262
+ import warnings
263
+ warnings .warn (
264
+ "The 'format' parameter is not used by TorchCodec AudioEncoder. "
265
+ "Format is determined by the file extension." ,
266
+ UserWarning ,
267
+ stacklevel = 2
268
+ )
269
+
270
+ if encoding is not None :
271
+ import warnings
272
+ warnings .warn (
273
+ "The 'encoding' parameter is not fully supported by TorchCodec AudioEncoder." ,
274
+ UserWarning ,
275
+ stacklevel = 2
276
+ )
277
+
278
+ if bits_per_sample is not None :
279
+ import warnings
280
+ warnings .warn (
281
+ "The 'bits_per_sample' parameter is not directly supported by TorchCodec AudioEncoder." ,
282
+ UserWarning ,
283
+ stacklevel = 2
284
+ )
285
+
286
+ if buffer_size != 4096 :
287
+ import warnings
288
+ warnings .warn (
289
+ "The 'buffer_size' parameter is not used by TorchCodec AudioEncoder." ,
290
+ UserWarning ,
291
+ stacklevel = 2
292
+ )
293
+
294
+ if backend is not None :
295
+ import warnings
296
+ warnings .warn (
297
+ "The 'backend' parameter is not used by TorchCodec AudioEncoder." ,
298
+ UserWarning ,
299
+ stacklevel = 2
300
+ )
301
+
302
+ # Input validation
303
+ if not isinstance (src , torch .Tensor ):
304
+ raise ValueError (f"Expected src to be a torch.Tensor, got { type (src )} " )
305
+
306
+ if src .dtype != torch .float32 :
307
+ src = src .float ()
308
+
309
+ if sample_rate <= 0 :
310
+ raise ValueError (f"sample_rate must be positive, got { sample_rate } " )
311
+
312
+ # Handle tensor shape and channels_first
313
+ if src .ndim == 1 :
314
+ # Convert to 2D: [1, time] for channels_first=True
315
+ if channels_first :
316
+ data = src .unsqueeze (0 ) # [1, time]
317
+ else :
318
+ # For channels_first=False, input is [time] -> reshape to [time, 1] -> transpose to [1, time]
319
+ data = src .unsqueeze (1 ).transpose (0 , 1 ) # [time, 1] -> [1, time]
320
+ elif src .ndim == 2 :
321
+ if channels_first :
322
+ data = src # Already [channel, time]
323
+ else :
324
+ data = src .transpose (0 , 1 ) # [time, channel] -> [channel, time]
325
+ else :
326
+ raise ValueError (f"Expected 1D or 2D tensor, got { src .ndim } D tensor" )
327
+
328
+ # Create AudioEncoder
329
+ try :
330
+ encoder = AudioEncoder (data , sample_rate = sample_rate )
331
+ except Exception as e :
332
+ raise RuntimeError (f"Failed to create AudioEncoder: { e } " ) from e
333
+
334
+ # Determine bit_rate from compression parameter
335
+ bit_rate = None
336
+ if compression is not None :
337
+ if isinstance (compression , (int , float )):
338
+ bit_rate = int (compression )
339
+ else :
340
+ import warnings
341
+ warnings .warn (
342
+ f"Unsupported compression type { type (compression )} . "
343
+ "TorchCodec AudioEncoder expects int or float for bit_rate." ,
344
+ UserWarning ,
345
+ stacklevel = 2
346
+ )
347
+
348
+ # Save to file
349
+ try :
350
+ encoder .to_file (uri , bit_rate = bit_rate )
351
+ except Exception as e :
352
+ raise RuntimeError (f"Failed to save audio to { uri } : { e } " ) from e
0 commit comments