Skip to content

Commit 00d02ac

Browse files
author
gongenlei
authored
[Bugfix] Fix ci breaking for paddle2.1.2 (#861)
* bugfix: fix ci breaking for paddle2.1.2 * bugfix: fix ci breaking for paddle2.1.2
1 parent eee6be9 commit 00d02ac

File tree

1 file changed

+3
-3
lines changed
  • examples/simultaneous_translation/stacl

1 file changed

+3
-3
lines changed

examples/simultaneous_translation/stacl/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,8 @@ def do_train(args):
118118
batch_id = 0
119119
batch_start = time.time()
120120
for input_data in train_loader:
121-
if args.max_iter and step_idx == args.max_iter:
122-
return
123121
train_reader_cost = time.time() - batch_start
124122
(src_word, trg_word, lbl_word) = input_data
125-
126123
if args.use_amp:
127124
scaler = paddle.amp.GradScaler(
128125
init_loss_scaling=args.scale_loss)
@@ -144,6 +141,9 @@ def do_train(args):
144141
optimizer.step()
145142
optimizer.clear_grad()
146143

144+
if args.max_iter and step_idx + 1 == args.max_iter:
145+
return
146+
147147
tokens_per_cards = token_num.numpy()
148148

149149
train_batch_cost = time.time() - batch_start

0 commit comments

Comments
 (0)