Skip to content

Commit e9f19c3

Browse files
authored
Refactor backend and not rely on global variables on switching (#698)
* Refactor backend switching 1. Do not rely on global variables for backend switch So that load/save/info/load_wav functions will be torchscript-able 2. Add no_backend module to for the case there is no backend module available [bonus] This allows the whole codebase importable on systems that do not have torchaudio C++ extension nor soundfile.
1 parent 87a761d commit e9f19c3

File tree

9 files changed

+266
-161
lines changed

9 files changed

+266
-161
lines changed

test/test_backend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
3+
import torchaudio
4+
from torchaudio._internal.module_utils import is_module_available
5+
6+
7+
class BackendSwitch:
8+
"""Test set/get_audio_backend works"""
9+
backend = None
10+
backend_module = None
11+
12+
def test_switch(self):
13+
torchaudio.set_audio_backend(self.backend)
14+
if self.backend is None:
15+
assert torchaudio.get_audio_backend() is None
16+
else:
17+
assert torchaudio.get_audio_backend() == self.backend
18+
assert torchaudio.load == self.backend_module.load
19+
assert torchaudio.load_wav == self.backend_module.load_wav
20+
assert torchaudio.save == self.backend_module.save
21+
assert torchaudio.info == self.backend_module.info
22+
23+
24+
class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
25+
backend = None
26+
backend_module = torchaudio.backend.no_backend
27+
28+
29+
@unittest.skipIf(
30+
not is_module_available('torchaudio._torchaudio'),
31+
'torchaudio C++ extension not available')
32+
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
33+
backend = 'sox'
34+
backend_module = torchaudio.backend.sox_backend
35+
36+
37+
@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
38+
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
39+
backend = 'soundfile'
40+
backend_module = torchaudio.backend.soundfile_backend

test/test_datasets.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ def test_librispeech(self):
2929
data[0]
3030

3131
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
32+
@common_utils.AudioBackendScope('sox')
3233
def test_commonvoice(self):
3334
data = COMMONVOICE(self.path, url="tatar")
3435
data[0]
3536

3637
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
38+
@common_utils.AudioBackendScope('sox')
3739
def test_commonvoice_diskcache(self):
3840
data = COMMONVOICE(self.path, url="tatar")
3941
data = diskcache_iterator(data)
@@ -43,6 +45,7 @@ def test_commonvoice_diskcache(self):
4345
data[0]
4446

4547
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
48+
@common_utils.AudioBackendScope('sox')
4649
def test_commonvoice_bg(self):
4750
data = COMMONVOICE(self.path, url="tatar")
4851
data = bg_iterator(data, 5)

torchaudio/__init__.py

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from pathlib import Path
2-
from typing import Any, Callable, Optional, Tuple, Union
3-
4-
from torch import Tensor
51
from torchaudio._internal import module_utils as _mod_utils
62
from torchaudio import (
73
compliance,
@@ -11,7 +7,6 @@
117
transforms
128
)
139
from torchaudio.backend import (
14-
_get_audio_backend_module,
1510
list_audio_backends,
1611
get_audio_backend,
1712
set_audio_backend,
@@ -57,116 +52,3 @@ def shutdown_sox():
5752
This function is deprecated. See ``torchaudio.sox_effects.shutdown_sox_effects``
5853
"""
5954
_shutdown_sox_effects()
60-
61-
62-
def load(filepath: Union[str, Path],
63-
out: Optional[Tensor] = None,
64-
normalization: Union[bool, float, Callable] = True,
65-
channels_first: bool = True,
66-
num_frames: int = 0,
67-
offset: int = 0,
68-
signalinfo: Optional[SignalInfo] = None,
69-
encodinginfo: Optional[EncodingInfo] = None,
70-
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
71-
r"""Loads an audio file from disk into a tensor
72-
73-
Args:
74-
filepath (str or Path): Path to audio file
75-
out (Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
76-
normalization (bool, float, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
77-
(assumes signed 32-bit audio), and normalizes to `[-1, 1]`.
78-
If `float`, then output is divided by that number
79-
If `Callable`, then the output is passed as a parameter
80-
to the given function, then the output is divided by
81-
the result. (Default: ``True``)
82-
channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
83-
num_frames (int, optional): Number of frames to load. 0 to load everything after the offset.
84-
(Default: ``0``)
85-
offset (int, optional): Number of frames from the start of the file to begin data loading.
86-
(Default: ``0``)
87-
signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the
88-
audio type cannot be automatically determined. (Default: ``None``)
89-
encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the
90-
audio type cannot be automatically determined. (Default: ``None``)
91-
filetype (str, optional): A filetype or extension to be set if sox cannot determine it
92-
automatically. (Default: ``None``)
93-
94-
Returns:
95-
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
96-
of audio frames and C is the number of channels. An integer which is the sample rate of the
97-
audio (as listed in the metadata of the file)
98-
99-
Example
100-
>>> data, sample_rate = torchaudio.load('foo.mp3')
101-
>>> print(data.size())
102-
torch.Size([2, 278756])
103-
>>> print(sample_rate)
104-
44100
105-
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
106-
>>> print(data_vol_normalized.abs().max())
107-
1.
108-
109-
"""
110-
return _get_audio_backend_module().load(
111-
filepath,
112-
out=out,
113-
normalization=normalization,
114-
channels_first=channels_first,
115-
num_frames=num_frames,
116-
offset=offset,
117-
signalinfo=signalinfo,
118-
encodinginfo=encodinginfo,
119-
filetype=filetype,
120-
)
121-
122-
123-
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
124-
r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
125-
the input right by 16 bits.
126-
127-
Args:
128-
filepath (str or Path): Path to audio file
129-
130-
Returns:
131-
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
132-
of audio frames and C is the number of channels. An integer which is the sample rate of the
133-
audio (as listed in the metadata of the file)
134-
"""
135-
kwargs['normalization'] = 1 << 16
136-
return load(filepath, **kwargs)
137-
138-
139-
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
140-
r"""Convenience function for `save_encinfo`.
141-
142-
Args:
143-
filepath (str): Path to audio file
144-
src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
145-
the number of audio frames, C is the number of channels
146-
sample_rate (int): An integer which is the sample rate of the
147-
audio (as listed in the metadata of the file)
148-
precision (int, optional): Bit precision (Default: ``16``)
149-
channels_first (bool, optional): Set channels first or length first in result. (
150-
Default: ``True``)
151-
"""
152-
153-
return _get_audio_backend_module().save(
154-
filepath, src, sample_rate, precision=precision, channels_first=channels_first
155-
)
156-
157-
158-
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
159-
r"""Gets metadata from an audio file without loading the signal.
160-
161-
Args:
162-
filepath (str): Path to audio file
163-
164-
Returns:
165-
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
166-
info as a python object. An ei (sox_encodinginfo_t) encoding info
167-
168-
Example
169-
>>> si, ei = torchaudio.info('foo.wav')
170-
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
171-
"""
172-
return _get_audio_backend_module().info(filepath)

torchaudio/backend/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from . import utils
22
from .utils import (
3-
_get_audio_backend_module,
43
list_audio_backends,
54
get_audio_backend,
65
set_audio_backend,

torchaudio/backend/common.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Optional, Tuple
1+
from typing import Any, Optional
22

33

44
class SignalInfo:
@@ -29,3 +29,115 @@ def __init__(self,
2929
self.reverse_nibbles = reverse_nibbles
3030
self.reverse_bits = reverse_bits
3131
self.opposite_endian = opposite_endian
32+
33+
34+
_LOAD_DOCSTRING = r"""Loads an audio file from disk into a tensor
35+
36+
Args:
37+
filepath: Path to audio file
38+
39+
out: An optional output tensor to use instead of creating one. (Default: ``None``)
40+
41+
normalization: Optional normalization.
42+
If boolean `True`, then output is divided by `1 << 31`.
43+
Assuming the input is signed 32-bit audio, this normalizes to `[-1, 1]`.
44+
If `float`, then output is divided by that number.
45+
If `Callable`, then the output is passed as a paramete to the given function,
46+
then the output is divided by the result. (Default: ``True``)
47+
48+
channels_first: Set channels first or length first in result. (Default: ``True``)
49+
50+
num_frames: Number of frames to load. 0 to load everything after the offset.
51+
(Default: ``0``)
52+
53+
offset: Number of frames from the start of the file to begin data loading.
54+
(Default: ``0``)
55+
56+
signalinfo: A sox_signalinfo_t type, which could be helpful if the
57+
audio type cannot be automatically determined. (Default: ``None``)
58+
59+
encodinginfo: A sox_encodinginfo_t type, which could be set if the
60+
audio type cannot be automatically determined. (Default: ``None``)
61+
62+
filetype: A filetype or extension to be set if sox cannot determine it
63+
automatically. (Default: ``None``)
64+
65+
Returns:
66+
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where
67+
L is the number of audio frames and
68+
C is the number of channels.
69+
An integer which is the sample rate of the audio (as listed in the metadata of the file)
70+
71+
Example
72+
>>> data, sample_rate = torchaudio.load('foo.mp3')
73+
>>> print(data.size())
74+
torch.Size([2, 278756])
75+
>>> print(sample_rate)
76+
44100
77+
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
78+
>>> print(data_vol_normalized.abs().max())
79+
1.
80+
"""
81+
82+
83+
_LOAD_WAV_DOCSTRING = r""" Loads a wave file.
84+
85+
It assumes that the wav file uses 16 bit per sample that needs normalization by
86+
shifting the input right by 16 bits.
87+
88+
Args:
89+
filepath: Path to audio file
90+
91+
Returns:
92+
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
93+
of audio frames and C is the number of channels. An integer which is the sample rate of the
94+
audio (as listed in the metadata of the file)
95+
"""
96+
97+
_SAVE_DOCSTRING = r"""Saves a Tensor on file as an audio file
98+
99+
Args:
100+
filepath: Path to audio file
101+
src: An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
102+
the number of audio frames, C is the number of channels
103+
sample_rate: An integer which is the sample rate of the
104+
audio (as listed in the metadata of the file)
105+
precision Bit precision (Default: ``16``)
106+
channels_first (bool, optional): Set channels first or length first in result. (
107+
Default: ``True``)
108+
"""
109+
110+
111+
_INFO_DOCSTRING = r"""Gets metadata from an audio file without loading the signal.
112+
113+
Args:
114+
filepath: Path to audio file
115+
116+
Returns:
117+
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
118+
info as a python object. An ei (sox_encodinginfo_t) encoding info
119+
120+
Example
121+
>>> si, ei = torchaudio.info('foo.wav')
122+
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
123+
"""
124+
125+
126+
def _impl_load(func):
127+
setattr(func, '__doc__', _LOAD_DOCSTRING)
128+
return func
129+
130+
131+
def _impl_load_wav(func):
132+
setattr(func, '__doc__', _LOAD_WAV_DOCSTRING)
133+
return func
134+
135+
136+
def _impl_save(func):
137+
setattr(func, '__doc__', _SAVE_DOCSTRING)
138+
return func
139+
140+
141+
def _impl_info(func):
142+
setattr(func, '__doc__', _INFO_DOCSTRING)
143+
return func

torchaudio/backend/no_backend.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pathlib import Path
2+
from typing import Any, Callable, Optional, Tuple, Union
3+
4+
from torch import Tensor
5+
6+
from . import common
7+
from .common import SignalInfo, EncodingInfo
8+
9+
10+
@common._impl_load
11+
def load(filepath: Union[str, Path],
12+
out: Optional[Tensor] = None,
13+
normalization: Union[bool, float, Callable] = True,
14+
channels_first: bool = True,
15+
num_frames: int = 0,
16+
offset: int = 0,
17+
signalinfo: Optional[SignalInfo] = None,
18+
encodinginfo: Optional[EncodingInfo] = None,
19+
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
20+
raise RuntimeError('No audio I/O backend is available.')
21+
22+
23+
@common._impl_load_wav
24+
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
25+
raise RuntimeError('No audio I/O backend is available.')
26+
27+
28+
@common._impl_save
29+
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
30+
raise RuntimeError('No audio I/O backend is available.')
31+
32+
33+
@common._impl_info
34+
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
35+
raise RuntimeError('No audio I/O backend is available.')

torchaudio/backend/soundfile_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
module_utils as _mod_utils,
99
misc_ops as _misc_ops,
1010
)
11+
from . import common
1112
from .common import SignalInfo, EncodingInfo
1213

1314
if _mod_utils.is_module_available('soundfile'):
@@ -24,6 +25,7 @@
2425

2526

2627
@_mod_utils.requires_module('soundfile')
28+
@common._impl_load
2729
def load(filepath: str,
2830
out: Optional[Tensor] = None,
2931
normalization: Optional[bool] = True,
@@ -71,6 +73,14 @@ def load(filepath: str,
7173

7274

7375
@_mod_utils.requires_module('soundfile')
76+
@common._impl_load_wav
77+
def load_wav(filepath, **kwargs):
78+
kwargs['normalization'] = 1 << 16
79+
return load(filepath, **kwargs)
80+
81+
82+
@_mod_utils.requires_module('soundfile')
83+
@common._impl_save
7484
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
7585
r"""See torchaudio.save"""
7686

@@ -104,6 +114,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
104114

105115

106116
@_mod_utils.requires_module('soundfile')
117+
@common._impl_info
107118
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
108119
r"""See torchaudio.info"""
109120

0 commit comments

Comments
 (0)