We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent eee6be9 commit 00d02acCopy full SHA for 00d02ac
examples/simultaneous_translation/stacl/train.py
@@ -118,11 +118,8 @@ def do_train(args):
118
batch_id = 0
119
batch_start = time.time()
120
for input_data in train_loader:
121
- if args.max_iter and step_idx == args.max_iter:
122
- return
123
train_reader_cost = time.time() - batch_start
124
(src_word, trg_word, lbl_word) = input_data
125
-
126
if args.use_amp:
127
scaler = paddle.amp.GradScaler(
128
init_loss_scaling=args.scale_loss)
@@ -144,6 +141,9 @@ def do_train(args):
144
141
optimizer.step()
145
142
optimizer.clear_grad()
146
143
+ if args.max_iter and step_idx + 1 == args.max_iter:
+ return
+
147
tokens_per_cards = token_num.numpy()
148
149
train_batch_cost = time.time() - batch_start
0 commit comments