Skip to content

Commit 5799599

Browse files
authored
Added encoding and bits_per_sample to soundfile's backend save() (#1303)
Cherry picked from commit b8fd5e9 (#1274)
1 parent b7e1cda commit 5799599

File tree

4 files changed

+230
-32
lines changed

4 files changed

+230
-32
lines changed

test/torchaudio_unittest/backend/soundfile/common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,26 @@ def skipIfFormatNotSupported(fmt):
3232

3333
def parameterize(*params):
3434
return parameterized.expand(list(itertools.product(*params)), name_func=name_func)
35+
36+
37+
def fetch_wav_subtype(dtype, encoding, bits_per_sample):
38+
subtype = {
39+
(None, None): dtype2subtype(dtype),
40+
(None, 8): "PCM_U8",
41+
('PCM_U', None): "PCM_U8",
42+
('PCM_U', 8): "PCM_U8",
43+
('PCM_S', None): "PCM_32",
44+
('PCM_S', 16): "PCM_16",
45+
('PCM_S', 32): "PCM_32",
46+
('PCM_F', None): "FLOAT",
47+
('PCM_F', 32): "FLOAT",
48+
('PCM_F', 64): "DOUBLE",
49+
('ULAW', None): "ULAW",
50+
('ULAW', 8): "ULAW",
51+
('ALAW', None): "ALAW",
52+
('ALAW', 8): "ALAW",
53+
}.get((encoding, bits_per_sample))
54+
if subtype:
55+
return subtype
56+
raise ValueError(
57+
f"wav does not support ({encoding}, {bits_per_sample}).")

test/torchaudio_unittest/backend/soundfile/save_test.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
get_wav_data,
1212
load_wav,
1313
)
14-
from .common import parameterize, dtype2subtype, skipIfFormatNotSupported
14+
from .common import (
15+
fetch_wav_subtype,
16+
parameterize,
17+
skipIfFormatNotSupported,
18+
)
1519

1620
if _mod_utils.is_module_available("soundfile"):
1721
import soundfile
@@ -20,36 +24,56 @@
2024
class MockedSaveTest(PytorchTestCase):
2125
@parameterize(
2226
["float32", "int32", "int16", "uint8"], [8000, 16000], [1, 2], [False, True],
27+
[
28+
(None, None),
29+
('PCM_U', None),
30+
('PCM_U', 8),
31+
('PCM_S', None),
32+
('PCM_S', 16),
33+
('PCM_S', 32),
34+
('PCM_F', None),
35+
('PCM_F', 32),
36+
('PCM_F', 64),
37+
('ULAW', None),
38+
('ULAW', 8),
39+
('ALAW', None),
40+
('ALAW', 8),
41+
],
2342
)
2443
@patch("soundfile.write")
25-
def test_wav(self, dtype, sample_rate, num_channels, channels_first, mocked_write):
44+
def test_wav(self, dtype, sample_rate, num_channels, channels_first,
45+
enc_params, mocked_write):
2646
"""soundfile_backend.save passes correct subtype to soundfile.write when WAV"""
2747
filepath = "foo.wav"
2848
input_tensor = get_wav_data(
2949
dtype,
3050
num_channels,
3151
num_frames=3 * sample_rate,
32-
normalize=dtype == "flaot32",
52+
normalize=dtype == "float32",
3353
channels_first=channels_first,
3454
).t()
3555

56+
encoding, bits_per_sample = enc_params
3657
soundfile_backend.save(
37-
filepath, input_tensor, sample_rate, channels_first=channels_first
58+
filepath, input_tensor, sample_rate, channels_first=channels_first,
59+
encoding=encoding, bits_per_sample=bits_per_sample
3860
)
3961

4062
# on +Py3.8 call_args.kwargs is more descreptive
4163
args = mocked_write.call_args[1]
4264
assert args["file"] == filepath
4365
assert args["samplerate"] == sample_rate
44-
assert args["subtype"] == dtype2subtype(dtype)
66+
assert args["subtype"] == fetch_wav_subtype(
67+
dtype, encoding, bits_per_sample)
4568
assert args["format"] is None
4669
self.assertEqual(
4770
args["data"], input_tensor.t() if channels_first else input_tensor
4871
)
4972

