Skip to content

Commit cc63cca

Browse files
authored
[scripts] support device id provided by user (#3968)
1 parent 63c732b commit cc63cca

File tree

4 files changed

+109
-112
lines changed

4 files changed

+109
-112
lines changed

egs/aishell/s10/chain/inference.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,16 @@ def main():
3030
logging.warning('No GPU detected! Use CPU for inference.')
3131
device = torch.device('cpu')
3232
else:
33-
devices = allocate_gpu_devices(1)
34-
if len(devices) != 1:
35-
logging.error('Allocate GPU failed!')
36-
sys.exit(-1)
37-
device = torch.device('cuda', devices[0][0])
33+
if args.device_ids != None and len(args.device_ids) > 0:
34+
device_id = args.device_ids[0]
35+
else:
36+
devices = allocate_gpu_devices(1)
37+
if len(devices) != 1:
38+
logging.error('Allocate GPU failed!')
39+
sys.exit(-1)
40+
device_id = devices[0][0]
41+
logging.info('device: {}'.format(device_id))
42+
device = torch.device('cuda', device_id)
3843

3944
model = get_chain_model(
4045
feat_dim=args.feat_dim,

egs/aishell/s10/chain/options.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def _set_training_args(parser):
125125
parser.add_argument('--train.ddp.multiple-machine',
126126
dest='multiple_machine',
127127
help="use ddp with multiple machines",
128-
type=_str2bool)
128+
type=_str2bool,
129+
default=False)
129130

130131

131132
parser.add_argument('--train.ddp.init-method',
@@ -184,9 +185,14 @@ def _check_args(args):
184185
if args.lda_mat_filename:
185186
assert os.path.isfile(args.lda_mat_filename)
186187

187-
# although -1 means to use CPU in `kaldi.SelectGpuDevice()`
188-
# we do NOT want to use CPU here so we require it to be >= 0
189-
# assert args.device_id >= 0
188+
if args.device_ids != None:
189+
# do NOT support assigning GPU when training with multiple machines
190+
assert args.multiple_machine == False
191+
args.device_ids = [int(i) for i in args.device_ids.split(', ')]
192+
# although -1 means to use CPU in `kaldi.SelectGpuDevice()`
193+
# we do NOT want to use CPU here so we require it to be >= 0
194+
for i in args.device_ids:
195+
assert i >= 0
190196

191197
assert args.feat_dim > 0
192198
assert args.output_dim > 0
@@ -224,11 +230,11 @@ def get_args():
224230
required=True,
225231
type=str)
226232

227-
parser.add_argument('--device-id',
228-
dest='device_id',
229-
help='GPU device id',
233+
parser.add_argument('--device-ids',
234+
dest='device_ids',
235+
help='GPU device ids',
230236
required=False,
231-
type=int)
237+
type=str)
232238

233239
parser.add_argument('--is-training',
234240
dest='is_training',

egs/aishell/s10/chain/train.py

Lines changed: 83 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -33,114 +33,94 @@
3333
from model import get_chain_model
3434
from options import get_args
3535

36-
def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
36+
def get_objf(batch, model, device, criterion, opts, den_graph, training, optimizer=None):
3737
total_objf = 0.
3838
total_weight = 0.
3939
total_frames = 0. # for display only
40+
41+
key_list, feature_list, supervision_list = batch
42+
assert len(key_list) == len(feature_list) == len(supervision_list)
43+
batch_size = len(key_list)
44+
for n in range(batch_size):
45+
feats = feature_list[n]
46+
assert feats.ndim == 3
47+
48+
# at this point, feats is [N, T, C]
49+
feats = feats.to(device)
50+
if training:
51+
nnet_output, xent_output = model(feats)
52+
else:
53+
with torch.no_grad():
54+
nnet_output, xent_output = model(feats)
4055

41-
model.eval()
56+
# at this point, nnet_output is: [N, T, C]
57+
# refer to kaldi/src/chain/chain-training.h
58+
# the output should be organized as
59+
# [all sequences for frame 0]
60+
# [all sequences for frame 1]
61+
# [etc.]
62+
nnet_output = nnet_output.permute(1, 0, 2)
63+
# at this point, nnet_output is: [T, N, C]
64+
nnet_output = nnet_output.contiguous().view(-1,
65+
nnet_output.shape[-1])
66+
67+
# at this point, xent_output is: [N, T, C]
68+
xent_output = xent_output.permute(1, 0, 2)
69+
# at this point, xent_output is: [T, N, C]
70+
xent_output = xent_output.contiguous().view(-1,
71+
xent_output.shape[-1])
72+
objf_l2_term_weight = criterion(opts, den_graph,
73+
supervision_list[n], nnet_output,
74+
xent_output)
75+
objf = objf_l2_term_weight[0]
76+
if training:
77+
optimizer.zero_grad()
78+
objf.backward()
79+
clip_grad_value_(model.parameters(), 5.0)
80+
optimizer.step()
4281

43-
for batch_idx, batch in enumerate(dataloader):
44-
key_list, feature_list, supervision_list = batch
82+
objf_l2_term_weight = objf_l2_term_weight.detach().cpu()
4583

46-
assert len(key_list) == len(feature_list) == len(supervision_list)
47-
batch_size = len(key_list)
84+
total_objf += objf_l2_term_weight[0].item()
85+
total_weight += objf_l2_term_weight[2].item()
86+
num_frames = nnet_output.shape[0]
87+
total_frames += num_frames
4888

49-
for n in range(batch_size):
50-
feats = feature_list[n]
51-
assert feats.ndim == 3
89+
return total_objf, total_weight, total_frames
5290

53-
# at this point, feats is [N, T, C]
54-
feats = feats.to(device)
5591

56-
with torch.no_grad():
57-
nnet_output, xent_output = model(feats)
92+
def get_validation_objf(dataloader, model, device, criterion, opts, den_graph):
93+
total_objf = 0.
94+
total_weight = 0.
95+
total_frames = 0. # for display only
96+
97+
model.eval()
5898

59-
# at this point, nnet_output is: [N, T, C]
60-
# refer to kaldi/src/chain/chain-training.h
61-
# the output should be organized as
62-
# [all sequences for frame 0]
63-
# [all sequences for frame 1]
64-
# [etc.]
65-
nnet_output = nnet_output.permute(1, 0, 2)
66-
# at this point, nnet_output is: [T, N, C]
67-
nnet_output = nnet_output.contiguous().view(-1,
68-
nnet_output.shape[-1])
69-
70-
# at this point, xent_output is: [N, T, C]
71-
xent_output = xent_output.permute(1, 0, 2)
72-
# at this point, xent_output is: [T, N, C]
73-
xent_output = xent_output.contiguous().view(-1,
74-
xent_output.shape[-1])
75-
objf_l2_term_weight = criterion(opts, den_graph,
76-
supervision_list[n], nnet_output,
77-
xent_output)
78-
objf = objf_l2_term_weight[0]
79-
80-
objf_l2_term_weight = objf_l2_term_weight.cpu()
81-
82-
total_objf += objf_l2_term_weight[0].item()
83-
total_weight += objf_l2_term_weight[2].item()
84-
85-
num_frames = nnet_output.shape[0]
86-
total_frames += num_frames
99+
for batch_idx, batch in enumerate(dataloader):
100+
objf, weight, frames = get_objf(
101+
batch, model, device, criterion, opts, den_graph, False)
102+
total_objf += objf
103+
total_weight += weight
104+
total_frames += frames
87105

88106
return total_objf, total_weight, total_frames
89107

90108

91109
def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
92110
criterion, current_epoch, opts, den_graph, tf_writer, rank):
93-
model.train()
94-
95111
total_objf = 0.
96112
total_weight = 0.
97113
total_frames = 0. # for display only
98114

99-
for batch_idx, batch in enumerate(dataloader):
100-
key_list, feature_list, supervision_list = batch
101-
assert len(key_list) == len(feature_list) == len(supervision_list)
102-
batch_size = len(key_list)
103-
for n in range(batch_size):
104-
feats = feature_list[n]
105-
assert feats.ndim == 3
106-
107-
# at this point, feats is [N, T, C]
108-
feats = feats.to(device)
109-
nnet_output, xent_output = model(feats)
110-
111-
# at this point, nnet_output is: [N, T, C]
112-
# refer to kaldi/src/chain/chain-training.h
113-
# the output should be organized as
114-
# [all sequences for frame 0]
115-
# [all sequences for frame 1]
116-
# [etc.]
117-
nnet_output = nnet_output.permute(1, 0, 2)
118-
# at this point, nnet_output is: [T, N, C]
119-
nnet_output = nnet_output.contiguous().view(-1,
120-
nnet_output.shape[-1])
121-
122-
# at this point, xent_output is: [N, T, C]
123-
xent_output = xent_output.permute(1, 0, 2)
124-
# at this point, xent_output is: [T, N, C]
125-
xent_output = xent_output.contiguous().view(-1,
126-
xent_output.shape[-1])
127-
objf_l2_term_weight = criterion(opts, den_graph,
128-
supervision_list[n], nnet_output,
129-
xent_output)
130-
objf = objf_l2_term_weight[0]
131-
optimizer.zero_grad()
132-
objf.backward()
133-
134-
clip_grad_value_(model.parameters(), 5.0)
135-
136-
optimizer.step()
115+
model.train()
137116

138-
objf_l2_term_weight = objf_l2_term_weight.detach().cpu()
117+
for batch_idx, batch in enumerate(dataloader):
118+
curr_batch_objf, curr_batch_weight, curr_batch_frames = get_objf(
119+
batch, model, device, criterion, opts, den_graph, True, optimizer)
139120

140-
total_objf += objf_l2_term_weight[0].item()
141-
total_weight += objf_l2_term_weight[2].item()
142-
num_frames = nnet_output.shape[0]
143-
total_frames += num_frames
121+
total_objf += curr_batch_objf
122+
total_weight += curr_batch_weight
123+
total_frames += curr_batch_frames
144124

145125
if batch_idx % 100 == 0:
146126
logging.info(
@@ -150,8 +130,8 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
150130
device.index, batch_idx, len(dataloader),
151131
float(batch_idx) / len(dataloader) * 100,
152132
total_objf / total_weight, total_frames,
153-
objf_l2_term_weight[0].item() /
154-
objf_l2_term_weight[2].item(), num_frames, current_epoch))
133+
curr_batch_objf / curr_batch_weight,
134+
curr_batch_frames, current_epoch))
155135

