Skip to content

Commit 55e9750

Browse files
committed
switch to adam and set affine to false for the input batchnorm layer.
1 parent 926f7d1 commit 55e9750

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

egs/aishell/s10/chain/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def __init__(self,
136136
self.has_LDA = True
137137
else:
138138
logging.info('replace LDA with BatchNorm')
139-
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3)
139+
self.input_batch_norm = nn.BatchNorm1d(num_features=feat_dim * 3,
140+
affine=False)
140141
self.has_LDA = False
141142

142143
def forward(self, x):
@@ -218,9 +219,9 @@ def constrain_orthonormal(self):
218219

219220
if __name__ == '__main__':
220221
logging.basicConfig(
221-
level=logging.DEBUG,
222-
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
223-
)
222+
level=logging.DEBUG,
223+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
224+
)
224225
feat_dim = 43
225226
output_dim = 4344
226227
model = ChainModel(feat_dim=feat_dim, output_dim=output_dim)

egs/aishell/s10/chain/tdnnf_layer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,11 @@ def forward(self, x):
187187

188188
# save it for skip connection
189189
input_x = x
190-
logging.debug(f"input_x shape is {input_x.shape}")
190+
191191
x = self.linear(x)
192-
logging.debug(f"x shape after linear is {x.shape}")
193192
# at this point, x is [N, C, T]
194193

195194
x = self.affine(x)
196-
logging.debug(f"x shape after affine is {x.shape}")
197195
# at this point, x is [N, C, T]
198196

199197
x = F.relu(x)
@@ -298,7 +296,6 @@ def _test_factorized_tdnn():
298296
assert y.size(2) == math.ceil(math.ceil((T - 3)) - 3)
299297

300298

301-
302299
if __name__ == '__main__':
303300
torch.manual_seed(20200130)
304301
_test_factorized_tdnn()

egs/aishell/s10/chain/train.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,11 @@ def main():
183183
egs_left_context=args.egs_left_context,
184184
egs_right_context=args.egs_right_context)
185185

186-
optimizer = optim.SGD(model.parameters(),
187-
lr=learning_rate,
188-
momentum=0.9,
189-
weight_decay=args.l2_regularize)
186+
optimizer = optim.Adam(model.parameters(),
187+
lr=learning_rate,
188+
weight_decay=args.l2_regularize)
190189

191-
scheduler = MultiStepLR(optimizer, milestones=[1, 3, 5], gamma=0.5)
190+
scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4, 5], gamma=0.5)
192191
criterion = KaldiChainObjfFunction.apply
193192

194193
tf_writer = SummaryWriter(log_dir='{}/tensorboard'.format(args.dir))

0 commit comments

Comments
 (0)