Skip to content

Commit 2039ff0

Browse files
author
Corey Adams
committed
Fix gradient accumulation in tf2
1 parent 038e3a9 commit 2039ff0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/utils/tensorflow2/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def train_step(self):
653653

654654
# After the accumulation, weight the gradients as needed and apply them:
655655
if self.args.gradient_accumulation != 1:
656-
gradients /= self.args.gradient_accumulation
656+
gradients = [ g / self.args.gradient_accumulation for g in gradients ]
657657

658658
self.apply_gradients(gradients)
659659

0 commit comments

Comments
 (0)