Skip to content

Commit 6854020

Browse files
authored
Apply codec-based data augmentation (#1200)
1 parent 4a3d203 commit 6854020

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

test/torchaudio_unittest/functional/functional_cpu_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
import itertools
99

1010
from torchaudio_unittest import common_utils
11+
from torchaudio_unittest.common_utils import (
12+
TorchaudioTestCase,
13+
skipIfNoExtension,
14+
)
15+
from torchaudio_unittest.backend.sox_io.common import name_func
16+
1117
from .functional_impl import Lfilter, Spectrogram
1218

1319

@@ -53,6 +59,7 @@ def test_warning(self):
5359

5460
class TestComputeDeltas(common_utils.TorchaudioTestCase):
5561
"""Test suite for correctness of compute_deltas"""
62+
5663
def test_one_channel(self):
5764
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
5865
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
@@ -211,3 +218,48 @@ def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
211218

212219
assert mask_specgrams.size() == specgrams.size()
213220
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
221+
222+
223+
@skipIfNoExtension
224+
class TestApplyCodec(TorchaudioTestCase):
225+
backend = "sox_io"
226+
227+
def _smoke_test(self, format, compression, check_num_frames):
228+
"""
229+
The purpose of this test suite is to verify that apply_codec functionalities do not exhibit
230+
abnormal behaviors.
231+
"""
232+
torch.random.manual_seed(42)
233+
sample_rate = 8000
234+
num_frames = 3 * sample_rate
235+
num_channels = 2
236+
waveform = torch.rand(num_channels, num_frames)
237+
238+
augmented = F.apply_codec(waveform,
239+
sample_rate,
240+
format,
241+
True,
242+
compression
243+
)
244+
assert augmented.dtype == waveform.dtype
245+
assert augmented.shape[0] == num_channels
246+
if check_num_frames:
247+
assert augmented.shape[1] == num_frames
248+
249+
def test_wave(self):
250+
self._smoke_test("wav", compression=None, check_num_frames=True)
251+
252+
@parameterized.expand([(96,), (128,), (160,), (192,), (224,), (256,), (320,)],
253+
name_func=name_func)
254+
def test_mp3(self, compression):
255+
self._smoke_test("mp3", compression, check_num_frames=False)
256+
257+
@parameterized.expand([(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,)],
258+
name_func=name_func)
259+
def test_flac(self, compression):
260+
self._smoke_test("flac", compression, check_num_frames=False)
261+
262+
@parameterized.expand([(-1,), (0,), (1,), (2,), (3,), (3.6,), (5,), (10,)],
263+
name_func=name_func)
264+
def test_vorbis(self, compression):
265+
self._smoke_test("vorbis", compression, check_num_frames=False)

torchaudio/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
sliding_window_cmn,
1919
spectrogram,
2020
spectral_centroid,
21+
apply_codec,
2122
)
2223
from .filtering import (
2324
allpass_biquad,
@@ -84,4 +85,5 @@
8485
'riaa_biquad',
8586
'treble_biquad',
8687
'vad',
88+
'apply_codec'
8789
]

torchaudio/functional/functional.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
# -*- coding: utf-8 -*-
22

3+
import io
34
import math
4-
from typing import Optional, Tuple
55
import warnings
6+
from typing import Optional, Tuple
67

78
import torch
89
from torch import Tensor
10+
from torchaudio._internal import (
11+
module_utils as _mod_utils,
12+
)
13+
import torchaudio
914

1015
__all__ = [
1116
"spectrogram",
@@ -29,6 +34,7 @@
2934
'mask_along_axis_iid',
3035
'sliding_window_cmn',
3136
"spectral_centroid",
37+
"apply_codec",
3238
]
3339

3440

@@ -994,6 +1000,52 @@ def spectral_centroid(
9941000
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
9951001

9961002

1003+
@_mod_utils.requires_module('torchaudio._torchaudio')
1004+
def apply_codec(
1005+
waveform: Tensor,
1006+
sample_rate: int,
1007+
format: str,
1008+
channels_first: bool = True,
1009+
compression: Optional[float] = None,
1010+
encoding: Optional[str] = None,
1011+
bits_per_sample: Optional[int] = None,
1012+
) -> Tensor:
1013+
r"""
1014+
Applies codecs as a form of augmentation
1015+
Args:
1016+
waveform (Tensor): Audio data. Must be 2 dimensional. See also ```channels_first```
1017+
sample_rate (int): Sample rate of the audio waveform
1018+
format (str): file format
1019+
channels_first (bool):
1020+
When True, both the input and output Tensor have dimension ``[channel, time]``.
1021+
Otherwise, they have dimension ``[time, channel]``.
1022+
compression (float): Used for formats other than WAV.
1023+
For mor details see :py:func:`torchaudio.backend.sox_io_backend.save`
1024+
encoding (str, optional): Changes the encoding for the supported formats.
1025+
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
1026+
bits_per_sample (int, optional): Changes the bit depth for the supported formats.
1027+
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`
1028+
1029+
Returns:
1030+
torch.Tensor: Resulting Tensor.
1031+
If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``
1032+
"""
1033+
bytes = io.BytesIO()
1034+
torchaudio.backend.sox_io_backend.save(bytes,
1035+
waveform,
1036+
sample_rate,
1037+
channels_first,
1038+
compression,
1039+
format,
1040+
encoding,
1041+
bits_per_sample
1042+
)
1043+
bytes.seek(0)
1044+
augmented, _ = torchaudio.sox_effects.sox_effects.apply_effects_file(
1045+
bytes, effects=[["rate", f"{sample_rate}"]], channels_first=channels_first, format=format)
1046+
return augmented
1047+
1048+
9971049
def compute_kaldi_pitch(
9981050
waveform: torch.Tensor,
9991051
sample_rate: float,

0 commit comments

Comments
 (0)