@@ -62,15 +62,21 @@ class FactorizedTDNN(nn.Module):
6262 https://github.com/kaldi-asr/kaldi/blob/master/src/nnet3/nnet-utils.cc#L982
6363 '''
6464
65- def __init__ (self , dim , bottleneck_dim , time_stride ):
65+ def __init__ (self , dim , bottleneck_dim , time_stride , bypass_scale = 0.66 ):
6666 super ().__init__ ()
67+
6768 assert time_stride in [0 , 1 ]
69+ assert abs (bypass_scale ) <= 1
70+
71+ self .bypass_scale = bypass_scale
6872
6973 if time_stride == 0 :
7074 kernel_size = 1
7175 else :
7276 kernel_size = 3
7377
78+ self .kernel_size = kernel_size
79+
7480 # WARNING(fangjun): kaldi uses [-1, 0] for the first linear layer
7581 # and [0, 1] for the second affine layer;
7682 # We use [-1, 0, 1] for the first linear layer
@@ -90,6 +96,10 @@ def __init__(self, dim, bottleneck_dim, time_stride):
9096 def forward (self , x ):
9197 # input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T]
9298 assert x .ndim == 3
99+
100+ # save it for skip connection
101+ input_x = x
102+
93103 x = self .conv (x )
94104 # at this point, x is [N, C, T]
95105
@@ -109,6 +119,10 @@ def forward(self, x):
109119 # TODO(fangjun): implement GeneralDropoutComponent in PyTorch
110120
111121 # at this point, x is [N, C, T]
122+ if self .kernel_size == 3 :
123+ x = self .bypass_scale * input_x [:, :, 1 :- 1 ] + x
124+ else :
125+ x = self .bypass_scale * input_x + x
112126 return x
113127
114128 def constraint_orthonormal (self ):
@@ -176,6 +190,10 @@ def _test_factorized_tdnn():
176190 y = model (x )
177191 assert y .size (2 ) == T - 2
178192
193+ model = FactorizedTDNN (dim = C , bottleneck_dim = 2 , time_stride = 0 )
194+ y = model (x )
195+ assert y .size (2 ) == T
196+
179197
180198if __name__ == '__main__' :
181199 torch .manual_seed (20200130 )
0 commit comments