Skip to content

Commit 4318fc5

Browse files
jimchen90Ji Chen
andauthored
Add MelResNet Block (#705)
* Add MelResNet Block * add default value * update model and test * rebase and small changes * add pad variable * update format * update reference in docstrings * add underscore name Co-authored-by: Ji Chen <[email protected]>
1 parent ab733e7 commit 4318fc5

File tree

3 files changed

+129
-1
lines changed

3 files changed

+129
-1
lines changed

test/test_models.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from torchaudio.models import Wav2Letter
2+
from torchaudio.models import Wav2Letter, _MelResNet
33

44

55
class TestWav2Letter:
@@ -29,3 +29,23 @@ def test_mfcc(self):
2929
out = model(x)
3030

3131
assert out.size() == (batch_size, num_classes, 2)
32+
33+
34+
class TestMelResNet:
35+
36+
def test_waveform(self):
37+
38+
batch_size = 2
39+
num_features = 200
40+
input_dims = 100
41+
output_dims = 128
42+
res_blocks = 10
43+
hidden_dims = 128
44+
pad = 2
45+
46+
model = _MelResNet(res_blocks, input_dims, hidden_dims, output_dims, pad)
47+
48+
x = torch.rand(batch_size, input_dims, num_features)
49+
out = model(x)
50+
51+
assert out.size() == (batch_size, output_dims, num_features - pad * 2)

torchaudio/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .wav2letter import *
2+
from ._wavernn import *

torchaudio/models/_wavernn.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from typing import Optional
2+
3+
from torch import Tensor
4+
from torch import nn
5+
6+
__all__ = ["_ResBlock", "_MelResNet"]
7+
8+
9+
class _ResBlock(nn.Module):
10+
r"""This is a ResNet block layer. This layer is based on the paper "Deep Residual Learning
11+
for Image Recognition". Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. CVPR, 2016.
12+
It is a block used in WaveRNN. WaveRNN is based on the paper "Efficient Neural Audio Synthesis".
13+
Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart,
14+
Florian Stimberg, Aaron van den Oord, Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
15+
16+
Args:
17+
num_dims: the number of compute dimensions in the input (default=128).
18+
19+
Examples::
20+
>>> resblock = _ResBlock(num_dims=128)
21+
>>> input = torch.rand(10, 128, 512)
22+
>>> output = resblock(input)
23+
"""
24+
25+
def __init__(self, num_dims: int = 128) -> None:
26+
super().__init__()
27+
28+
self.resblock_model = nn.Sequential(
29+
nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
30+
nn.BatchNorm1d(num_dims),
31+
nn.ReLU(inplace=True),
32+
nn.Conv1d(in_channels=num_dims, out_channels=num_dims, kernel_size=1, bias=False),
33+
nn.BatchNorm1d(num_dims)
34+
)
35+
36+
def forward(self, x: Tensor) -> Tensor:
37+
r"""Pass the input through the _ResBlock layer.
38+
39+
Args:
40+
x: the input sequence to the _ResBlock layer (required).
41+
42+
Shape:
43+
- x: :math:`(N, S, T)`.
44+
- output: :math:`(N, S, T)`.
45+
where N is the batch size, S is the number of input sequence,
46+
T is the length of input sequence.
47+
"""
48+
49+
residual = x
50+
return self.resblock_model(x) + residual
51+
52+
53+
class _MelResNet(nn.Module):
54+
r"""This is a MelResNet layer based on a stack of ResBlocks. It is a block used in WaveRNN.
55+
WaveRNN is based on the paper "Efficient Neural Audio Synthesis". Nal Kalchbrenner, Erich Elsen,
56+
Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, Florian Stimberg, Aaron van den Oord,
57+
Sander Dieleman, Koray Kavukcuoglu. arXiv:1802.08435, 2018.
58+
59+
Args:
60+
res_blocks: the number of ResBlock in stack (default=10).
61+
input_dims: the number of input sequence (default=100).
62+
hidden_dims: the number of compute dimensions (default=128).
63+
output_dims: the number of output sequence (default=128).
64+
pad: the number of kernal size (pad * 2 + 1) in the first Conv1d layer (default=2).
65+
66+
Examples::
67+
>>> melresnet = _MelResNet(res_blocks=10, input_dims=100,
68+
hidden_dims=128, output_dims=128, pad=2)
69+
>>> input = torch.rand(10, 100, 512)
70+
>>> output = melresnet(input)
71+
"""
72+
73+
def __init__(self, res_blocks: int = 10,
74+
input_dims: int = 100,
75+
hidden_dims: int = 128,
76+
output_dims: int = 128,
77+
pad: int = 2) -> None:
78+
super().__init__()
79+
80+
kernel_size = pad * 2 + 1
81+
ResBlocks = []
82+
83+
for i in range(res_blocks):
84+
ResBlocks.append(_ResBlock(hidden_dims))
85+
86+
self.melresnet_model = nn.Sequential(
87+
nn.Conv1d(in_channels=input_dims, out_channels=hidden_dims, kernel_size=kernel_size, bias=False),
88+
nn.BatchNorm1d(hidden_dims),
89+
nn.ReLU(inplace=True),
90+
*ResBlocks,
91+
nn.Conv1d(in_channels=hidden_dims, out_channels=output_dims, kernel_size=1)
92+
)
93+
94+
def forward(self, x: Tensor) -> Tensor:
95+
r"""Pass the input through the _MelResNet layer.
96+
97+
Args:
98+
x: the input sequence to the _MelResNet layer (required).
99+
100+
Shape:
101+
- x: :math:`(N, S, T)`.
102+
- output: :math:`(N, P, T-2*pad)`.
103+
where N is the batch size, S is the number of input sequence,
104+
P is the number of ouput sequence, T is the length of input sequence.
105+
"""
106+
107+
return self.melresnet_model(x)

0 commit comments

Comments
 (0)