-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathConv_TasNet.py
More file actions
283 lines (245 loc) · 9.16 KB
/
Conv_TasNet.py
File metadata and controls
283 lines (245 loc) · 9.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import torch
import torch.nn as nn
class GlobalLayerNorm(nn.Module):
'''
Calculate Global Layer Normalization
dim: (int or list or torch.Size) –
input shape from an expected input of size
eps: a value added to the denominator for numerical stability.
elementwise_affine: a boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
'''
def __init__(self, dim, eps=1e-05, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
# x = N x C x L
# N x 1 x 1
# cln: mean,var N x 1 x L
# gln: mean,var N x 1 x 1
if x.dim() != 3:
raise RuntimeError("{} accept 3D tensor as input".format(
self.__name__))
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x-mean)**2, (1, 2), keepdim=True)
# N x C x L
if self.elementwise_affine:
x = self.weight*(x-mean)/torch.sqrt(var+self.eps)+self.bias
else:
x = (x-mean)/torch.sqrt(var+self.eps)
return x
class CumulativeLayerNorm(nn.LayerNorm):
'''
Calculate Cumulative Layer Normalization
dim: you want to norm dim
elementwise_affine: learnable per-element affine parameters
'''
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine)
def forward(self, x):
# x: N x C x L
# N x L x C
x = torch.transpose(x, 1, 2)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, 2)
return x
def select_norm(norm, dim):
if norm not in ['gln', 'cln', 'bn']:
if x.dim() != 3:
raise RuntimeError("{} accept 3D tensor as input".format(
self.__name__))
if norm == 'gln':
return GlobalLayerNorm(dim, elementwise_affine=True)
if norm == 'cln':
return CumulativeLayerNorm(dim, elementwise_affine=True)
else:
return nn.BatchNorm1d(dim)
class Conv1D(nn.Conv1d):
'''
Applies a 1D convolution over an input signal composed of several input planes.
'''
def __init__(self, *args, **kwargs):
super(Conv1D, self).__init__(*args, **kwargs)
def forward(self, x, squeeze=False):
# x: N x C x L
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 2/3D tensor as input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if squeeze:
x = torch.squeeze(x)
return x
class ConvTrans1D(nn.ConvTranspose1d):
'''
This module can be seen as the gradient of Conv1d with respect to its input.
It is also known as a fractionally-strided convolution
or a deconvolution (although it is not an actual deconvolution operation).
'''
def __init__(self, *args, **kwargs):
super(ConvTrans1D, self).__init__(*args, **kwargs)
def forward(self, x, squeeze=False):
"""
x: N x L or N x C x L
"""
if x.dim() not in [2, 3]:
raise RuntimeError("{} accept 2/3D tensor as input".format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if squeeze:
x = torch.squeeze(x)
return x
class Conv1D_Block(nn.Module):
'''
Consider only residual links
'''
def __init__(self, in_channels=256, out_channels=512,
kernel_size=3, dilation=1, norm='gln', causal=False):
super(Conv1D_Block, self).__init__()
# conv 1 x 1
self.conv1x1 = Conv1D(in_channels, out_channels, 1)
self.PReLU_1 = nn.PReLU()
self.norm_1 = select_norm(norm, out_channels)
# not causal don't need to padding, causal need to pad+1 = kernel_size
self.pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
dilation * (kernel_size - 1))
# depthwise convolution
self.dwconv = Conv1D(out_channels, out_channels, kernel_size,
groups=out_channels, padding=self.pad, dilation=dilation)
self.PReLU_2 = nn.PReLU()
self.norm_2 = select_norm(norm, out_channels)
self.Sc_conv = nn.Conv1d(out_channels, in_channels, 1, bias=True)
self.causal = causal
def forward(self, x):
# x: N x C x L
# N x O_C x L
c = self.conv1x1(x)
# N x O_C x L
c = self.PReLU_1(c)
c = self.norm_1(c)
# causal: N x O_C x (L+pad)
# noncausal: N x O_C x L
c = self.dwconv(c)
# N x O_C x L
if self.causal:
c = c[:, :, :-self.pad]
c = self.Sc_conv(c)
return x+c
class ConvTasNet(nn.Module):
'''
ConvTasNet module
N Number of filters in autoencoder
L Length of the filters (in samples)
B Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks
Sc Number of channels in skip-connection paths’ 1 × 1-conv blocks
H Number of channels in convolutional blocks
P Kernel size in convolutional blocks
X Number of convolutional blocks in each repeat
R Number of repeats
'''
def __init__(self,
N=512,
L=16,
B=128,
H=512,
P=3,
X=8,
R=3,
norm="gln",
num_spks=2,
activate="relu",
causal=False):
super(ConvTasNet, self).__init__()
# n x 1 x T => n x N x T
self.encoder = Conv1D(1, N, L, stride=L // 2, padding=0)
# n x N x T Layer Normalization of Separation
self.LayerN_S = select_norm('cln', N)
# n x B x T Conv 1 x 1 of Separation
self.BottleN_S = Conv1D(N, B, 1)
# Separation block
# n x B x T => n x B x T
self.separation = self._Sequential_repeat(
R, X, in_channels=B, out_channels=H, kernel_size=P, norm=norm, causal=causal)
# n x B x T => n x 2*N x T
self.gen_masks = Conv1D(B, num_spks*N, 1)
# n x N x T => n x 1 x L
self.decoder = ConvTrans1D(N, 1, L, stride=L//2)
# activation function
active_f = {
'relu': nn.ReLU(),
'sigmoid': nn.Sigmoid(),
'softmax': nn.Softmax(dim=0)
}
self.activation_type = activate
self.activation = active_f[activate]
self.num_spks = num_spks
def _Sequential_block(self, num_blocks, **block_kwargs):
'''
Sequential 1-D Conv Block
input:
num_block: how many blocks in every repeats
**block_kwargs: parameters of Conv1D_Block
'''
Conv1D_Block_lists = [Conv1D_Block(
**block_kwargs, dilation=(2**i)) for i in range(num_blocks)]
return nn.Sequential(*Conv1D_Block_lists)
def _Sequential_repeat(self, num_repeats, num_blocks, **block_kwargs):
'''
Sequential repeats
input:
num_repeats: Number of repeats
num_blocks: Number of block in every repeats
**block_kwargs: parameters of Conv1D_Block
'''
repeats_lists = [self._Sequential_block(
num_blocks, **block_kwargs) for i in range(num_repeats)]
return nn.Sequential(*repeats_lists)
def forward(self, x):
if x.dim() >= 3:
raise RuntimeError(
"{} accept 1/2D tensor as input, but got {:d}".format(
self.__name__, x.dim()))
if x.dim() == 1:
x = torch.unsqueeze(x, 0)
# x: n x 1 x L => n x N x T
w = self.encoder(x)
# n x N x L => n x B x L
e = self.LayerN_S(w)
e = self.BottleN_S(e)
# n x B x L => n x B x L
e = self.separation(e)
# n x B x L => n x num_spk*N x L
m = self.gen_masks(e)
# n x N x L x num_spks
m = torch.chunk(m, chunks=self.num_spks, dim=1)
# num_spks x n x N x L
m = self.activation(torch.stack(m, dim=0))
d = [w*m[i] for i in range(self.num_spks)]
# decoder part num_spks x n x L
s = [self.decoder(d[i], squeeze=True) for i in range(self.num_spks)]
return s
def check_parameters(net):
'''
Returns module parameters. Mb
'''
parameters = sum(param.numel() for param in net.parameters())
return parameters / 10**6
def test_convtasnet():
x = torch.randn(320)
nnet = ConvTasNet()
s = nnet(x)
print(str(check_parameters(nnet))+' Mb')
print(s[1].shape)
if __name__ == "__main__":
test_convtasnet()