From 2d63c6d1bf4a0a34db3383990f9df6655998c6e8 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Tue, 26 Aug 2025 09:29:08 -0700 Subject: [PATCH 1/4] [ez] zero grad every step --- apps/sft/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index b5ae6fc16..533c7581c 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -199,6 +199,7 @@ def train_step(self, batch) -> None: # self.model, # self.data_parallel_size, # ) as grad_acc: + self.optimizers.zero_grad() labels = batch.pop("labels") loss = self.forward_backward(batch, labels) self.pbar.update(1) @@ -209,7 +210,6 @@ def train_step(self, batch) -> None: def train(self) -> None: dataloader = iter(self.train_dataloader) - self.optimizers.zero_grad() self.pbar = tqdm( initial=0, From 7767a09d18cdbf873652d9df551b76f8a5b25c3a Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Tue, 26 Aug 2025 09:32:02 -0700 Subject: [PATCH 2/4] zero right after optimizer step --- apps/sft/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index 533c7581c..a1a9c7419 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -199,16 +199,17 @@ def train_step(self, batch) -> None: # self.model, # self.data_parallel_size, # ) as grad_acc: - self.optimizers.zero_grad() labels = batch.pop("labels") loss = self.forward_backward(batch, labels) self.pbar.update(1) self.pbar.set_description(f"{self.current_step}|Loss: {loss}") self.optimizers.step() + self.optimizers.zero_grad() self.lr_schedulers.step() def train(self) -> None: + self.optimizers.zero_grad() dataloader = iter(self.train_dataloader) self.pbar = tqdm( From d96905aaf884a4446e542a6c4f01c6ab169da7ba Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Tue, 26 Aug 2025 09:33:09 -0700 Subject: [PATCH 3/4] reorder --- apps/sft/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index a1a9c7419..dcda8bffe 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -205,12 +205,12 @@ def train_step(self, batch) -> None: self.pbar.set_description(f"{self.current_step}|Loss: {loss}") self.optimizers.step() - self.optimizers.zero_grad() self.lr_schedulers.step() + self.optimizers.zero_grad() def train(self) -> None: - self.optimizers.zero_grad() dataloader = iter(self.train_dataloader) + self.optimizers.zero_grad() self.pbar = tqdm( initial=0, From b7dda39c3ac9f587c35a1ab47909f36f4a0e9995 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Tue, 26 Aug 2025 09:34:04 -0700 Subject: [PATCH 4/4] move right after optimizer step --- apps/sft/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/sft/main.py b/apps/sft/main.py index dcda8bffe..93ea21996 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -205,8 +205,8 @@ def train_step(self, batch) -> None: self.pbar.set_description(f"{self.current_step}|Loss: {loss}") self.optimizers.step() - self.lr_schedulers.step() self.optimizers.zero_grad() + self.lr_schedulers.step() def train(self) -> None: dataloader = iter(self.train_dataloader)