156136
if valid_dataloader and batch_idx % 1000 == 0:
157137
total_valid_objf, total_valid_weight, total_valid_frames = get_validation_objf(
@@ -161,7 +141,6 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
161141
criterion=criterion,
162142
opts=opts,
163143
den_graph=den_graph)
164-
165144
model.train()
166145

167146
logging.info(
@@ -178,7 +157,7 @@ def train_one_epoch(dataloader, valid_dataloader, model, device, optimizer,
178157
batch_idx + current_epoch * len(dataloader))
179158
tf_writer.add_scalar(
180159
'train/current_batch_average_objf',
181-
objf_l2_term_weight[0].item() / objf_l2_term_weight[2].item(),
160+
curr_batch_objf / curr_batch_weight,
182161
batch_idx + current_epoch * len(dataloader))
183162

184163
state_dict = model.state_dict()
@@ -206,20 +185,24 @@ def main():
206185
if args.multiple_machine:
207186
# Suppose we have submitted multiple jobs with SGE (Sun Grid Engine)
208187
local_rank = int(os.environ['SGE_TASK_ID']) - 1
209-
process_job(learning_rate, local_rank)
188+
process_job(learning_rate, local_rank=local_rank)
210189
else:
211190
proc = []
191+
if args.device_ids != None:
192+
assert len(args.device_ids) >= args.world_size
212193
for i in range(args.world_size):
213-
p = Process(target=process_job, args=(learning_rate, i))
194+
device_id = None if args.device_ids == None else args.device_ids[i]
195+
p = Process(target=process_job, args=(learning_rate, device_id, i))
214196
proc.append(p)
215197
p.start()
216198
for p in proc:
217199
p.join()
218200
else:
219-
process_job(args.learning_rate)
201+
device_id = None if args.device_ids == None else args.device_ids[0]
202+
process_job(args.learning_rate, device_id)
220203

221204

222-
def process_job(learning_rate, local_rank=None):
205+
def process_job(learning_rate, device_id=None, local_rank=None):
223206
args = get_args()
224207
if local_rank != None:
225208
setup_logger('{}/train/logs/log-train-rank-{}'.format(args.dir, local_rank),
@@ -233,12 +216,13 @@ def process_job(learning_rate, local_rank=None):
233216
logging.error('No GPU detected!')
234217
sys.exit(-1)
235218

236-
devices = allocate_gpu_devices(1)
237-
if len(devices) < 1:
238-
logging.error('Allocate GPU failed!')
239-
sys.exit(-1)
219+
if device_id == None:
220+
devices = allocate_gpu_devices(1)
221+
if len(devices) < 1:
222+
logging.error('Allocate GPU failed!')
223+
sys.exit(-1)
224+
device_id = devices[0][0]
240225

241-
device_id = devices[0][0]
242226
logging.info('device: {}'.format(device_id))
243227

244228
if args.use_ddp:

egs/aishell/s10/local/run_chain.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ if [[ $stage -le 17 ]]; then
227227
use_ddp=true
228228
world_size=4
229229
use_multiple_machine=true
230+
# you can assign GPUs with --device-ids "$device_ids"
231+
# device_ids="4, 5, 6, 7"
230232
if $use_multiple_machine ; then
231233
# suppose you are using Sun GridEngine
232234
cuda_train_cmd=$(echo "$cuda_train_cmd --gpu 1 JOB=1:$world_size $dir/train/logs/job.JOB.log")

0 commit comments

Comments
 (0)