@@ -209,6 +209,93 @@ def load(
209209 return waveform , sample_rate
210210
211211
212+ def _get_subtype_for_wav (
213+ dtype : torch .dtype ,
214+ encoding : str ,
215+ bits_per_sample : int ):
216+ if not encoding :
217+ if not bits_per_sample :
218+ subtype = {
219+ torch .uint8 : "PCM_U8" ,
220+ torch .int16 : "PCM_16" ,
221+ torch .int32 : "PCM_32" ,
222+ torch .float32 : "FLOAT" ,
223+ torch .float64 : "DOUBLE" ,
224+ }.get (dtype )
225+ if not subtype :
226+ raise ValueError (f"Unsupported dtype for wav: { dtype } " )
227+ return subtype
228+ if bits_per_sample == 8 :
229+ return "PCM_U8"
230+ return f"PCM_{ bits_per_sample } "
231+ if encoding == "PCM_S" :
232+ if not bits_per_sample :
233+ return "PCM_32"
234+ if bits_per_sample == 8 :
235+ raise ValueError ("wav does not support 8-bit signed PCM encoding." )
236+ return f"PCM_{ bits_per_sample } "
237+ if encoding == "PCM_U" :
238+ if bits_per_sample in (None , 8 ):
239+ return "PCM_U8"
240+ raise ValueError ("wav only supports 8-bit unsigned PCM encoding." )
241+ if encoding == "PCM_F" :
242+ if bits_per_sample in (None , 32 ):
243+ return "FLOAT"
244+ if bits_per_sample == 64 :
245+ return "DOUBLE"
246+ raise ValueError ("wav only supports 32/64-bit float PCM encoding." )
247+ if encoding == "ULAW" :
248+ if bits_per_sample in (None , 8 ):
249+ return "ULAW"
250+ raise ValueError ("wav only supports 8-bit mu-law encoding." )
251+ if encoding == "ALAW" :
252+ if bits_per_sample in (None , 8 ):
253+ return "ALAW"
254+ raise ValueError ("wav only supports 8-bit a-law encoding." )
255+ raise ValueError (f"wav does not support { encoding } ." )
256+
257+
258+ def _get_subtype_for_sphere (encoding : str , bits_per_sample : int ):
259+ if encoding in (None , "PCM_S" ):
260+ return f"PCM_{ bits_per_sample } " if bits_per_sample else "PCM_32"
261+ if encoding in ("PCM_U" , "PCM_F" ):
262+ raise ValueError (f"sph does not support { encoding } encoding." )
263+ if encoding == "ULAW" :
264+ if bits_per_sample in (None , 8 ):
265+ return "ULAW"
266+ raise ValueError ("sph only supports 8-bit for mu-law encoding." )
267+ if encoding == "ALAW" :
268+ return "ALAW"
269+ raise ValueError (f"sph does not support { encoding } ." )
270+
271+
272+ def _get_subtype (
273+ dtype : torch .dtype ,
274+ format : str ,
275+ encoding : str ,
276+ bits_per_sample : int ):
277+ if format == "wav" :
278+ return _get_subtype_for_wav (dtype , encoding , bits_per_sample )
279+ if format == "flac" :
280+ if encoding :
281+ raise ValueError ("flac does not support encoding." )
282+ if not bits_per_sample :
283+ return "PCM_24"
284+ if bits_per_sample > 24 :
285+ raise ValueError ("flac does not support bits_per_sample > 24." )
286+ return "PCM_S8" if bits_per_sample == 8 else f"PCM_{ bits_per_sample } "
287+ if format in ("ogg" , "vorbis" ):
288+ if encoding or bits_per_sample :
289+ raise ValueError (
290+ "ogg/vorbis does not support encoding/bits_per_sample." )
291+ return "VORBIS"
292+ if format == "sph" :
293+ return _get_subtype_for_sphere (encoding , bits_per_sample )
294+ if format in ("nis" , "nist" ):
295+ return "PCM_16"
296+ raise ValueError (f"Unsupported format: { format } " )
297+
298+
212299@_mod_utils .requires_module ("soundfile" )
213300def save (
214301 filepath : str ,
@@ -217,6 +304,8 @@ def save(
217304 channels_first : bool = True ,
218305 compression : Optional [float ] = None ,
219306 format : Optional [str ] = None ,
307+ encoding : Optional [str ] = None ,
308+ bits_per_sample : Optional [int ] = None ,
220309):
221310 """Save audio data to file.
222311
@@ -246,9 +335,65 @@ def save(
246335 otherwise ``[time, channel]``.
247336 compression (Optional[float]): Not used.
248337 It is here only for interface compatibility reson with "sox_io" backend.
249- format (str, optional): Output audio format.
250- This is required when the output audio format cannot be infered from
251- ``filepath``, (such as file extension or ``name`` attribute of the given file object).
338+ format (str, optional): Override the audio format.
339+ When ``filepath`` argument is path-like object, audio format is
340+ inferred from file extension. If the file extension is missing or
341+ different, you can specify the correct format with this argument.
342+
343+ When ``filepath`` argument is file-like object,
344+ this argument is required.
345+
346+ Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
347+ ``"flac"`` and ``"sph"``.
348+ encoding (str, optional): Changes the encoding for supported formats.
349+ This argument is effective only for supported formats, sush as
350+ ``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
351+
352+ - ``"PCM_S"`` (signed integer Linear PCM)
353+ - ``"PCM_U"`` (unsigned integer Linear PCM)
354+ - ``"PCM_F"`` (floating point PCM)
355+ - ``"ULAW"`` (mu-law)
356+ - ``"ALAW"`` (a-law)
357+
358+ bits_per_sample (int, optional): Changes the bit depth for the
359+ supported formats.
360+ When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
361+ you can change the bit depth.
362+ Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
363+
364+ Supported formats/encodings/bit depth/compression are:
365+
366+ ``"wav"``
367+ - 32-bit floating-point PCM
368+ - 32-bit signed integer PCM
369+ - 24-bit signed integer PCM
370+ - 16-bit signed integer PCM
371+ - 8-bit unsigned integer PCM
372+ - 8-bit mu-law
373+ - 8-bit a-law
374+
375+ Note: Default encoding/bit depth is determined by the dtype of
376+ the input Tensor.
377+
378+ ``"flac"``
379+ - 8-bit
380+ - 16-bit
381+ - 24-bit (default)
382+
383+ ``"ogg"``, ``"vorbis"``
384+ - Doesn't accept changing configuration.
385+
386+ ``"sph"``
387+ - 8-bit signed integer PCM
388+ - 16-bit signed integer PCM
389+ - 24-bit signed integer PCM
390+ - 32-bit signed integer PCM (default)
391+ - 8-bit mu-law
392+ - 8-bit a-law
393+ - 16-bit a-law
394+ - 24-bit a-law
395+ - 32-bit a-law
396+
252397 """
253398 if src .ndim != 2 :
254399 raise ValueError (f"Expected 2D Tensor, got { src .ndim } D." )
@@ -260,24 +405,13 @@ def save(
260405 if hasattr (filepath , 'write' ):
261406 if format is None :
262407 raise RuntimeError ('`format` is required when saving to file object.' )
263- ext = format
408+ ext = format . lower ()
264409 else :
265410 ext = str (filepath ).split ("." )[- 1 ].lower ()
266411
267- if ext != "wav" :
268- subtype = None
269- elif src .dtype == torch .uint8 :
270- subtype = "PCM_U8"
271- elif src .dtype == torch .int16 :
272- subtype = "PCM_16"
273- elif src .dtype == torch .int32 :
274- subtype = "PCM_32"
275- elif src .dtype == torch .float32 :
276- subtype = "FLOAT"
277- elif src .dtype == torch .float64 :
278- subtype = "DOUBLE"
279- else :
280- raise ValueError (f"Unsupported dtype for WAV: { src .dtype } " )
412+ if bits_per_sample not in (None , 8 , 16 , 24 , 32 , 64 ):
413+ raise ValueError ("Invalid bits_per_sample." )
414+ subtype = _get_subtype (src .dtype , ext , encoding , bits_per_sample )
281415
282416 # sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
283417 # so we extend the extensions manually here
0 commit comments