5073
@patch("soundfile.write")
5174
def assert_non_wav(
52-
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write
75+
self, fmt, dtype, sample_rate, num_channels, channels_first, mocked_write,
76+
encoding=None, bits_per_sample=None,
5377
):
5478
"""soundfile_backend.save passes correct subtype and format to soundfile.write when SPHERE"""
5579
filepath = f"foo.{fmt}"
@@ -63,14 +87,14 @@ def assert_non_wav(
6387
expected_data = input_tensor.t() if channels_first else input_tensor
6488

6589
soundfile_backend.save(
66-
filepath, input_tensor, sample_rate, channels_first=channels_first
90+
filepath, input_tensor, sample_rate, channels_first,
91+
encoding=encoding, bits_per_sample=bits_per_sample,
6792
)
6893

6994
# on +Py3.8 call_args.kwargs is more descreptive
7095
args = mocked_write.call_args[1]
7196
assert args["file"] == filepath
7297
assert args["samplerate"] == sample_rate
73-
assert args["subtype"] is None
7498
if fmt in ["sph", "nist", "nis"]:
7599
assert args["format"] == "NIST"
76100
else:
@@ -83,19 +107,36 @@ def assert_non_wav(
83107
[8000, 16000],
84108
[1, 2],
85109
[False, True],
110+
[
111+
('PCM_S', 8),
112+
('PCM_S', 16),
113+
('PCM_S', 24),
114+
('PCM_S', 32),
115+
('ULAW', 8),
116+
('ALAW', 8),
117+
('ALAW', 16),
118+
('ALAW', 24),
119+
('ALAW', 32),
120+
],
86121
)
87-
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first):
122+
def test_sph(self, fmt, dtype, sample_rate, num_channels, channels_first, enc_params):
88123
"""soundfile_backend.save passes default format and subtype (None-s) to
89124
soundfile.write when not WAV"""
90-
self.assert_non_wav(fmt, dtype, sample_rate, num_channels, channels_first)
125+
encoding, bits_per_sample = enc_params
126+
self.assert_non_wav(fmt, dtype, sample_rate, num_channels,
127+
channels_first, encoding=encoding,
128+
bits_per_sample=bits_per_sample)
91129

92130
@parameterize(
93131
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
132+
[8, 16, 24],
94133
)
95-
def test_flac(self, dtype, sample_rate, num_channels, channels_first):
134+
def test_flac(self, dtype, sample_rate, num_channels,
135+
channels_first, bits_per_sample):
96136
"""soundfile_backend.save passes default format and subtype (None-s) to
97137
soundfile.write when not WAV"""
98-
self.assert_non_wav("flac", dtype, sample_rate, num_channels, channels_first)
138+
self.assert_non_wav("flac", dtype, sample_rate, num_channels,
139+
channels_first, bits_per_sample=bits_per_sample)
99140

