From c51ac04fb6cdd898c7779dbe0633379b76aee56c Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 9 Jan 2025 09:02:14 +0100 Subject: [PATCH 1/2] update --- finetrainers/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 81c82647..9e02fc30 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -11,7 +11,6 @@ import torch import torch.backends import transformers -import wandb from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import ( @@ -30,6 +29,8 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from tqdm import tqdm +import wandb + from .args import _INVERSE_DTYPE_MAP, Args, validate_args from .constants import ( FINETRAINERS_LOG_LEVEL, @@ -673,6 +674,8 @@ def train(self) -> None: self.transformer.train() models_to_accumulate = [self.transformer] + epoch_loss = 0.0 + num_loss_updates = 0 for step, batch in enumerate(self.dataloader): logger.debug(f"Starting step {step + 1}") @@ -843,7 +846,10 @@ def train(self) -> None: if should_run_validation: self.validate(global_step) - logs["loss"] = loss.detach().item() + loss_item = loss.detach().item() + epoch_loss += loss_item + num_loss_updates += 1 + logs["step_loss"] = loss_item logs["lr"] = self.lr_scheduler.get_last_lr()[0] progress_bar.set_postfix(logs) accelerator.log(logs, step=global_step) @@ -851,6 +857,9 @@ def train(self) -> None: if global_step >= self.state.train_steps: break + if num_loss_updates > 0: + epoch_loss /= num_loss_updates + accelerator.log({"epoch_loss": epoch_loss}, step=global_step) memory_statistics = get_memory_statistics() logger.info(f"Memory after epoch {epoch + 1}: {json.dumps(memory_statistics, indent=4)}") From 7f259249dcdf41bc80b08eb9886e4e93803a4da5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 9 Jan 2025 09:03:56 +0100 Subject: [PATCH 2/2] update --- finetrainers/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/finetrainers/trainer.py b/finetrainers/trainer.py index 9e02fc30..8f2ebedf 100644 --- a/finetrainers/trainer.py +++ b/finetrainers/trainer.py @@ -11,6 +11,7 @@ import torch import torch.backends import transformers +import wandb from accelerate import Accelerator, DistributedType from accelerate.logging import get_logger from accelerate.utils import ( @@ -29,8 +30,6 @@ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from tqdm import tqdm -import wandb - from .args import _INVERSE_DTYPE_MAP, Args, validate_args from .constants import ( FINETRAINERS_LOG_LEVEL,