diff --git a/apps/sft/main.py b/apps/sft/main.py index b5ae6fc16..93ea21996 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -205,6 +205,7 @@ def train_step(self, batch) -> None: self.pbar.set_description(f"{self.current_step}|Loss: {loss}") self.optimizers.step() + self.optimizers.zero_grad() self.lr_schedulers.step() def train(self) -> None: