Skip to content

Commit 1ffb881

Browse files
authored
[egs] support dropout in pytorch training (#3988)
1 parent c9a68a5 commit 1ffb881

File tree

8 files changed

+141
-70
lines changed

8 files changed

+141
-70
lines changed

egs/aishell/s10/chain/model.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from tdnnf_layer import FactorizedTDNN
1414
from tdnnf_layer import OrthonormalLinear
1515
from tdnnf_layer import PrefinalLayer
16+
from tdnnf_layer import TDNN
1617

1718

1819
def get_chain_model(feat_dim,
@@ -101,13 +102,8 @@ def __init__(self,
101102
num_layers = len(kernel_size_list)
102103

103104
input_dim = feat_dim * 3 + ivector_dim
104-
# tdnn1_affine requires [N, T, C]
105-
self.tdnn1_affine = nn.Linear(in_features=input_dim,
106-
out_features=hidden_dim)
107-
108-
# tdnn1_batchnorm requires [N, C, T]
109-
self.tdnn1_batchnorm = nn.BatchNorm1d(num_features=hidden_dim,
110-
affine=False)
105+
106+
self.tdnn1 = TDNN(input_dim=input_dim, hidden_dim=hidden_dim)
111107

112108
tdnnfs = []
113109
for i in range(num_layers):
@@ -156,7 +152,7 @@ def __init__(self,
156152

157153
self.register_forward_pre_hook(constrain_orthonormal_hook)
158154

159-
def forward(self, x):
155+
def forward(self, x, dropout=0.):
160156
# input x is of shape: [batch_size, seq_len, feat_dim] = [N, T, C]
161157
assert x.ndim == 3
162158

@@ -178,25 +174,11 @@ def forward(self, x):
178174

179175
# at this point, x is [N, C, T]
180176

181-
x = x.permute(0, 2, 1)
182-
183-
# at this point, x is [N, T, C]
184-
185-
x = self.tdnn1_affine(x)
186-
187-
# at this point, x is [N, T, C]
188-
189-
x = F.relu(x)
190-
191-
x = x.permute(0, 2, 1)
192-
193-
# at this point, x is [N, C, T]
194-
195-
x = self.tdnn1_batchnorm(x)
177+
x = self.tdnn1(x, dropout=dropout)
196178

197179
# tdnnf requires input of shape [N, C, T]
198180
for i in range(len(self.tdnnfs)):
199-
x = self.tdnnfs[i](x)
181+
x = self.tdnnfs[i](x, dropout=dropout)
200182

201183
# at this point, x is [N, C, T]
202184

egs/aishell/s10/chain/options.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def _set_training_args(parser):
105105
dest='l2_regularize',
106106
help='l2 regularize',
107107
type=float)
108+
109+
parser.add_argument('--train.dropout-schedule',
110+
dest='dropout_schedule',
111+
help='dropout schedule',
112+
type=str,
113+
108114

109115
parser.add_argument('--train.xent-regularize',
110116
dest='xent_regularize',

egs/aishell/s10/chain/tdnnf_layer.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,28 @@ def _constrain_orthonormal_internal(M):
5151
return M
5252

5353

54+
class SharedDimScaleDropout(nn.Module):
55+
def __init__(self, dim=1):
56+
'''
57+
Continuous scaled dropout that is const over chosen dim (usually across time)
58+
Multiplies inputs by random mask taken from Uniform([1 - 2\alpha, 1 + 2\alpha])
59+
'''
60+
super().__init__()
61+
self.dim = dim
62+
self.register_buffer('mask', torch.tensor(0.))
63+
64+
def forward(self, x, alpha=0.0):
65+
if self.training and alpha > 0.:
66+
# sample mask from uniform dist with dim of length 1 in self.dim and then repeat to match size
67+
tied_mask_shape = list(x.shape)
68+
tied_mask_shape[self.dim] = 1
69+
repeats = [1 if i != self.dim else x.shape[self.dim]
70+
for i in range(len(x.shape))]
71+
return x * self.mask.repeat(tied_mask_shape).uniform_(1 - 2*alpha, 1 + 2*alpha).repeat(repeats)
72+
# expected value of dropout mask is 1 so no need to scale outputs like vanilla dropout
73+
return x
74+
75+
5476
class OrthonormalLinear(nn.Module):
5577

5678
def __init__(self, dim, bottleneck_dim, kernel_size):
@@ -137,6 +159,35 @@ def forward(self, x):
137159
return x
138160

139161

162+
class TDNN(nn.Module):
163+
'''
164+
This class implements the following topology in kaldi:
165+
relu-batchnorm-dropout-layer name=tdnn1 dropout-per-dim-continuous=true dim=1024
166+
'''
167+
168+
def __init__(self, input_dim, hidden_dim):
169+
super().__init__()
170+
# affine conv1d requires [N, C, T]
171+
self.affine = nn.Conv1d(in_channels=input_dim,
172+
out_channels=hidden_dim,
173+
kernel_size=1)
174+
175+
# tdnn1_batchnorm requires [N, C, T]
176+
self.batchnorm = nn.BatchNorm1d(num_features=hidden_dim,
177+
affine=False)
178+
179+
self.dropout = SharedDimScaleDropout(dim=2)
180+
181+
def forward(self, x, dropout=0.):
182+
# input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T]
183+
x = self.affine(x)
184+
x = F.relu(x)
185+
x = self.batchnorm(x)
186+
x = self.dropout(x, alpha=dropout)
187+
# return shape is [N, C, T]
188+
return x
189+
190+
140191
class FactorizedTDNN(nn.Module):
141192
'''
142193
This class implements the following topology in kaldi:
@@ -178,7 +229,9 @@ def __init__(self,
178229
# batchnorm requires [N, C, T]
179230
self.batchnorm = nn.BatchNorm1d(num_features=dim, affine=False)
180231

181-
def forward(self, x):
232+
self.dropout = SharedDimScaleDropout(dim=2)
233+
234+
def forward(self, x, dropout=0.):
182235
# input x is of shape: [batch_size, feat_dim, seq_len] = [N, C, T]
183236
assert x.ndim == 3
184237

@@ -199,9 +252,9 @@ def forward(self, x):
199252

200253
# at this point, x is [N, C, T]
201254

202-
# TODO(fangjun): implement GeneralDropoutComponent in PyTorch
255+
x = self.dropout(x, alpha=dropout)
203256

204-
if self.linear.kernel_size == 3:
257+
if self.linear.kernel_size > 1:
205258
x = self.bypass_scale * input_x[:, :, self.s:-self.s:self.s] + x
206259
else:
207260
x = self.bypass_scale * input_x[:, :, ::self.s] + x

egs/aishell/s10/chain/train.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch.optim as optim
1818
from torch.nn.parallel import DistributedDataParallel as DDP
1919
from torch.nn.utils import clip_grad_value_
20+
from torch.utils.data.distributed import DistributedSampler
2021
from torch.utils.tensorboard import SummaryWriter
2122

2223
import kaldi
@@ -30,10 +31,11 @@
3031
from common import setup_logger
3132
from device_utils import allocate_gpu_devices
3233
from egs_dataset import get_egs_dataloader
34+
from libs.nnet3.train.dropout_schedule import _get_dropout_proportions
3335
from model import get_chain_model
3436
from options import get_args
3537

36-
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None):
38+
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None, dropout=0.):
3739
total_objf = 0.
3840
total_weight = 0.
3941
total_frames = 0. # for display only
@@ -48,7 +50,7 @@ def get_objf(batch, model, device, criterion, opts, den_graph, training, optimiz
4850
# at this point, feats is [N, T, C]
4951
feats = feats.to(device)
5052
if training:
51-
nnet_output, xent_output = model(feats)
53+
nnet_output, xent_output = model(feats, dropout=dropout)
5254
else:
5355
with torch.no_grad():
5456
nnet_output, xent_output = model(feats)
@@ -106,17 +108,20 @@ def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
106108
return total_objf, total_weight, total_frames
107109

108110

109-
def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
110-
criterion, current_epoch, opts, den_graph, tf_writer, rank):
111+
def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer, criterion,
112+
current_epoch, num_epochs, opts, den_graph, tf_writer, rank, dropout_schedule):
111113
total_objf = 0.
112114
total_weight = 0.
113115
total_frames = 0. # for display only
114116

115117
model.train()
116-
117118
for batch_idx, batch in enumerate(dataloader):
119+
data_fraction = (batch_idx + current_epoch *
120+
len(dataloader)) / (len(dataloader) * num_epochs)
121+
_, dropout = _get_dropout_proportions(
122+
dropout_schedule, data_fraction)[0]
118123
curr_batch_objf, curr_batch_weight, curr_batch_frames = get_objf(
119-
batch, model, device, criterion, opts, den_graph, True, optimizer)
124+
batch, model, device, criterion, opts, den_graph, True, optimizer, dropout=dropout)
120125

121126
total_objf += curr_batch_objf
122127
total_weight += curr_batch_weight
@@ -159,6 +164,11 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
159164
'train/current_batch_average_objf',
160165
curr_batch_objf / curr_batch_weight,
161166
batch_idx + current_epoch * len(dataloader))
167+
168+
tf_writer.add_scalar(
169+
'train/current_dropout',
170+
dropout,
171+
batch_idx + current_epoch * len(dataloader))
162172

163173
state_dict = model.state_dict()
164174
for key, value in state_dict.items():
@@ -205,10 +215,10 @@ def main():
205215
def process_job(learning_rate, device_id=None, local_rank=None):
206216
args = get_args()
207217
if local_rank != None:
208-
setup_logger('{}/train/logs/log-train-rank-{}'.format(args.dir, local_rank),
218+
setup_logger('{}/logs/log-train-rank-{}'.format(args.dir, local_rank),
209219
args.log_level)
210220
else:
211-
setup_logger('{}/train/logs/log-train-single-GPU'.format(args.dir), args.log_level)
221+
setup_logger('{}/logs/log-train-single-GPU'.format(args.dir), args.log_level)
212222

213223
logging.info(' '.join(sys.argv))
214224

@@ -249,7 +259,6 @@ def process_job(learning_rate, device_id=None, local_rank=None):
249259
opts.leaky_hmm_coefficient = args.leaky_hmm_coefficient
250260

251261
den_graph = chain.DenominatorGraph(fst=den_fst, num_pdfs=args.output_dim)
252-
253262

254263
model = get_chain_model(
255264
feat_dim=args.feat_dim,
@@ -325,6 +334,9 @@ def process_job(learning_rate, device_id=None, local_rank=None):
325334

326335
if tf_writer:
327336
tf_writer.add_scalar('learning_rate', curr_learning_rate, epoch)
337+
338+
if dataloader.sampler and isinstance(dataloader.sampler, DistributedSampler):
339+
dataloader.sampler.set_epoch(epoch)
328340

329341
objf = train_one_epoch(dataloader=dataloader,
330342
valid_dataloader=valid_dataloader,
@@ -333,10 +345,12 @@ def process_job(learning_rate, device_id=None, local_rank=None):
333345
optimizer=optimizer,
334346
criterion=criterion,
335347
current_epoch=epoch,
348+
num_epochs=num_epochs,
336349
opts=opts,
337350
den_graph=den_graph,
338351
tf_writer=tf_writer,
339-
rank=local_rank)
352+
rank=local_rank,
353+
dropout_schedule=args.dropout_schedule)
340354

341355
if best_objf is None:
342356
best_objf = objf

egs/aishell/s10/cmd.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
export train_cmd="queue.pl -q all.q --mem 4G"
1414
export decode_cmd="queue.pl -q all.q --mem 4G"
1515
export mkgraph_cmd="queue.pl -q all.q --mem 8G"
16-
export cuda_train_cmd="queue.pl -q v100.q --mem 4G"
17-
export cuda_inference_cmd="queue.pl -q v100.q --mem 4G"
16+
export cuda_train_cmd="queue.pl -q g.q --mem 4G"
17+
export cuda_inference_cmd="queue.pl -q g.q --mem 4G"
1818

egs/aishell/s10/local/decode.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ lattice_beam=4.0
1111
max_active=7000 # limit of active tokens
1212
max_mem=50000000 # approx. limit to memory consumption during minimization in bytes
1313
min_active=200
14-
num_threads=20
14+
num_threads=10
1515
post_decode_acwt=10 # can be used in 'chain' systems to scale acoustics by 10
1616

1717
. ./path.sh

0 commit comments

Comments
 (0)