diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 59c1b5e4..8a3003e1 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -191,6 +191,7 @@ def _get_mesh(): return _get_mesh() def get_checkpointer(self, *args, **kwargs): + kwargs["_parallel_backend"] = self return AccelerateCheckpointer(self._accelerator, *args, **kwargs) @property @@ -263,10 +264,19 @@ def __init__( enable: bool = True, _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, _prefix: str = "finetrainers_step", + _parallel_backend: Optional["BaseParallelBackend"] = None, *args, **kwargs, ) -> None: self.accelerator = accelerator + self._parallel_backend = _parallel_backend + + if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker: + wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + states["wandb_run_id"] = wandb_run_id + else: + states["wandb_run_id"] = None self.states = states self.checkpointing_steps = checkpointing_steps @@ -285,10 +295,11 @@ def save_model_hook(models, weights, output_dir: str) -> None: assert len(models) == 1 _callback_fn(weights[0]) + torch.save(self.states, os.path.join(output_dir, "states.pt")) def load_model_hook(models, input_dir) -> None: - self.states = torch.load(os.path.join(input_dir, "states.pt")) + self.states = torch.load(os.path.join(input_dir, "states.pt"), weights_only=False) self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook) @@ -334,6 +345,10 @@ def load(self, step: int = -1) -> bool: return True + def get_wandb_run_id_from_checkpoint(self) -> Optional[str]: + """Get the wandb run ID from the loaded checkpoint states.""" + return self.states.get("wandb_run_id", None) + def _should_checkpoint(self, step: int, force: bool) -> bool: if not self.enable: return False diff --git a/finetrainers/parallel/base.py b/finetrainers/parallel/base.py index ab04aeb7..68f84286 100644 --- a/finetrainers/parallel/base.py +++ b/finetrainers/parallel/base.py @@ -45,10 +45,15 @@ def get_checkpointer(self, *args, **kwargs) -> None: raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.") def initialize_trackers( - self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str + self, + trackers: List[str], + experiment_name: str, + config: Dict[str, Any], + log_dir: str, + resume_run_id: Optional[str] = None, ) -> TrackerType: if self.is_main_process: - self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir) + self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir, resume_run_id) else: self.tracker = DummyTracker() diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py index 2a95b1a9..45878f1a 100644 --- a/finetrainers/parallel/ptd.py +++ b/finetrainers/parallel/ptd.py @@ -209,7 +209,7 @@ def _get_mesh(): return _get_mesh() def get_checkpointer(self, *args, **kwargs): - return PTDCheckpointer(*args, **kwargs) + return PTDCheckpointer(*args, **kwargs, _parallel_backend=self) @property def world_size(self): @@ -309,7 +309,9 @@ def __init__( enable: bool = True, _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, _prefix: str = "finetrainers_step", + _parallel_backend: Optional["BaseParallelBackend"] = None, ) -> None: + self._parallel_backend = _parallel_backend self.states = states self.states.update( { @@ -319,7 +321,12 @@ def __init__( } ) self.states.update(schedulers.get_lr_scheduler_state()) - + if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker: + wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + self.states["wandb_run_id"] = wandb_run_id + else: + self.states["wandb_run_id"] = None self.checkpointing_steps = checkpointing_steps self.checkpointing_limit = checkpointing_limit self.output_dir = pathlib.Path(output_dir) @@ -385,6 +392,10 @@ def load(self, step: int = -1) -> bool: return True + def get_wandb_run_id_from_checkpoint(self) -> Optional[str]: + """Get the wandb run ID from the loaded checkpoint states.""" + return self.states.get("wandb_run_id", None) + def _should_checkpoint(self, step: int, force: bool) -> bool: if not self.enable: return False diff --git a/finetrainers/trackers.py b/finetrainers/trackers.py index 68a53c5a..5cb4ec8a 100644 --- a/finetrainers/trackers.py +++ b/finetrainers/trackers.py @@ -37,6 +37,10 @@ def log(self, metrics: Dict[str, Any], step: int) -> None: def finish(self) -> None: pass + def get_wandb_run_id(self) -> Optional[str]: + r"""Get the wandb run ID if available.""" + return None + class DummyTracker(BaseTracker): def __init__(self): @@ -52,7 +56,13 @@ def finish(self) -> None: class WandbTracker(BaseTracker): r"""Logger implementation for Weights & Biases.""" - def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + experiment_name: str, + log_dir: str, + config: Optional[Dict[str, Any]] = None, + resume_run_id: Optional[str] = None, + ) -> None: super().__init__() import wandb @@ -62,7 +72,11 @@ def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str # WandB does not create a directory if it does not exist and instead starts using the system temp directory. pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) - self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) + if resume_run_id is not None: + logger.info(f"Resuming WandB run with ID: {resume_run_id}") + self.run = wandb.init(project=experiment_name, dir=log_dir, config=config, id=resume_run_id, resume="must") + else: + self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) logger.info("WandB logging enabled") def log(self, metrics: Dict[str, Any], step: int) -> None: @@ -73,6 +87,15 @@ def log(self, metrics: Dict[str, Any], step: int) -> None: def finish(self) -> None: self.run.finish() + @property + def run_id(self) -> Optional[str]: + """Return the current wandb run ID for checkpointing purposes.""" + return self.run.id if self.run is not None else None + + def get_wandb_run_id(self) -> Optional[str]: + """Return the wandb run ID if this tracker supports it.""" + return self.run_id + class SequentialTracker(BaseTracker): r"""Sequential tracker that logs to multiple trackers in sequence.""" @@ -106,6 +129,14 @@ def finish(self) -> None: for tracker in self.trackers: tracker.finish() + def get_wandb_run_id(self) -> Optional[str]: + """Return the wandb run ID from the first WandB tracker in the sequence.""" + for tracker in self.trackers: + run_id = tracker.get_wandb_run_id() + if run_id is not None: + return run_id + return None + class Trackers(str, Enum): r"""Enum for supported trackers.""" @@ -118,7 +149,11 @@ class Trackers(str, Enum): def initialize_trackers( - trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str + trackers: List[str], + experiment_name: str, + config: Dict[str, Any], + log_dir: str, + resume_run_id: Optional[str] = None, ) -> Union[BaseTracker, SequentialTracker]: r"""Initialize loggers based on the provided configuration.""" @@ -135,7 +170,7 @@ def initialize_trackers( if tracker_name == Trackers.NONE: tracker = BaseTracker() elif tracker_name == Trackers.WANDB: - tracker = WandbTracker(experiment_name, log_dir, config) + tracker = WandbTracker(experiment_name, log_dir, config, resume_run_id=resume_run_id) tracker_instances.append(tracker) tracker = SequentialTracker(tracker_instances) diff --git a/finetrainers/trainer/base.py b/finetrainers/trainer/base.py index 445fc89e..ddbf0d20 100644 --- a/finetrainers/trainer/base.py +++ b/finetrainers/trainer/base.py @@ -1,7 +1,7 @@ import contextlib import functools import os -from typing import Callable, List, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.backends @@ -116,12 +116,16 @@ def _init_logging(self) -> None: logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) logger.info("Initialized FineTrainers") - def _init_trackers(self) -> None: + def _init_trackers(self, resume_run_id: Optional[str] = None) -> None: # TODO(aryan): handle multiple trackers trackers = [self.args.report_to] experiment_name = self.args.tracker_name or "finetrainers-experiment" self.state.parallel_backend.initialize_trackers( - trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir + trackers, + experiment_name=experiment_name, + config=self._get_training_info(), + log_dir=self.args.logging_dir, + resume_run_id=resume_run_id, ) def _init_config_options(self) -> None: diff --git a/finetrainers/trainer/control_trainer/trainer.py b/finetrainers/trainer/control_trainer/trainer.py index 576e17a0..12dd2af5 100644 --- a/finetrainers/trainer/control_trainer/trainer.py +++ b/finetrainers/trainer/control_trainer/trainer.py @@ -262,7 +262,8 @@ def _prepare_for_training(self) -> None: # 3. Initialize trackers, directories and repositories self._init_logging() - self._init_trackers() + if self.args.resume_from_checkpoint is None: + self._init_trackers() self._init_directories_and_repositories() def _prepare_dataset(self) -> None: @@ -373,6 +374,9 @@ def save_model_hook(state_dict: Dict[str, Any]) -> None: resume_from_checkpoint = -1 if resume_from_checkpoint is not None: self.checkpointer.load(resume_from_checkpoint) + # Extract wandb run ID from loaded checkpoint and initialize trackers with it + wandb_run_id = self.checkpointer.get_wandb_run_id_from_checkpoint() + self._init_trackers(resume_run_id=wandb_run_id) def _train(self) -> None: logger.info("Starting training") diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py index 78954596..d6350ee8 100644 --- a/finetrainers/trainer/sft_trainer/trainer.py +++ b/finetrainers/trainer/sft_trainer/trainer.py @@ -230,7 +230,8 @@ def _prepare_for_training(self) -> None: # 3. Initialize trackers, directories and repositories self._init_logging() - self._init_trackers() + if self.args.resume_from_checkpoint is None: + self._init_trackers() self._init_directories_and_repositories() def _prepare_dataset(self) -> None: @@ -324,6 +325,9 @@ def save_model_hook(state_dict: Dict[str, Any]) -> None: resume_from_checkpoint = -1 if resume_from_checkpoint is not None: self.checkpointer.load(resume_from_checkpoint) + # Extract wandb run ID from loaded checkpoint and initialize trackers with it + wandb_run_id = self.checkpointer.get_wandb_run_id_from_checkpoint() + self._init_trackers(resume_run_id=wandb_run_id) def _train(self) -> None: logger.info("Starting training") diff --git a/tests/test_trackers.py b/tests/test_trackers.py index c9fee180..2694c748 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -4,12 +4,18 @@ import tempfile import unittest +import torch from diffusers.utils.testing_utils import CaptureLogger +from finetrainers import BaseArgs, SFTTrainer, TrainingType from finetrainers.trackers import WandbTracker +from tests.trainer import SFTTrainerFastTestsMixin + +from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa os.environ["WANDB_MODE"] = "offline" +os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" class WandbFastTests(unittest.TestCase): @@ -24,3 +30,72 @@ def test_wandb_logdir(self): self.assertTrue(pathlib.Path(tempdir).exists()) self.assertTrue("WandB logging enabled" in cap_log.out) + + +class SFTTrainerLoRAWandbResumeTests(SFTTrainerFastTestsMixin, unittest.TestCase): + model_specification_cls = DummyCogView4ModelSpecification + + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.checkpointing_steps = 5 + args.training_type = TrainingType.LORA + args.rank = 4 + args.lora_alpha = 4 + args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + return args + + def test_wandb_session_resumption_with_checkpoint(self): + """ + Test the core issue: wandb session should be continued when resuming from checkpoint. + + Steps: + 1. Start training for 6 steps (with checkpointing every 5 steps) + 2. Verify checkpoint is created at step 5 + 3. Resume training from checkpoint at step 5 for additional steps + 4. Verify that the same wandb session ID is maintained + """ + for parallel_backend in ("ptd", "accelerate"): + # Phase 1: Initial training run (6 steps, checkpoint at step 5) + args_phase1 = self.get_args() + args_phase1.parallel_backend = parallel_backend + args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5) + + model_specification_1 = self.model_specification_cls() + trainer_phase1 = SFTTrainer(args_phase1, model_specification_1) + trainer_phase1.run() + + # Verify checkpoint was created at step 5 + checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5" + self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}") + + # Extract the wandb run ID from the first training run + # This should be stored in the checkpoint + original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint() + self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint") + + del trainer_phase1 + # Reinitialize process group for resumed training + if parallel_backend != "ptd" and not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") # or 'gloo' for CPU + + # Phase 2: Resume training from the checkpoint + args_phase2 = self.get_args() + args_phase2.parallel_backend = parallel_backend + args_phase2.resume_from_checkpoint = 5 + + model_specification_2 = self.model_specification_cls() + trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) + trainer_phase2.run() + + # Verify that the resumed training uses the same wandb run ID + resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() + + self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") + self.assertEqual( + original_wandb_run_id, + resumed_wandb_run_id, + f"WandB run ID should be the same after resumption. " + f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}", + ) + + del trainer_phase2 diff --git a/tests/trainer/__init__.py b/tests/trainer/__init__.py index e69de29b..5ce37eec 100644 --- a/tests/trainer/__init__.py +++ b/tests/trainer/__init__.py @@ -0,0 +1 @@ +from .test_sft_trainer import SFTTrainerFastTestsMixin