100141
@parameterize(
101142
["int32", "int16"], [8000, 16000], [1, 2], [False, True],
@@ -228,7 +269,7 @@ def _test_fileobj(self, ext):
228269
found, sr = soundfile.read(fileobj, dtype='float32')
229270

230271
assert sr == sample_rate
231-
self.assertEqual(expected, found)
272+
self.assertEqual(expected, found, atol=1e-4, rtol=1e-8)
232273

233274
def test_fileobj_wav(self):
234275
"""Saving audio via file-like object works"""

torchaudio/backend/_soundfile_backend.py

Lines changed: 152 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,93 @@ def load(
209209
return waveform, sample_rate
210210

211211

212+
def _get_subtype_for_wav(
213+
dtype: torch.dtype,
214+
encoding: str,
215+
bits_per_sample: int):
216+
if not encoding:
217+
if not bits_per_sample:
218+
subtype = {
219+
torch.uint8: "PCM_U8",
220+
torch.int16: "PCM_16",
221+
torch.int32: "PCM_32",
222+
torch.float32: "FLOAT",
223+
torch.float64: "DOUBLE",
224+
}.get(dtype)
225+
if not subtype:
226+
raise ValueError(f"Unsupported dtype for wav: {dtype}")
227+
return subtype
228+
if bits_per_sample == 8:
229+
return "PCM_U8"
230+
return f"PCM_{bits_per_sample}"
231+
if encoding == "PCM_S":
232+
if not bits_per_sample:
233+
return "PCM_32"
234+
if bits_per_sample == 8:
235+
raise ValueError("wav does not support 8-bit signed PCM encoding.")
236+
return f"PCM_{bits_per_sample}"
237+
if encoding == "PCM_U":
238+
if bits_per_sample in (None, 8):
239+
return "PCM_U8"
240+
raise ValueError("wav only supports 8-bit unsigned PCM encoding.")
241+
if encoding == "PCM_F":
242+
if bits_per_sample in (None, 32):
243+
return "FLOAT"
244+
if bits_per_sample == 64:
245+
return "DOUBLE"
246+
raise ValueError("wav only supports 32/64-bit float PCM encoding.")
247+
if encoding == "ULAW":
248+
if bits_per_sample in (None, 8):
249+
return "ULAW"
250+
raise ValueError("wav only supports 8-bit mu-law encoding.")
251+
if encoding == "ALAW":
252+
if bits_per_sample in (None, 8):
253+
return "ALAW"
254+
raise ValueError("wav only supports 8-bit a-law encoding.")
255+
raise ValueError(f"wav does not support {encoding}.")
256+
257+
258+
def _get_subtype_for_sphere(encoding: str, bits_per_sample: int):
259+
if encoding in (None, "PCM_S"):
260+
return f"PCM_{bits_per_sample}" if bits_per_sample else "PCM_32"
261+
if encoding in ("PCM_U", "PCM_F"):
262+
raise ValueError(f"sph does not support {encoding} encoding.")
263+
if encoding == "ULAW":
264+
if bits_per_sample in (None, 8):
265+
return "ULAW"
266+
raise ValueError("sph only supports 8-bit for mu-law encoding.")
267+
if encoding == "ALAW":
268+
return "ALAW"
269+
raise ValueError(f"sph does not support {encoding}.")
270+
271+
272+
def _get_subtype(
273+
dtype: torch.dtype,
274+
format: str,
275+
encoding: str,
276+
bits_per_sample: int):
277+
if format == "wav":
278+
return _get_subtype_for_wav(dtype, encoding, bits_per_sample)
279+
if format == "flac":
280+
if encoding:
281+
raise ValueError("flac does not support encoding.")
282+
if not bits_per_sample:
283+
return "PCM_24"
284+
if bits_per_sample > 24:
285+
raise ValueError("flac does not support bits_per_sample > 24.")
286+
return "PCM_S8" if bits_per_sample == 8 else f"PCM_{bits_per_sample}"
287+
if format in ("ogg", "vorbis"):
288+
if encoding or bits_per_sample:
289+
raise ValueError(
290+
"ogg/vorbis does not support encoding/bits_per_sample.")
291+
return "VORBIS"
292+
if format == "sph":
293+
return _get_subtype_for_sphere(encoding, bits_per_sample)
294+
if format in ("nis", "nist"):
295+
return "PCM_16"
296+
raise ValueError(f"Unsupported format: {format}")
297+
298+
212299
@_mod_utils.requires_module("soundfile")
213300
def save(
214301
filepath: str,
@@ -217,6 +304,8 @@ def save(
217304
channels_first: bool = True,
218305
compression: Optional[float] = None,
219306
format: Optional[str] = None,
307+
encoding: Optional[str] = None,
308+
bits_per_sample: Optional[int] = None,
220309
):
221310
"""Save audio data to file.
222311
@@ -246,9 +335,65 @@ def save(
246335
otherwise ``[time, channel]``.
247336
compression (Optional[float]): Not used.
248337
It is here only for interface compatibility reson with "sox_io" backend.
249-
format (str, optional): Output audio format.
250-
This is required when the output audio format cannot be infered from
251-
``filepath``, (such as file extension or ``name`` attribute of the given file object).
338+
format (str, optional): Override the audio format.
339+
When ``filepath`` argument is path-like object, audio format is
340+
inferred from file extension. If the file extension is missing or
341+
different, you can specify the correct format with this argument.
342+
343+
When ``filepath`` argument is file-like object,
344+
this argument is required.
345+
346+
Valid values are ``"wav"``, ``"ogg"``, ``"vorbis"``,
347+
``"flac"`` and ``"sph"``.
348+
encoding (str, optional): Changes the encoding for supported formats.
349+
This argument is effective only for supported formats, sush as
350+
``"wav"``, ``""flac"`` and ``"sph"``. Valid values are;
351+
352+
- ``"PCM_S"`` (signed integer Linear PCM)
353+
- ``"PCM_U"`` (unsigned integer Linear PCM)
354+
- ``"PCM_F"`` (floating point PCM)
355+
- ``"ULAW"`` (mu-law)
356+
- ``"ALAW"`` (a-law)
357+
358+
bits_per_sample (int, optional): Changes the bit depth for the
359+
supported formats.
360+
When ``format`` is one of ``"wav"``, ``"flac"`` or ``"sph"``,
361+
you can change the bit depth.
362+
Valid values are ``8``, ``16``, ``24``, ``32`` and ``64``.
363+
364+
Supported formats/encodings/bit depth/compression are:
365+
366+
``"wav"``
367+
- 32-bit floating-point PCM
368+
- 32-bit signed integer PCM
369+
- 24-bit signed integer PCM
370+
- 16-bit signed integer PCM
371+
- 8-bit unsigned integer PCM
372+
- 8-bit mu-law
373+
- 8-bit a-law
374+
375+
Note: Default encoding/bit depth is determined by the dtype of
376+
the input Tensor.
377+
378+
``"flac"``
379+
- 8-bit
380+
- 16-bit
381+
- 24-bit (default)
382+
383+
``"ogg"``, ``"vorbis"``
384+
- Doesn't accept changing configuration.
385+
386+
``"sph"``
387+
- 8-bit signed integer PCM
388+
- 16-bit signed integer PCM
389+
- 24-bit signed integer PCM
390+
- 32-bit signed integer PCM (default)
391+
- 8-bit mu-law
392+
- 8-bit a-law
393+
- 16-bit a-law
394+
- 24-bit a-law
395+
- 32-bit a-law
396+
252397
"""
253398
if src.ndim != 2:
254399
raise ValueError(f"Expected 2D Tensor, got {src.ndim}D.")
@@ -260,24 +405,13 @@ def save(
260405
if hasattr(filepath, 'write'):
261406
if format is None:
262407
raise RuntimeError('`format` is required when saving to file object.')
263-
ext = format
408+
ext = format.lower()
264409
else:
265410
ext = str(filepath).split(".")[-1].lower()
266411

267-
if ext != "wav":
268-
subtype = None
269-
elif src.dtype == torch.uint8:
270-
subtype = "PCM_U8"
271-
elif src.dtype == torch.int16:
272-
subtype = "PCM_16"
273-
elif src.dtype == torch.int32:
274-
subtype = "PCM_32"
275-
elif src.dtype == torch.float32:
276-
subtype = "FLOAT"
277-
elif src.dtype == torch.float64:
278-
subtype = "DOUBLE"
279-
else:
280-
raise ValueError(f"Unsupported dtype for WAV: {src.dtype}")
412+
if bits_per_sample not in (None, 8, 16, 24, 32, 64):
413+
raise ValueError("Invalid bits_per_sample.")
414+
subtype = _get_subtype(src.dtype, ext, encoding, bits_per_sample)
281415

282416
# sph is a extension used in TED-LIUM but soundfile does not recognize it as NIST format,
283417
# so we extend the extensions manually here

torchaudio/backend/sox_io_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def save(
197197
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
198198
``"amb"``, ``"flac"`` and ``"sph"``.
199199
encoding (str, optional): Changes the encoding for the supported formats.
200-
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
200+
This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
201201
and ``"sph"``. Valid values are;
202202
203203
- ``"PCM_S"`` (signed integer Linear PCM)

0 commit comments

Comments
 (0)