Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 9df613e

Browse files
committed
Improve the readability of the training script. This fix replaces magic numbers with the name
1 parent 8d27b82 commit 9df613e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

scripts/machine_translation/train_transformer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@
203203
test_loss_function = MaskedSoftmaxCELoss()
204204
test_loss_function.hybridize(static_alloc=static_alloc)
205205

206-
rescale_loss = 100
206+
rescale_loss = 100.
207207
parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss)
208208
detokenizer = nlp.data.SacreMosesDetokenizer()
209209

@@ -317,7 +317,7 @@ def train():
317317
if average_param_dict is None:
318318
average_param_dict = {k: v.data(ctx[0]).copy() for k, v in
319319
model.collect_params().items()}
320-
trainer.step(float(loss_denom) / args.batch_size / 100.0)
320+
trainer.step(float(loss_denom) / args.batch_size / rescale_loss)
321321
param_dict = model.collect_params()
322322
param_dict.zero_grad()
323323
if step_num > average_start:
@@ -327,7 +327,7 @@ def train():
327327
step_loss += sum([L.asscalar() for L in Ls])
328328
if batch_id % grad_interval == grad_interval - 1 or\
329329
batch_id == len(train_data_loader) - 1:
330-
log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0
330+
log_avg_loss += step_loss / loss_denom * args.batch_size * rescale_loss
331331
loss_denom = 0
332332
step_loss = 0
333333
log_wc += src_wc + tgt_wc

0 commit comments

Comments
 (0)