Skip to content

Commit 16052ce

Browse files
committed
add skip connection.
1 parent 27843a9 commit 16052ce

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

egs/aishell/s10/chain/tdnnf_layer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,21 @@ class FactorizedTDNN(nn.Module):
6262
https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc#L982
6363
'''
6464

65-
def __init__(self, dim, bottleneck_dim, time_stride):
65+
def __init__(self, dim, bottleneck_dim, time_stride, bypass_scale=0.66):
6666
super().__init__()
67+
6768
assert time_stride in [0, 1]
69+
assert abs(bypass_scale) <= 1
70+
71+
self.bypass_scale = bypass_scale
6872

6973
if time_stride == 0:
7074
kernel_size = 1
7175
else:
7276
kernel_size = 3
7377

78+
self.kernel_size = kernel_size
79+
7480
# WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer
7581
# and [0, 1] for the second affine layer;
7682
# We use [-1, 0, 1] for the first linear layer
@@ -90,6 +96,10 @@ def __init__(self, dim, bottleneck_dim, time_stride):
9096
def forward(self, x):
9197
# input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T]
9298
assert x.ndim == 3
99+
100+
# save it for skip connection
101+
input_x = x
102+
93103
x = self.conv(x)
94104
# at this point, x is [N, C, T]
95105

@@ -109,6 +119,10 @@ def forward(self, x):
109119
# TODO(fangjun): implement GeneralDropoutComponent in PyTorch
110120

111121
# at this point, x is [N, C, T]
122+
if self.kernel_size == 3:
123+
x = self.bypass_scale * input_x[:, :, 1:-1] + x
124+
else:
125+
x = self.bypass_scale * input_x + x
112126
return x
113127

114128
def constraint_orthonormal(self):
@@ -176,6 +190,10 @@ def _test_factorized_tdnn():
176190
y = model(x)
177191
assert y.size(2) == T - 2
178192

193+
model = FactorizedTDNN(dim=C, bottleneck_dim=2, time_stride=0)
194+
y = model(x)
195+
assert y.size(2) == T
196+
179197

180198
if __name__ == '__main__':
181199
torch.manual_seed(20200130)

0 commit comments

Comments
 (0)