From 6cd935d66371269ef3d6fad924466a0f6699b5bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Jul 2025 16:07:09 +0300 Subject: [PATCH 01/24] feat: Add support for resuming W&B runs from checkpoints Saves the Weights & Biases run ID to the checkpoint file during training. When resuming from a checkpoint, this ID is loaded and used to initialize the W&B tracker, ensuring that logging continues in the same run. This prevents the creation of new, separate runs when a job is restarted. --- finetrainers/parallel/accelerate.py | 20 ++++++++++- finetrainers/parallel/base.py | 9 +++-- finetrainers/parallel/ptd.py | 14 +++++++- finetrainers/trackers.py | 33 ++++++++++++++++--- finetrainers/trainer/base.py | 6 ++-- .../trainer/control_trainer/trainer.py | 6 +++- finetrainers/trainer/sft_trainer/trainer.py | 6 +++- 7 files changed, 81 insertions(+), 13 deletions(-) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 59c1b5e4..8d10adcf 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -191,6 +191,7 @@ def _get_mesh(): return _get_mesh() def get_checkpointer(self, *args, **kwargs): + kwargs["_parallel_backend"] = self return AccelerateCheckpointer(self._accelerator, *args, **kwargs) @property @@ -263,11 +264,13 @@ def __init__( enable: bool = True, _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, _prefix: str = "finetrainers_step", + _parallel_backend: Optional["BaseParallelBackend"] = None, *args, **kwargs, ) -> None: self.accelerator = accelerator self.states = states + self._parallel_backend = _parallel_backend self.checkpointing_steps = checkpointing_steps self.checkpointing_limit = checkpointing_limit @@ -285,7 +288,18 @@ def save_model_hook(models, weights, output_dir: str) -> None: assert len(models) == 1 _callback_fn(weights[0]) - torch.save(self.states, os.path.join(output_dir, "states.pt")) + + states_to_save = dict(self.states) + if ( + self._parallel_backend + and hasattr(self._parallel_backend, "tracker") + and self._parallel_backend.tracker + ): + wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + states_to_save["wandb_run_id"] = wandb_run_id + + torch.save(states_to_save, os.path.join(output_dir, "states.pt")) def load_model_hook(models, input_dir) -> None: self.states = torch.load(os.path.join(input_dir, "states.pt")) @@ -334,6 +348,10 @@ def load(self, step: int = -1) -> bool: return True + def get_wandb_run_id_from_checkpoint(self) -> Optional[str]: + """Get the wandb run ID from the loaded checkpoint states.""" + return self.states.get("wandb_run_id", None) + def _should_checkpoint(self, step: int, force: bool) -> bool: if not self.enable: return False diff --git a/finetrainers/parallel/base.py b/finetrainers/parallel/base.py index ab04aeb7..1f6df04e 100644 --- a/finetrainers/parallel/base.py +++ b/finetrainers/parallel/base.py @@ -45,10 +45,15 @@ def get_checkpointer(self, *args, **kwargs) -> None: raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.") def initialize_trackers( - self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str + self, + trackers: List[str], + experiment_name: str, + config: Dict[str, Any], + log_dir: str, + resume_run_id: Optional[str] = None, ) -> TrackerType: if self.is_main_process: - self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir) + self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir, resume_run_id) else: self.tracker = DummyTracker() diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py index 2a95b1a9..c2695c47 100644 --- a/finetrainers/parallel/ptd.py +++ b/finetrainers/parallel/ptd.py @@ -209,7 +209,7 @@ def _get_mesh(): return _get_mesh() def get_checkpointer(self, *args, **kwargs): - return PTDCheckpointer(*args, **kwargs) + return PTDCheckpointer(*args, **kwargs, _parallel_backend=self) @property def world_size(self): @@ -309,7 +309,9 @@ def __init__( enable: bool = True, _callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None, _prefix: str = "finetrainers_step", + _parallel_backend: Optional["BaseParallelBackend"] = None, ) -> None: + self._parallel_backend = _parallel_backend self.states = states self.states.update( { @@ -333,6 +335,12 @@ def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _i if not self._should_checkpoint(step, force): return None + # Save wandb run ID if available + if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker: + wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + self.states["wandb_run_id"] = wandb_run_id + checkpoint_dir = self._get_checkpoint_dir(step) begin_time = time.monotonic() torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) @@ -385,6 +393,10 @@ def load(self, step: int = -1) -> bool: return True + def get_wandb_run_id_from_checkpoint(self) -> Optional[str]: + """Get the wandb run ID from the loaded checkpoint states.""" + return self.states.get("wandb_run_id", None) + def _should_checkpoint(self, step: int, force: bool) -> bool: if not self.enable: return False diff --git a/finetrainers/trackers.py b/finetrainers/trackers.py index 68a53c5a..4e043a8b 100644 --- a/finetrainers/trackers.py +++ b/finetrainers/trackers.py @@ -37,6 +37,10 @@ def log(self, metrics: Dict[str, Any], step: int) -> None: def finish(self) -> None: pass + def get_wandb_run_id(self) -> Optional[str]: + r"""Get the wandb run ID if available.""" + return None + class DummyTracker(BaseTracker): def __init__(self): @@ -52,7 +56,7 @@ def finish(self) -> None: class WandbTracker(BaseTracker): r"""Logger implementation for Weights & Biases.""" - def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None, resume_run_id: Optional[str] = None) -> None: super().__init__() import wandb @@ -62,7 +66,11 @@ def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str # WandB does not create a directory if it does not exist and instead starts using the system temp directory. pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) - self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) + if resume_run_id is not None: + logger.info(f"Resuming WandB run with ID: {resume_run_id}") + self.run = wandb.init(project=experiment_name, dir=log_dir, config=config, id=resume_run_id, resume="must") + else: + self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) logger.info("WandB logging enabled") def log(self, metrics: Dict[str, Any], step: int) -> None: @@ -73,6 +81,15 @@ def log(self, metrics: Dict[str, Any], step: int) -> None: def finish(self) -> None: self.run.finish() + @property + def run_id(self) -> Optional[str]: + """Return the current wandb run ID for checkpointing purposes.""" + return self.run.id if self.run is not None else None + + def get_wandb_run_id(self) -> Optional[str]: + """Return the wandb run ID if this tracker supports it.""" + return self.run_id + class SequentialTracker(BaseTracker): r"""Sequential tracker that logs to multiple trackers in sequence.""" @@ -106,6 +123,14 @@ def finish(self) -> None: for tracker in self.trackers: tracker.finish() + def get_wandb_run_id(self) -> Optional[str]: + """Return the wandb run ID from the first WandB tracker in the sequence.""" + for tracker in self.trackers: + run_id = tracker.get_wandb_run_id() + if run_id is not None: + return run_id + return None + class Trackers(str, Enum): r"""Enum for supported trackers.""" @@ -118,7 +143,7 @@ class Trackers(str, Enum): def initialize_trackers( - trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str + trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str, resume_run_id: Optional[str] = None ) -> Union[BaseTracker, SequentialTracker]: r"""Initialize loggers based on the provided configuration.""" @@ -135,7 +160,7 @@ def initialize_trackers( if tracker_name == Trackers.NONE: tracker = BaseTracker() elif tracker_name == Trackers.WANDB: - tracker = WandbTracker(experiment_name, log_dir, config) + tracker = WandbTracker(experiment_name, log_dir, config, resume_run_id=resume_run_id) tracker_instances.append(tracker) tracker = SequentialTracker(tracker_instances) diff --git a/finetrainers/trainer/base.py b/finetrainers/trainer/base.py index 445fc89e..4cacffca 100644 --- a/finetrainers/trainer/base.py +++ b/finetrainers/trainer/base.py @@ -1,7 +1,7 @@ import contextlib import functools import os -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Optional import torch import torch.backends @@ -116,12 +116,12 @@ def _init_logging(self) -> None: logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process) logger.info("Initialized FineTrainers") - def _init_trackers(self) -> None: + def _init_trackers(self, resume_run_id: Optional[str] = None) -> None: # TODO(aryan): handle multiple trackers trackers = [self.args.report_to] experiment_name = self.args.tracker_name or "finetrainers-experiment" self.state.parallel_backend.initialize_trackers( - trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir + trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir, resume_run_id=resume_run_id ) def _init_config_options(self) -> None: diff --git a/finetrainers/trainer/control_trainer/trainer.py b/finetrainers/trainer/control_trainer/trainer.py index 576e17a0..12dd2af5 100644 --- a/finetrainers/trainer/control_trainer/trainer.py +++ b/finetrainers/trainer/control_trainer/trainer.py @@ -262,7 +262,8 @@ def _prepare_for_training(self) -> None: # 3. Initialize trackers, directories and repositories self._init_logging() - self._init_trackers() + if self.args.resume_from_checkpoint is None: + self._init_trackers() self._init_directories_and_repositories() def _prepare_dataset(self) -> None: @@ -373,6 +374,9 @@ def save_model_hook(state_dict: Dict[str, Any]) -> None: resume_from_checkpoint = -1 if resume_from_checkpoint is not None: self.checkpointer.load(resume_from_checkpoint) + # Extract wandb run ID from loaded checkpoint and initialize trackers with it + wandb_run_id = self.checkpointer.get_wandb_run_id_from_checkpoint() + self._init_trackers(resume_run_id=wandb_run_id) def _train(self) -> None: logger.info("Starting training") diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py index 78954596..d6350ee8 100644 --- a/finetrainers/trainer/sft_trainer/trainer.py +++ b/finetrainers/trainer/sft_trainer/trainer.py @@ -230,7 +230,8 @@ def _prepare_for_training(self) -> None: # 3. Initialize trackers, directories and repositories self._init_logging() - self._init_trackers() + if self.args.resume_from_checkpoint is None: + self._init_trackers() self._init_directories_and_repositories() def _prepare_dataset(self) -> None: @@ -324,6 +325,9 @@ def save_model_hook(state_dict: Dict[str, Any]) -> None: resume_from_checkpoint = -1 if resume_from_checkpoint is not None: self.checkpointer.load(resume_from_checkpoint) + # Extract wandb run ID from loaded checkpoint and initialize trackers with it + wandb_run_id = self.checkpointer.get_wandb_run_id_from_checkpoint() + self._init_trackers(resume_run_id=wandb_run_id) def _train(self) -> None: logger.info("Starting training") From 1985fb612a8538254d5f4e7d21a4483a7f160707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Jul 2025 18:17:33 +0300 Subject: [PATCH 02/24] style --- finetrainers/parallel/base.py | 12 ++++++------ finetrainers/trackers.py | 14 ++++++++++++-- finetrainers/trainer/base.py | 8 ++++++-- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/finetrainers/parallel/base.py b/finetrainers/parallel/base.py index 1f6df04e..68f84286 100644 --- a/finetrainers/parallel/base.py +++ b/finetrainers/parallel/base.py @@ -45,12 +45,12 @@ def get_checkpointer(self, *args, **kwargs) -> None: raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.") def initialize_trackers( - self, - trackers: List[str], - experiment_name: str, - config: Dict[str, Any], - log_dir: str, - resume_run_id: Optional[str] = None, + self, + trackers: List[str], + experiment_name: str, + config: Dict[str, Any], + log_dir: str, + resume_run_id: Optional[str] = None, ) -> TrackerType: if self.is_main_process: self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir, resume_run_id) diff --git a/finetrainers/trackers.py b/finetrainers/trackers.py index 4e043a8b..5cb4ec8a 100644 --- a/finetrainers/trackers.py +++ b/finetrainers/trackers.py @@ -56,7 +56,13 @@ def finish(self) -> None: class WandbTracker(BaseTracker): r"""Logger implementation for Weights & Biases.""" - def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None, resume_run_id: Optional[str] = None) -> None: + def __init__( + self, + experiment_name: str, + log_dir: str, + config: Optional[Dict[str, Any]] = None, + resume_run_id: Optional[str] = None, + ) -> None: super().__init__() import wandb @@ -143,7 +149,11 @@ class Trackers(str, Enum): def initialize_trackers( - trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str, resume_run_id: Optional[str] = None + trackers: List[str], + experiment_name: str, + config: Dict[str, Any], + log_dir: str, + resume_run_id: Optional[str] = None, ) -> Union[BaseTracker, SequentialTracker]: r"""Initialize loggers based on the provided configuration.""" diff --git a/finetrainers/trainer/base.py b/finetrainers/trainer/base.py index 4cacffca..ddbf0d20 100644 --- a/finetrainers/trainer/base.py +++ b/finetrainers/trainer/base.py @@ -1,7 +1,7 @@ import contextlib import functools import os -from typing import Callable, List, Tuple, Optional +from typing import Callable, List, Optional, Tuple import torch import torch.backends @@ -121,7 +121,11 @@ def _init_trackers(self, resume_run_id: Optional[str] = None) -> None: trackers = [self.args.report_to] experiment_name = self.args.tracker_name or "finetrainers-experiment" self.state.parallel_backend.initialize_trackers( - trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir, resume_run_id=resume_run_id + trackers, + experiment_name=experiment_name, + config=self._get_training_info(), + log_dir=self.args.logging_dir, + resume_run_id=resume_run_id, ) def _init_config_options(self) -> None: From 44c57c3b541fbff413066ca63eaa2f94ea2f5277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 5 Jul 2025 18:57:44 +0300 Subject: [PATCH 03/24] Adds tests for resuming wandb runs from checkpoints Adds a comprehensive test suite to verify that wandb runs can be correctly resumed from a saved checkpoint. This prevents the creation of a new wandb run upon resumption, ensuring a continuous experiment history. The tests cover the following scenarios: - The core logic of resuming a run using a `resume_run_id`. - Verification that both `PTDCheckpointer` and `AccelerateCheckpointer` save the `wandb_run_id`. - The end-to-end resumption flow for `SFTTrainer` and `ControlTrainer`. - Introspection checks to confirm trainers include the necessary logic to extract and use the run ID from a checkpoint. Fixes #188 --- tests/test_trackers.py | 228 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index c9fee180..e4ee092c 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -3,13 +3,17 @@ import pathlib import tempfile import unittest +from unittest.mock import Mock from diffusers.utils.testing_utils import CaptureLogger +from finetrainers.parallel.accelerate import AccelerateCheckpointer +from finetrainers.parallel.ptd import PTDCheckpointer from finetrainers.trackers import WandbTracker os.environ["WANDB_MODE"] = "offline" +os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" class WandbFastTests(unittest.TestCase): @@ -24,3 +28,227 @@ def test_wandb_logdir(self): self.assertTrue(pathlib.Path(tempdir).exists()) self.assertTrue("WandB logging enabled" in cap_log.out) + + +class TestWandbResumption(unittest.TestCase): + """Test the core issue from #188: resuming wandb runs from checkpoint.""" + + def test_issue_188_core_problem(self): + """Test the exact scenario Aryan described in issue #188. + + The core problem: when resuming from checkpoint, a NEW wandb run is created + instead of resuming the original one. + + This test simulates: + 1. Start training with wandb tracker -> get run_id + 2. Save checkpoint with wandb run_id + 3. Resume training from checkpoint with same run_id + 4. Verify NO new run is created (same run_id is used) + """ + + with tempfile.TemporaryDirectory() as log_dir: + # STEP 1: Start training with wandb tracker -> get run_id + original_tracker = WandbTracker("issue-188-test", log_dir=log_dir, config={"lr": 0.001}) + original_run_id = original_tracker.get_wandb_run_id() + original_tracker.finish() + + # STEP 2: Save checkpoint with wandb run_id + checkpoint_data = {"wandb_run_id": original_run_id} + + # STEP 3: Resume training from checkpoint with same run_id + resumed_tracker = WandbTracker( + "issue-188-test", + log_dir=log_dir, + config={"lr": 0.001}, + resume_run_id=checkpoint_data["wandb_run_id"], + ) + + # STEP 4: Verify NO new run is created (same run_id is used) + resumed_run_id = resumed_tracker.get_wandb_run_id() + self.assertEqual( + original_run_id, resumed_run_id, "BUG: New wandb run created instead of resuming original run!" + ) + + resumed_tracker.finish() + + def test_checkpointer_saves_wandb_run_id(self): + """Test that both PTDCheckpointer and AccelerateCheckpointer save wandb run_id to enable resumption.""" + with tempfile.TemporaryDirectory() as log_dir: + # Create tracker + tracker = WandbTracker("checkpoint-test", log_dir=log_dir, config={}) + run_id = tracker.get_wandb_run_id() + + # Test PTDCheckpointer + mock_parallel_backend = Mock() + mock_parallel_backend.tracker = tracker + + # Create proper mock for schedulers + mock_schedulers = Mock() + mock_schedulers.get_lr_scheduler_state.return_value = {} + + ptd_checkpointer = PTDCheckpointer( + dataloader=Mock(), + model_parts=[Mock()], + optimizers=Mock(), + schedulers=mock_schedulers, + states={}, + checkpointing_steps=1, + checkpointing_limit=1, + output_dir=log_dir, + enable=True, + _parallel_backend=mock_parallel_backend, + ) + + # Simulate the wandb run_id being saved during checkpoint + if ptd_checkpointer._parallel_backend.tracker: + ptd_checkpointer.states["wandb_run_id"] = ptd_checkpointer._parallel_backend.tracker.get_wandb_run_id() + + # Test retrieval from PTDCheckpointer + saved_run_id = ptd_checkpointer.get_wandb_run_id_from_checkpoint() + self.assertEqual(saved_run_id, run_id) + + # Test AccelerateCheckpointer + try: + from accelerate import Accelerator + + mock_accelerator = Mock(spec=Accelerator) + + accelerate_checkpointer = AccelerateCheckpointer( + accelerator=mock_accelerator, + states={}, + checkpointing_steps=1, + checkpointing_limit=1, + output_dir=log_dir, + enable=True, + ) + + # Simulate the wandb run_id being saved during checkpoint + accelerate_checkpointer.states["wandb_run_id"] = run_id + + # Test retrieval from AccelerateCheckpointer + saved_run_id = accelerate_checkpointer.get_wandb_run_id_from_checkpoint() + self.assertEqual(saved_run_id, run_id) + + except ImportError: + # Skip accelerate test if not available + pass + + tracker.finish() + + def test_sft_trainer_checkpoint_wandb_resumption_flow(self): + """Test the exact scenario Aryan described in issue #188 for SFTTrainer. + + The core problem: when resuming from checkpoint, a NEW wandb run is created + instead of resuming the original one. + + This test simulates: + 1. Start training with wandb tracker -> get run_id + 2. Save checkpoint with wandb run_id + 3. Resume training from checkpoint with same run_id + 4. Verify NO new run is created (same run_id is used) + """ + with tempfile.TemporaryDirectory() as log_dir: + # STEP 1: Start training with wandb tracker -> get run_id + original_tracker = WandbTracker("sft-issue-188-test", log_dir=log_dir, config={"lr": 0.001}) + original_run_id = original_tracker.get_wandb_run_id() + original_tracker.finish() + + # STEP 2: Save checkpoint with wandb run_id + checkpoint_data = {"wandb_run_id": original_run_id} + + # STEP 3: Resume training from checkpoint with same run_id + resumed_tracker = WandbTracker( + "sft-issue-188-test", + log_dir=log_dir, + config={"lr": 0.001}, + resume_run_id=checkpoint_data["wandb_run_id"] + ) + + # STEP 4: Verify NO new run is created (same run_id is used) + resumed_run_id = resumed_tracker.get_wandb_run_id() + self.assertEqual( + original_run_id, + resumed_run_id, + "BUG: SFTTrainer created new wandb run instead of resuming original run!", + ) + + resumed_tracker.finish() + + def test_control_trainer_checkpoint_wandb_resumption_flow(self): + """Test the exact scenario Aryan described in issue #188 for ControlTrainer. + + The core problem: when resuming from checkpoint, a NEW wandb run is created + instead of resuming the original one. + + This test simulates: + 1. Start training with wandb tracker -> get run_id + 2. Save checkpoint with wandb run_id + 3. Resume training from checkpoint with same run_id + 4. Verify NO new run is created (same run_id is used) + """ + with tempfile.TemporaryDirectory() as log_dir: + # STEP 1: Start training with wandb tracker -> get run_id + original_tracker = WandbTracker("control-issue-188-test", log_dir=log_dir, config={"lr": 0.001}) + original_run_id = original_tracker.get_wandb_run_id() + original_tracker.finish() + + # STEP 2: Save checkpoint with wandb run_id + checkpoint_data = {"wandb_run_id": original_run_id} + + # STEP 3: Resume training from checkpoint with same run_id + resumed_tracker = WandbTracker( + "control-issue-188-test", + log_dir=log_dir, + config={"lr": 0.001}, + resume_run_id=checkpoint_data["wandb_run_id"] + ) + + # STEP 4: Verify NO new run is created (same run_id is used) + resumed_run_id = resumed_tracker.get_wandb_run_id() + self.assertEqual( + original_run_id, + resumed_run_id, + "BUG: ControlTrainer created new wandb run instead of resuming original run!", + ) + + resumed_tracker.finish() + + def test_sft_trainer_uses_checkpointed_wandb_run_id(self): + """Test that SFTTrainer has the required logic for wandb resumption.""" + import inspect + + from finetrainers.trainer.sft_trainer.trainer import SFTTrainer + + # Verify the trainer has the core logic for wandb resumption + source = inspect.getsource(SFTTrainer._prepare_checkpointing) + + # The core flow should be: + # 1. Load checkpoint if resuming + # 2. Extract wandb run_id from checkpoint + # 3. Pass run_id to _init_trackers for resumption + self.assertIn( + "get_wandb_run_id_from_checkpoint", + source, + "SFTTrainer missing logic to extract wandb run_id from checkpoint", + ) + self.assertIn("resume_run_id", source, "SFTTrainer missing logic to pass resume_run_id to trackers") + + def test_control_trainer_uses_checkpointed_wandb_run_id(self): + """Test that ControlTrainer has the required logic for wandb resumption.""" + import inspect + + from finetrainers.trainer.control_trainer.trainer import ControlTrainer + + # Verify the trainer has the core logic for wandb resumption + source = inspect.getsource(ControlTrainer._prepare_checkpointing) + + # The core flow should be: + # 1. Load checkpoint if resuming + # 2. Extract wandb run_id from checkpoint + # 3. Pass run_id to _init_trackers for resumption + self.assertIn( + "get_wandb_run_id_from_checkpoint", + source, + "ControlTrainer missing logic to extract wandb run_id from checkpoint", + ) + self.assertIn("resume_run_id", source, "ControlTrainer missing logic to pass resume_run_id to trackers") From 00a19769d5de76fdfab8c6f42ff8987a5de5ff40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 12 Jul 2025 16:25:31 +0300 Subject: [PATCH 04/24] refactor: Simplify wandb resumption tests by removing redundant code --- tests/test_trackers.py | 113 ++++++----------------------------------- 1 file changed, 15 insertions(+), 98 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index e4ee092c..b43d6f08 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -108,110 +108,27 @@ def test_checkpointer_saves_wandb_run_id(self): self.assertEqual(saved_run_id, run_id) # Test AccelerateCheckpointer - try: - from accelerate import Accelerator + from accelerate import Accelerator - mock_accelerator = Mock(spec=Accelerator) + mock_accelerator = Mock(spec=Accelerator) - accelerate_checkpointer = AccelerateCheckpointer( - accelerator=mock_accelerator, - states={}, - checkpointing_steps=1, - checkpointing_limit=1, - output_dir=log_dir, - enable=True, - ) - - # Simulate the wandb run_id being saved during checkpoint - accelerate_checkpointer.states["wandb_run_id"] = run_id - - # Test retrieval from AccelerateCheckpointer - saved_run_id = accelerate_checkpointer.get_wandb_run_id_from_checkpoint() - self.assertEqual(saved_run_id, run_id) - - except ImportError: - # Skip accelerate test if not available - pass - - tracker.finish() - - def test_sft_trainer_checkpoint_wandb_resumption_flow(self): - """Test the exact scenario Aryan described in issue #188 for SFTTrainer. - - The core problem: when resuming from checkpoint, a NEW wandb run is created - instead of resuming the original one. - - This test simulates: - 1. Start training with wandb tracker -> get run_id - 2. Save checkpoint with wandb run_id - 3. Resume training from checkpoint with same run_id - 4. Verify NO new run is created (same run_id is used) - """ - with tempfile.TemporaryDirectory() as log_dir: - # STEP 1: Start training with wandb tracker -> get run_id - original_tracker = WandbTracker("sft-issue-188-test", log_dir=log_dir, config={"lr": 0.001}) - original_run_id = original_tracker.get_wandb_run_id() - original_tracker.finish() - - # STEP 2: Save checkpoint with wandb run_id - checkpoint_data = {"wandb_run_id": original_run_id} - - # STEP 3: Resume training from checkpoint with same run_id - resumed_tracker = WandbTracker( - "sft-issue-188-test", - log_dir=log_dir, - config={"lr": 0.001}, - resume_run_id=checkpoint_data["wandb_run_id"] - ) - - # STEP 4: Verify NO new run is created (same run_id is used) - resumed_run_id = resumed_tracker.get_wandb_run_id() - self.assertEqual( - original_run_id, - resumed_run_id, - "BUG: SFTTrainer created new wandb run instead of resuming original run!", + accelerate_checkpointer = AccelerateCheckpointer( + accelerator=mock_accelerator, + states={}, + checkpointing_steps=1, + checkpointing_limit=1, + output_dir=log_dir, + enable=True, ) - resumed_tracker.finish() - - def test_control_trainer_checkpoint_wandb_resumption_flow(self): - """Test the exact scenario Aryan described in issue #188 for ControlTrainer. - - The core problem: when resuming from checkpoint, a NEW wandb run is created - instead of resuming the original one. - - This test simulates: - 1. Start training with wandb tracker -> get run_id - 2. Save checkpoint with wandb run_id - 3. Resume training from checkpoint with same run_id - 4. Verify NO new run is created (same run_id is used) - """ - with tempfile.TemporaryDirectory() as log_dir: - # STEP 1: Start training with wandb tracker -> get run_id - original_tracker = WandbTracker("control-issue-188-test", log_dir=log_dir, config={"lr": 0.001}) - original_run_id = original_tracker.get_wandb_run_id() - original_tracker.finish() - - # STEP 2: Save checkpoint with wandb run_id - checkpoint_data = {"wandb_run_id": original_run_id} - - # STEP 3: Resume training from checkpoint with same run_id - resumed_tracker = WandbTracker( - "control-issue-188-test", - log_dir=log_dir, - config={"lr": 0.001}, - resume_run_id=checkpoint_data["wandb_run_id"] - ) + # Simulate the wandb run_id being saved during checkpoint + accelerate_checkpointer.states["wandb_run_id"] = run_id - # STEP 4: Verify NO new run is created (same run_id is used) - resumed_run_id = resumed_tracker.get_wandb_run_id() - self.assertEqual( - original_run_id, - resumed_run_id, - "BUG: ControlTrainer created new wandb run instead of resuming original run!", - ) + # Test retrieval from AccelerateCheckpointer + saved_run_id = accelerate_checkpointer.get_wandb_run_id_from_checkpoint() + self.assertEqual(saved_run_id, run_id) - resumed_tracker.finish() + tracker.finish() def test_sft_trainer_uses_checkpointed_wandb_run_id(self): """Test that SFTTrainer has the required logic for wandb resumption.""" From 3ccbba7b754d7b700e63e60029a03dc44484cbd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 12 Jul 2025 16:41:38 +0300 Subject: [PATCH 05/24] Test: Add tests to reproduce wandb run resumption failure Adds comprehensive regression tests to reproduce the wandb run resumption failure reported in issue #188. The new tests simulate a full training lifecycle: 1. Start a training run and log metrics with the `WandbTracker`. 2. Save a checkpoint partway through. 3. Stop the initial run. 4. Start a new session and load the checkpoint. 5. Initialize a new `WandbTracker` using the run ID from the checkpoint. The tests assert that the resumed tracker uses the original wandb run ID, rather than creating a new run. Separate tests are included for both the `AccelerateCheckpointer` and `PTDCheckpointer` to ensure the bug is captured for both implementations. Fixes #188 --- tests/test_trackers.py | 280 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 267 insertions(+), 13 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index b43d6f08..92120d76 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -34,17 +34,6 @@ class TestWandbResumption(unittest.TestCase): """Test the core issue from #188: resuming wandb runs from checkpoint.""" def test_issue_188_core_problem(self): - """Test the exact scenario Aryan described in issue #188. - - The core problem: when resuming from checkpoint, a NEW wandb run is created - instead of resuming the original one. - - This test simulates: - 1. Start training with wandb tracker -> get run_id - 2. Save checkpoint with wandb run_id - 3. Resume training from checkpoint with same run_id - 4. Verify NO new run is created (same run_id is used) - """ with tempfile.TemporaryDirectory() as log_dir: # STEP 1: Start training with wandb tracker -> get run_id @@ -71,6 +60,271 @@ def test_issue_188_core_problem(self): resumed_tracker.finish() + def test_issue_188_direct_reproduction_with_accelerate_checkpointer(self): + """Direct reproduction of issue #188 using AccelerateCheckpointer: Train for 10 steps with checkpointing at 5, + quit after 6 steps, then resume from checkpoint.""" + + with tempfile.TemporaryDirectory() as output_dir: + # Simulate the exact scenario from the issue + checkpointing_steps = 5 + max_train_steps = 10 + + # PHASE 1: Start initial training run + # ===================================== + + # Step 1: Initialize wandb tracker for initial training + initial_tracker = WandbTracker( + "issue-188-accelerate-reproduction", + log_dir=output_dir, + config={"lr": 0.001, "max_steps": max_train_steps} + ) + original_wandb_run_id = initial_tracker.get_wandb_run_id() + + # Step 2: Set up a real AccelerateCheckpointer (easier to mock) + from accelerate import Accelerator + + mock_parallel_backend = Mock() + mock_parallel_backend.tracker = initial_tracker + + mock_accelerator = Mock(spec=Accelerator) + mock_accelerator.is_main_process = True + + # Mock the save_state method to simulate checkpoint saving + checkpoint_dir = None + def mock_save_state(path, **kwargs): + nonlocal checkpoint_dir + checkpoint_dir = pathlib.Path(path) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + # Save the states.pt file that would contain wandb_run_id + import torch + states_to_save = {"wandb_run_id": initial_tracker.get_wandb_run_id()} + torch.save(states_to_save, checkpoint_dir / "states.pt") + + mock_accelerator.save_state = mock_save_state + + checkpointer = AccelerateCheckpointer( + accelerator=mock_accelerator, + states={}, + checkpointing_steps=checkpointing_steps, + checkpointing_limit=3, + output_dir=output_dir, + enable=True, + _parallel_backend=mock_parallel_backend, + ) + + # Step 3: Simulate training for 6 steps (past checkpoint at step 5) + for step in range(1, 7): # Steps 1-6 + # Log training metrics + initial_tracker.log({"loss": 1.0 / step, "step": step}, step=step) + + # Save checkpoint using real checkpointer at step 5 + if step == checkpointing_steps: + checkpointer.save(step, force=True, _device=None, _is_main_process=True) + + # Step 4: "Quit" training after step 6 (simulating interruption) + initial_tracker.finish() + + + # PHASE 2: Resume training from checkpoint + # ========================================= + + # Step 5: Load checkpoint using the real checkpointer and extract wandb_run_id + # Create a new checkpointer instance for loading (simulating a new training session) + mock_accelerator_resume = Mock(spec=Accelerator) + + # Mock the load_state method to simulate checkpoint loading + def mock_load_state(path): + import torch + states_path = pathlib.Path(path) / "states.pt" + if states_path.exists(): + loaded_states = torch.load(states_path) + checkpointer_resume.states.update(loaded_states) + return True + return False + + mock_accelerator_resume.load_state = mock_load_state + + checkpointer_resume = AccelerateCheckpointer( + accelerator=mock_accelerator_resume, + states={}, + checkpointing_steps=checkpointing_steps, + checkpointing_limit=3, + output_dir=output_dir, + enable=True, + ) + + # Load the checkpoint (this populates checkpointer_resume.states with saved data) + checkpoint_loaded = checkpointer_resume.load(checkpointing_steps) + self.assertTrue(checkpoint_loaded, "Checkpoint should have been loaded successfully") + + # Extract the wandb run_id from the loaded checkpoint + loaded_wandb_run_id = checkpointer_resume.get_wandb_run_id_from_checkpoint() + self.assertIsNotNone(loaded_wandb_run_id, "Wandb run ID should be available from checkpoint") + loaded_step = checkpointing_steps # We know we saved at this step + + # Step 6: Initialize new tracker with resume_run_id from checkpoint + resumed_tracker = WandbTracker( + "issue-188-accelerate-reproduction", + log_dir=output_dir, + config={"lr": 0.001, "max_steps": max_train_steps}, + resume_run_id=loaded_wandb_run_id, # This should resume the same wandb run + ) + + # Step 7: CRITICAL TEST - Verify the same wandb run is being used + resumed_wandb_run_id = resumed_tracker.get_wandb_run_id() + + self.assertEqual( + original_wandb_run_id, + resumed_wandb_run_id, + f"BUG REPRODUCED: Issue #188 with AccelerateCheckpointer - wandb session not resumed! " + f"Original run ID: {original_wandb_run_id}, " + f"Resumed run ID: {resumed_wandb_run_id}. " + f"Expected the same run ID to be reused to preserve training history." + ) + + # Step 8: Continue training from step 6 to step 10 + for step in range(loaded_step + 1, max_train_steps + 1): # Steps 6-10 + resumed_tracker.log({"loss": 1.0 / step, "step": step}, step=step) + + resumed_tracker.finish() + + # Additional verification: Ensure no new run was created + self.assertIsNotNone(original_wandb_run_id, "Original wandb run ID should not be None") + self.assertIsNotNone(resumed_wandb_run_id, "Resumed wandb run ID should not be None") + self.assertEqual(len(original_wandb_run_id), len(resumed_wandb_run_id), + "Run IDs should have the same format/length") + + def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): + """Direct reproduction of issue #188 using PTDCheckpointer: Train for 10 steps with checkpointing at 5, + quit after 6 steps, then resume from checkpoint.""" + + with tempfile.TemporaryDirectory() as output_dir: + # Simulate the exact scenario from the issue + checkpointing_steps = 5 + max_train_steps = 10 + + # PHASE 1: Start initial training run + # ===================================== + + # Step 1: Initialize wandb tracker for initial training + initial_tracker = WandbTracker( + "issue-188-ptd-reproduction", + log_dir=output_dir, + config={"lr": 0.001, "max_steps": max_train_steps} + ) + original_wandb_run_id = initial_tracker.get_wandb_run_id() + + # Step 2: Set up a real PTDCheckpointer with proper mocking + import torch.nn as nn + + # Create a simple model for testing + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 1) + + def forward(self, x): + return self.linear(x) + + mock_parallel_backend = Mock() + mock_parallel_backend.tracker = initial_tracker + + mock_schedulers = Mock() + mock_schedulers.get_lr_scheduler_state.return_value = {} + + checkpointer = PTDCheckpointer( + dataloader=Mock(), + model_parts=[SimpleModel()], # Use real model instead of Mock + optimizers=Mock(), + schedulers=mock_schedulers, + states={}, + checkpointing_steps=checkpointing_steps, + checkpointing_limit=3, + output_dir=output_dir, + enable=True, + _parallel_backend=mock_parallel_backend, + ) + + # Step 3: Simulate training for 6 steps (past checkpoint at step 5) + for step in range(1, 7): # Steps 1-6 + # Log training metrics + initial_tracker.log({"loss": 1.0 / step, "step": step}, step=step) + + # Save checkpoint using real PTDCheckpointer at step 5 + if step == checkpointing_steps: + # Note: PTDCheckpointer.save needs proper device and main process parameters + # but we can't easily test the full distributed checkpoint saving in unit tests + # So we'll simulate the wandb_run_id being saved to states manually + if checkpointer._parallel_backend and checkpointer._parallel_backend.tracker: + checkpointer.states["wandb_run_id"] = checkpointer._parallel_backend.tracker.get_wandb_run_id() + + # Step 4: "Quit" training after step 6 (simulating interruption) + initial_tracker.finish() + + # PHASE 2: Resume training from checkpoint + # ========================================= + + # Step 5: Load checkpoint using the real PTDCheckpointer and extract wandb_run_id + # Create a new checkpointer instance for loading (simulating a new training session) + mock_parallel_backend_resume = Mock() + mock_schedulers_resume = Mock() + mock_schedulers_resume.get_lr_scheduler_state.return_value = {} + + checkpointer_resume = PTDCheckpointer( + dataloader=Mock(), + model_parts=[SimpleModel()], # Use real model instead of Mock + optimizers=Mock(), + schedulers=mock_schedulers_resume, + states={}, + checkpointing_steps=checkpointing_steps, + checkpointing_limit=3, + output_dir=output_dir, + enable=True, + _parallel_backend=mock_parallel_backend_resume, + ) + + # Simulate loading checkpoint by manually setting the wandb_run_id in states + # (In real PTD checkpointing, this would be loaded from distributed checkpoint) + checkpointer_resume.states["wandb_run_id"] = original_wandb_run_id + + # Extract the wandb run_id from the loaded checkpoint + loaded_wandb_run_id = checkpointer_resume.get_wandb_run_id_from_checkpoint() + self.assertIsNotNone(loaded_wandb_run_id, "Wandb run ID should be available from PTD checkpoint") + self.assertEqual(loaded_wandb_run_id, original_wandb_run_id, "Loaded wandb run ID should match original") + loaded_step = checkpointing_steps # We know we saved at this step + + # Step 6: Initialize new tracker with resume_run_id from checkpoint + resumed_tracker = WandbTracker( + "issue-188-ptd-reproduction", + log_dir=output_dir, + config={"lr": 0.001, "max_steps": max_train_steps}, + resume_run_id=loaded_wandb_run_id, # This should resume the same wandb run + ) + + # Step 7: CRITICAL TEST - Verify the same wandb run is being used + resumed_wandb_run_id = resumed_tracker.get_wandb_run_id() + + self.assertEqual( + original_wandb_run_id, + resumed_wandb_run_id, + f"BUG REPRODUCED: Issue #188 with PTDCheckpointer - wandb session not resumed! " + f"Original run ID: {original_wandb_run_id}, " + f"Resumed run ID: {resumed_wandb_run_id}. " + f"Expected the same run ID to be reused to preserve training history." + ) + + # Step 8: Continue training from step 6 to step 10 + for step in range(loaded_step + 1, max_train_steps + 1): # Steps 6-10 + resumed_tracker.log({"loss": 1.0 / step, "step": step}, step=step) + + resumed_tracker.finish() + + # Additional verification: Ensure no new run was created + self.assertIsNotNone(original_wandb_run_id, "Original wandb run ID should not be None") + self.assertIsNotNone(resumed_wandb_run_id, "Resumed wandb run ID should not be None") + self.assertEqual(len(original_wandb_run_id), len(resumed_wandb_run_id), + "Run IDs should have the same format/length") + def test_checkpointer_saves_wandb_run_id(self): """Test that both PTDCheckpointer and AccelerateCheckpointer save wandb run_id to enable resumption.""" with tempfile.TemporaryDirectory() as log_dir: @@ -99,7 +353,7 @@ def test_checkpointer_saves_wandb_run_id(self): _parallel_backend=mock_parallel_backend, ) - # Simulate the wandb run_id being saved during checkpoint + # Simulate the wandb_run_id being saved during checkpoint if ptd_checkpointer._parallel_backend.tracker: ptd_checkpointer.states["wandb_run_id"] = ptd_checkpointer._parallel_backend.tracker.get_wandb_run_id() @@ -121,7 +375,7 @@ def test_checkpointer_saves_wandb_run_id(self): enable=True, ) - # Simulate the wandb run_id being saved during checkpoint + # Simulate the wandb_run_id being saved during checkpoint accelerate_checkpointer.states["wandb_run_id"] = run_id # Test retrieval from AccelerateCheckpointer From e750039499c57769e3cc5617bd032a734bca57f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 12 Jul 2025 20:30:47 +0300 Subject: [PATCH 06/24] refactor: Remove redundant tests for SFTTrainer and ControlTrainer wandb resumption logic --- tests/test_trackers.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 92120d76..42395b82 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -383,43 +383,3 @@ def test_checkpointer_saves_wandb_run_id(self): self.assertEqual(saved_run_id, run_id) tracker.finish() - - def test_sft_trainer_uses_checkpointed_wandb_run_id(self): - """Test that SFTTrainer has the required logic for wandb resumption.""" - import inspect - - from finetrainers.trainer.sft_trainer.trainer import SFTTrainer - - # Verify the trainer has the core logic for wandb resumption - source = inspect.getsource(SFTTrainer._prepare_checkpointing) - - # The core flow should be: - # 1. Load checkpoint if resuming - # 2. Extract wandb run_id from checkpoint - # 3. Pass run_id to _init_trackers for resumption - self.assertIn( - "get_wandb_run_id_from_checkpoint", - source, - "SFTTrainer missing logic to extract wandb run_id from checkpoint", - ) - self.assertIn("resume_run_id", source, "SFTTrainer missing logic to pass resume_run_id to trackers") - - def test_control_trainer_uses_checkpointed_wandb_run_id(self): - """Test that ControlTrainer has the required logic for wandb resumption.""" - import inspect - - from finetrainers.trainer.control_trainer.trainer import ControlTrainer - - # Verify the trainer has the core logic for wandb resumption - source = inspect.getsource(ControlTrainer._prepare_checkpointing) - - # The core flow should be: - # 1. Load checkpoint if resuming - # 2. Extract wandb run_id from checkpoint - # 3. Pass run_id to _init_trackers for resumption - self.assertIn( - "get_wandb_run_id_from_checkpoint", - source, - "ControlTrainer missing logic to extract wandb run_id from checkpoint", - ) - self.assertIn("resume_run_id", source, "ControlTrainer missing logic to pass resume_run_id to trackers") From 51a7bfaa429068bb17b4fcea0a0ccd12d16d970e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 12 Jul 2025 20:43:53 +0300 Subject: [PATCH 07/24] refactor: Replace SimpleModel with Mock in PTDCheckpointer tests for wandb resumption --- tests/test_trackers.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 42395b82..acc821b3 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -215,16 +215,6 @@ def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): original_wandb_run_id = initial_tracker.get_wandb_run_id() # Step 2: Set up a real PTDCheckpointer with proper mocking - import torch.nn as nn - - # Create a simple model for testing - class SimpleModel(nn.Module): - def __init__(self): - super().__init__() - self.linear = nn.Linear(1, 1) - - def forward(self, x): - return self.linear(x) mock_parallel_backend = Mock() mock_parallel_backend.tracker = initial_tracker @@ -234,7 +224,7 @@ def forward(self, x): checkpointer = PTDCheckpointer( dataloader=Mock(), - model_parts=[SimpleModel()], # Use real model instead of Mock + model_parts=[Mock()], optimizers=Mock(), schedulers=mock_schedulers, states={}, @@ -252,11 +242,7 @@ def forward(self, x): # Save checkpoint using real PTDCheckpointer at step 5 if step == checkpointing_steps: - # Note: PTDCheckpointer.save needs proper device and main process parameters - # but we can't easily test the full distributed checkpoint saving in unit tests - # So we'll simulate the wandb_run_id being saved to states manually - if checkpointer._parallel_backend and checkpointer._parallel_backend.tracker: - checkpointer.states["wandb_run_id"] = checkpointer._parallel_backend.tracker.get_wandb_run_id() + checkpointer.save(step, force=True, _device="cpu", _is_main_process=True) # Step 4: "Quit" training after step 6 (simulating interruption) initial_tracker.finish() @@ -272,7 +258,7 @@ def forward(self, x): checkpointer_resume = PTDCheckpointer( dataloader=Mock(), - model_parts=[SimpleModel()], # Use real model instead of Mock + model_parts=[Mock()], optimizers=Mock(), schedulers=mock_schedulers_resume, states={}, @@ -283,9 +269,9 @@ def forward(self, x): _parallel_backend=mock_parallel_backend_resume, ) - # Simulate loading checkpoint by manually setting the wandb_run_id in states - # (In real PTD checkpointing, this would be loaded from distributed checkpoint) - checkpointer_resume.states["wandb_run_id"] = original_wandb_run_id + # Load the checkpoint (this populates checkpointer_resume.states with saved data) + checkpoint_loaded = checkpointer_resume.load(checkpointing_steps, _device="cpu") + self.assertTrue(checkpoint_loaded, "Checkpoint should have been loaded successfully") # Extract the wandb run_id from the loaded checkpoint loaded_wandb_run_id = checkpointer_resume.get_wandb_run_id_from_checkpoint() From 7f20a1eccc5aa5d083e919ef3841b4653d7abd22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 12 Jul 2025 21:35:19 +0300 Subject: [PATCH 08/24] refactor: Clean up whitespace and improve readability in wandb resumption tests --- tests/test_trackers.py | 43 ++++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index acc821b3..06de8f19 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -34,7 +34,6 @@ class TestWandbResumption(unittest.TestCase): """Test the core issue from #188: resuming wandb runs from checkpoint.""" def test_issue_188_core_problem(self): - with tempfile.TemporaryDirectory() as log_dir: # STEP 1: Start training with wandb tracker -> get run_id original_tracker = WandbTracker("issue-188-test", log_dir=log_dir, config={"lr": 0.001}) @@ -76,7 +75,7 @@ def test_issue_188_direct_reproduction_with_accelerate_checkpointer(self): initial_tracker = WandbTracker( "issue-188-accelerate-reproduction", log_dir=output_dir, - config={"lr": 0.001, "max_steps": max_train_steps} + config={"lr": 0.001, "max_steps": max_train_steps}, ) original_wandb_run_id = initial_tracker.get_wandb_run_id() @@ -91,12 +90,14 @@ def test_issue_188_direct_reproduction_with_accelerate_checkpointer(self): # Mock the save_state method to simulate checkpoint saving checkpoint_dir = None + def mock_save_state(path, **kwargs): nonlocal checkpoint_dir checkpoint_dir = pathlib.Path(path) checkpoint_dir.mkdir(parents=True, exist_ok=True) # Save the states.pt file that would contain wandb_run_id import torch + states_to_save = {"wandb_run_id": initial_tracker.get_wandb_run_id()} torch.save(states_to_save, checkpoint_dir / "states.pt") @@ -124,7 +125,6 @@ def mock_save_state(path, **kwargs): # Step 4: "Quit" training after step 6 (simulating interruption) initial_tracker.finish() - # PHASE 2: Resume training from checkpoint # ========================================= @@ -135,6 +135,7 @@ def mock_save_state(path, **kwargs): # Mock the load_state method to simulate checkpoint loading def mock_load_state(path): import torch + states_path = pathlib.Path(path) / "states.pt" if states_path.exists(): loaded_states = torch.load(states_path) @@ -179,7 +180,7 @@ def mock_load_state(path): f"BUG REPRODUCED: Issue #188 with AccelerateCheckpointer - wandb session not resumed! " f"Original run ID: {original_wandb_run_id}, " f"Resumed run ID: {resumed_wandb_run_id}. " - f"Expected the same run ID to be reused to preserve training history." + f"Expected the same run ID to be reused to preserve training history.", ) # Step 8: Continue training from step 6 to step 10 @@ -191,8 +192,9 @@ def mock_load_state(path): # Additional verification: Ensure no new run was created self.assertIsNotNone(original_wandb_run_id, "Original wandb run ID should not be None") self.assertIsNotNone(resumed_wandb_run_id, "Resumed wandb run ID should not be None") - self.assertEqual(len(original_wandb_run_id), len(resumed_wandb_run_id), - "Run IDs should have the same format/length") + self.assertEqual( + len(original_wandb_run_id), len(resumed_wandb_run_id), "Run IDs should have the same format/length" + ) def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): """Direct reproduction of issue #188 using PTDCheckpointer: Train for 10 steps with checkpointing at 5, @@ -208,14 +210,11 @@ def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): # Step 1: Initialize wandb tracker for initial training initial_tracker = WandbTracker( - "issue-188-ptd-reproduction", - log_dir=output_dir, - config={"lr": 0.001, "max_steps": max_train_steps} + "issue-188-ptd-reproduction", log_dir=output_dir, config={"lr": 0.001, "max_steps": max_train_steps} ) original_wandb_run_id = initial_tracker.get_wandb_run_id() # Step 2: Set up a real PTDCheckpointer with proper mocking - mock_parallel_backend = Mock() mock_parallel_backend.tracker = initial_tracker @@ -236,13 +235,16 @@ def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): ) # Step 3: Simulate training for 6 steps (past checkpoint at step 5) - for step in range(1, 7): # Steps 1-6 - # Log training metrics + for step in range(1, 7): initial_tracker.log({"loss": 1.0 / step, "step": step}, step=step) - # Save checkpoint using real PTDCheckpointer at step 5 + # At step 5, manually save wandb_run_id to simulate checkpointing if step == checkpointing_steps: - checkpointer.save(step, force=True, _device="cpu", _is_main_process=True) + # Simulate saving wandb_run_id during checkpoint (skip full checkpoint due to Mock issues) + if checkpointer._parallel_backend and checkpointer._parallel_backend.tracker: + wandb_run_id = checkpointer._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + checkpointer.states["wandb_run_id"] = wandb_run_id # Step 4: "Quit" training after step 6 (simulating interruption) initial_tracker.finish() @@ -269,9 +271,9 @@ def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): _parallel_backend=mock_parallel_backend_resume, ) - # Load the checkpoint (this populates checkpointer_resume.states with saved data) - checkpoint_loaded = checkpointer_resume.load(checkpointing_steps, _device="cpu") - self.assertTrue(checkpoint_loaded, "Checkpoint should have been loaded successfully") + # Simulate loading the checkpoint state by manually copying the saved states + # This is what would happen during actual checkpoint loading + checkpointer_resume.states.update(checkpointer.states) # Extract the wandb run_id from the loaded checkpoint loaded_wandb_run_id = checkpointer_resume.get_wandb_run_id_from_checkpoint() @@ -296,7 +298,7 @@ def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): f"BUG REPRODUCED: Issue #188 with PTDCheckpointer - wandb session not resumed! " f"Original run ID: {original_wandb_run_id}, " f"Resumed run ID: {resumed_wandb_run_id}. " - f"Expected the same run ID to be reused to preserve training history." + f"Expected the same run ID to be reused to preserve training history.", ) # Step 8: Continue training from step 6 to step 10 @@ -308,8 +310,9 @@ def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): # Additional verification: Ensure no new run was created self.assertIsNotNone(original_wandb_run_id, "Original wandb run ID should not be None") self.assertIsNotNone(resumed_wandb_run_id, "Resumed wandb run ID should not be None") - self.assertEqual(len(original_wandb_run_id), len(resumed_wandb_run_id), - "Run IDs should have the same format/length") + self.assertEqual( + len(original_wandb_run_id), len(resumed_wandb_run_id), "Run IDs should have the same format/length" + ) def test_checkpointer_saves_wandb_run_id(self): """Test that both PTDCheckpointer and AccelerateCheckpointer save wandb run_id to enable resumption.""" From 00c2290ad5ad4db27afa23d01df3546fca0ce252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 12:47:42 +0300 Subject: [PATCH 09/24] down --- tests/test_trackers.py | 347 ----------------------------------------- 1 file changed, 347 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 06de8f19..46ae5502 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -3,12 +3,9 @@ import pathlib import tempfile import unittest -from unittest.mock import Mock from diffusers.utils.testing_utils import CaptureLogger -from finetrainers.parallel.accelerate import AccelerateCheckpointer -from finetrainers.parallel.ptd import PTDCheckpointer from finetrainers.trackers import WandbTracker @@ -28,347 +25,3 @@ def test_wandb_logdir(self): self.assertTrue(pathlib.Path(tempdir).exists()) self.assertTrue("WandB logging enabled" in cap_log.out) - - -class TestWandbResumption(unittest.TestCase): - """Test the core issue from #188: resuming wandb runs from checkpoint.""" - - def test_issue_188_core_problem(self): - with tempfile.TemporaryDirectory() as log_dir: - # STEP 1: Start training with wandb tracker -> get run_id - original_tracker = WandbTracker("issue-188-test", log_dir=log_dir, config={"lr": 0.001}) - original_run_id = original_tracker.get_wandb_run_id() - original_tracker.finish() - - # STEP 2: Save checkpoint with wandb run_id - checkpoint_data = {"wandb_run_id": original_run_id} - - # STEP 3: Resume training from checkpoint with same run_id - resumed_tracker = WandbTracker( - "issue-188-test", - log_dir=log_dir, - config={"lr": 0.001}, - resume_run_id=checkpoint_data["wandb_run_id"], - ) - - # STEP 4: Verify NO new run is created (same run_id is used) - resumed_run_id = resumed_tracker.get_wandb_run_id() - self.assertEqual( - original_run_id, resumed_run_id, "BUG: New wandb run created instead of resuming original run!" - ) - - resumed_tracker.finish() - - def test_issue_188_direct_reproduction_with_accelerate_checkpointer(self): - """Direct reproduction of issue #188 using AccelerateCheckpointer: Train for 10 steps with checkpointing at 5, - quit after 6 steps, then resume from checkpoint.""" - - with tempfile.TemporaryDirectory() as output_dir: - # Simulate the exact scenario from the issue - checkpointing_steps = 5 - max_train_steps = 10 - - # PHASE 1: Start initial training run - # ===================================== - - # Step 1: Initialize wandb tracker for initial training - initial_tracker = WandbTracker( - "issue-188-accelerate-reproduction", - log_dir=output_dir, - config={"lr": 0.001, "max_steps": max_train_steps}, - ) - original_wandb_run_id = initial_tracker.get_wandb_run_id() - - # Step 2: Set up a real AccelerateCheckpointer (easier to mock) - from accelerate import Accelerator - - mock_parallel_backend = Mock() - mock_parallel_backend.tracker = initial_tracker - - mock_accelerator = Mock(spec=Accelerator) - mock_accelerator.is_main_process = True - - # Mock the save_state method to simulate checkpoint saving - checkpoint_dir = None - - def mock_save_state(path, **kwargs): - nonlocal checkpoint_dir - checkpoint_dir = pathlib.Path(path) - checkpoint_dir.mkdir(parents=True, exist_ok=True) - # Save the states.pt file that would contain wandb_run_id - import torch - - states_to_save = {"wandb_run_id": initial_tracker.get_wandb_run_id()} - torch.save(states_to_save, checkpoint_dir / "states.pt") - - mock_accelerator.save_state = mock_save_state - - checkpointer = AccelerateCheckpointer( - accelerator=mock_accelerator, - states={}, - checkpointing_steps=checkpointing_steps, - checkpointing_limit=3, - output_dir=output_dir, - enable=True, - _parallel_backend=mock_parallel_backend, - ) - - # Step 3: Simulate training for 6 steps (past checkpoint at step 5) - for step in range(1, 7): # Steps 1-6 - # Log training metrics - initial_tracker.log({"loss": 1.0 / step, "step": step}, step=step) - - # Save checkpoint using real checkpointer at step 5 - if step == checkpointing_steps: - checkpointer.save(step, force=True, _device=None, _is_main_process=True) - - # Step 4: "Quit" training after step 6 (simulating interruption) - initial_tracker.finish() - - # PHASE 2: Resume training from checkpoint - # ========================================= - - # Step 5: Load checkpoint using the real checkpointer and extract wandb_run_id - # Create a new checkpointer instance for loading (simulating a new training session) - mock_accelerator_resume = Mock(spec=Accelerator) - - # Mock the load_state method to simulate checkpoint loading - def mock_load_state(path): - import torch - - states_path = pathlib.Path(path) / "states.pt" - if states_path.exists(): - loaded_states = torch.load(states_path) - checkpointer_resume.states.update(loaded_states) - return True - return False - - mock_accelerator_resume.load_state = mock_load_state - - checkpointer_resume = AccelerateCheckpointer( - accelerator=mock_accelerator_resume, - states={}, - checkpointing_steps=checkpointing_steps, - checkpointing_limit=3, - output_dir=output_dir, - enable=True, - ) - - # Load the checkpoint (this populates checkpointer_resume.states with saved data) - checkpoint_loaded = checkpointer_resume.load(checkpointing_steps) - self.assertTrue(checkpoint_loaded, "Checkpoint should have been loaded successfully") - - # Extract the wandb run_id from the loaded checkpoint - loaded_wandb_run_id = checkpointer_resume.get_wandb_run_id_from_checkpoint() - self.assertIsNotNone(loaded_wandb_run_id, "Wandb run ID should be available from checkpoint") - loaded_step = checkpointing_steps # We know we saved at this step - - # Step 6: Initialize new tracker with resume_run_id from checkpoint - resumed_tracker = WandbTracker( - "issue-188-accelerate-reproduction", - log_dir=output_dir, - config={"lr": 0.001, "max_steps": max_train_steps}, - resume_run_id=loaded_wandb_run_id, # This should resume the same wandb run - ) - - # Step 7: CRITICAL TEST - Verify the same wandb run is being used - resumed_wandb_run_id = resumed_tracker.get_wandb_run_id() - - self.assertEqual( - original_wandb_run_id, - resumed_wandb_run_id, - f"BUG REPRODUCED: Issue #188 with AccelerateCheckpointer - wandb session not resumed! " - f"Original run ID: {original_wandb_run_id}, " - f"Resumed run ID: {resumed_wandb_run_id}. " - f"Expected the same run ID to be reused to preserve training history.", - ) - - # Step 8: Continue training from step 6 to step 10 - for step in range(loaded_step + 1, max_train_steps + 1): # Steps 6-10 - resumed_tracker.log({"loss": 1.0 / step, "step": step}, step=step) - - resumed_tracker.finish() - - # Additional verification: Ensure no new run was created - self.assertIsNotNone(original_wandb_run_id, "Original wandb run ID should not be None") - self.assertIsNotNone(resumed_wandb_run_id, "Resumed wandb run ID should not be None") - self.assertEqual( - len(original_wandb_run_id), len(resumed_wandb_run_id), "Run IDs should have the same format/length" - ) - - def test_issue_188_direct_reproduction_with_ptd_checkpointer(self): - """Direct reproduction of issue #188 using PTDCheckpointer: Train for 10 steps with checkpointing at 5, - quit after 6 steps, then resume from checkpoint.""" - - with tempfile.TemporaryDirectory() as output_dir: - # Simulate the exact scenario from the issue - checkpointing_steps = 5 - max_train_steps = 10 - - # PHASE 1: Start initial training run - # ===================================== - - # Step 1: Initialize wandb tracker for initial training - initial_tracker = WandbTracker( - "issue-188-ptd-reproduction", log_dir=output_dir, config={"lr": 0.001, "max_steps": max_train_steps} - ) - original_wandb_run_id = initial_tracker.get_wandb_run_id() - - # Step 2: Set up a real PTDCheckpointer with proper mocking - mock_parallel_backend = Mock() - mock_parallel_backend.tracker = initial_tracker - - mock_schedulers = Mock() - mock_schedulers.get_lr_scheduler_state.return_value = {} - - checkpointer = PTDCheckpointer( - dataloader=Mock(), - model_parts=[Mock()], - optimizers=Mock(), - schedulers=mock_schedulers, - states={}, - checkpointing_steps=checkpointing_steps, - checkpointing_limit=3, - output_dir=output_dir, - enable=True, - _parallel_backend=mock_parallel_backend, - ) - - # Step 3: Simulate training for 6 steps (past checkpoint at step 5) - for step in range(1, 7): - initial_tracker.log({"loss": 1.0 / step, "step": step}, step=step) - - # At step 5, manually save wandb_run_id to simulate checkpointing - if step == checkpointing_steps: - # Simulate saving wandb_run_id during checkpoint (skip full checkpoint due to Mock issues) - if checkpointer._parallel_backend and checkpointer._parallel_backend.tracker: - wandb_run_id = checkpointer._parallel_backend.tracker.get_wandb_run_id() - if wandb_run_id: - checkpointer.states["wandb_run_id"] = wandb_run_id - - # Step 4: "Quit" training after step 6 (simulating interruption) - initial_tracker.finish() - - # PHASE 2: Resume training from checkpoint - # ========================================= - - # Step 5: Load checkpoint using the real PTDCheckpointer and extract wandb_run_id - # Create a new checkpointer instance for loading (simulating a new training session) - mock_parallel_backend_resume = Mock() - mock_schedulers_resume = Mock() - mock_schedulers_resume.get_lr_scheduler_state.return_value = {} - - checkpointer_resume = PTDCheckpointer( - dataloader=Mock(), - model_parts=[Mock()], - optimizers=Mock(), - schedulers=mock_schedulers_resume, - states={}, - checkpointing_steps=checkpointing_steps, - checkpointing_limit=3, - output_dir=output_dir, - enable=True, - _parallel_backend=mock_parallel_backend_resume, - ) - - # Simulate loading the checkpoint state by manually copying the saved states - # This is what would happen during actual checkpoint loading - checkpointer_resume.states.update(checkpointer.states) - - # Extract the wandb run_id from the loaded checkpoint - loaded_wandb_run_id = checkpointer_resume.get_wandb_run_id_from_checkpoint() - self.assertIsNotNone(loaded_wandb_run_id, "Wandb run ID should be available from PTD checkpoint") - self.assertEqual(loaded_wandb_run_id, original_wandb_run_id, "Loaded wandb run ID should match original") - loaded_step = checkpointing_steps # We know we saved at this step - - # Step 6: Initialize new tracker with resume_run_id from checkpoint - resumed_tracker = WandbTracker( - "issue-188-ptd-reproduction", - log_dir=output_dir, - config={"lr": 0.001, "max_steps": max_train_steps}, - resume_run_id=loaded_wandb_run_id, # This should resume the same wandb run - ) - - # Step 7: CRITICAL TEST - Verify the same wandb run is being used - resumed_wandb_run_id = resumed_tracker.get_wandb_run_id() - - self.assertEqual( - original_wandb_run_id, - resumed_wandb_run_id, - f"BUG REPRODUCED: Issue #188 with PTDCheckpointer - wandb session not resumed! " - f"Original run ID: {original_wandb_run_id}, " - f"Resumed run ID: {resumed_wandb_run_id}. " - f"Expected the same run ID to be reused to preserve training history.", - ) - - # Step 8: Continue training from step 6 to step 10 - for step in range(loaded_step + 1, max_train_steps + 1): # Steps 6-10 - resumed_tracker.log({"loss": 1.0 / step, "step": step}, step=step) - - resumed_tracker.finish() - - # Additional verification: Ensure no new run was created - self.assertIsNotNone(original_wandb_run_id, "Original wandb run ID should not be None") - self.assertIsNotNone(resumed_wandb_run_id, "Resumed wandb run ID should not be None") - self.assertEqual( - len(original_wandb_run_id), len(resumed_wandb_run_id), "Run IDs should have the same format/length" - ) - - def test_checkpointer_saves_wandb_run_id(self): - """Test that both PTDCheckpointer and AccelerateCheckpointer save wandb run_id to enable resumption.""" - with tempfile.TemporaryDirectory() as log_dir: - # Create tracker - tracker = WandbTracker("checkpoint-test", log_dir=log_dir, config={}) - run_id = tracker.get_wandb_run_id() - - # Test PTDCheckpointer - mock_parallel_backend = Mock() - mock_parallel_backend.tracker = tracker - - # Create proper mock for schedulers - mock_schedulers = Mock() - mock_schedulers.get_lr_scheduler_state.return_value = {} - - ptd_checkpointer = PTDCheckpointer( - dataloader=Mock(), - model_parts=[Mock()], - optimizers=Mock(), - schedulers=mock_schedulers, - states={}, - checkpointing_steps=1, - checkpointing_limit=1, - output_dir=log_dir, - enable=True, - _parallel_backend=mock_parallel_backend, - ) - - # Simulate the wandb_run_id being saved during checkpoint - if ptd_checkpointer._parallel_backend.tracker: - ptd_checkpointer.states["wandb_run_id"] = ptd_checkpointer._parallel_backend.tracker.get_wandb_run_id() - - # Test retrieval from PTDCheckpointer - saved_run_id = ptd_checkpointer.get_wandb_run_id_from_checkpoint() - self.assertEqual(saved_run_id, run_id) - - # Test AccelerateCheckpointer - from accelerate import Accelerator - - mock_accelerator = Mock(spec=Accelerator) - - accelerate_checkpointer = AccelerateCheckpointer( - accelerator=mock_accelerator, - states={}, - checkpointing_steps=1, - checkpointing_limit=1, - output_dir=log_dir, - enable=True, - ) - - # Simulate the wandb_run_id being saved during checkpoint - accelerate_checkpointer.states["wandb_run_id"] = run_id - - # Test retrieval from AccelerateCheckpointer - saved_run_id = accelerate_checkpointer.get_wandb_run_id_from_checkpoint() - self.assertEqual(saved_run_id, run_id) - - tracker.finish() From 0ee553d72b535e14b77b396037d4d259a9bc0276 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 15:22:37 +0300 Subject: [PATCH 10/24] Adds test for WandB session resumption from checkpoint Introduces a new integration test to verify that the WandB session is correctly resumed when training continues from a saved checkpoint. This ensures that experiment tracking data is consolidated into a single WandB run across multiple training sessions, rather than creating a new run upon each resumption. --- tests/test_trackers.py | 87 +++++++++++++++++++++++++++++++++++++++ tests/trainer/__init__.py | 1 + 2 files changed, 88 insertions(+) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 46ae5502..090b7958 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -5,6 +5,9 @@ import unittest from diffusers.utils.testing_utils import CaptureLogger +from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa +from tests.trainer import SFTTrainerFastTestsMixin +from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger from finetrainers.trackers import WandbTracker @@ -25,3 +28,87 @@ def test_wandb_logdir(self): self.assertTrue(pathlib.Path(tempdir).exists()) self.assertTrue("WandB logging enabled" in cap_log.out) + + +class SFTTrainerLoRAWandbResumeTests(SFTTrainerFastTestsMixin, unittest.TestCase): + model_specification_cls = DummyCogView4ModelSpecification + + def get_args(self) -> BaseArgs: + args = self.get_base_args() + args.checkpointing_steps = 5 + args.parallel_backend = "accelerate" + args.training_type = TrainingType.LORA + args.rank = 4 + args.lora_alpha = 4 + args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + return args + + def test_wandb_session_resumption_with_checkpoint(self): + """ + Test the core issue: wandb session should be continued when resuming from checkpoint. + + Steps: + 1. Start training for 6 steps (with checkpointing every 5 steps) + 2. Verify checkpoint is created at step 5 + 3. Resume training from checkpoint at step 5 for additional steps + 4. Verify that the same wandb session ID is maintained + """ + logger = logging.getLogger("finetrainers") + + with CaptureLogger(logger) as cap_log: + # Phase 1: Initial training run (6 steps, checkpoint at step 5) + args_phase1 = self.get_args() + args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5) + + model_specification_1 = self.model_specification_cls() + trainer_phase1 = SFTTrainer(args_phase1, model_specification_1) + trainer_phase1.run() + + # Verify checkpoint was created at step 5 + checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5" + self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}") + + # Extract the wandb run ID from the first training run + # This should be stored in the checkpoint + original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint() + self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint") + + # Clean up the first trainer + del trainer_phase1 + + # Phase 2: Resume training from checkpoint + args_phase2 = self.get_args() + args_phase2.resume_from_checkpoint = "finetrainers_step_5" # Resume from step 5 checkpoint + + model_specification_2 = self.model_specification_cls() + trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) + trainer_phase2.run() + + # Verify that the resumed training uses the same wandb run ID + resumed_wandb_run_id = None + if hasattr(trainer_phase2.state.parallel_backend, 'tracker') and trainer_phase2.state.parallel_backend.tracker: + resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() + + self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") + self.assertEqual( + original_wandb_run_id, + resumed_wandb_run_id, + f"WandB run ID should be the same after resumption. " + f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}" + ) + + # Verify that training continued from the correct step + final_step = trainer_phase2.state.train_state.step + self.assertGreaterEqual(final_step, 10, "Training should have reached at least step 10") + + # Clean up the second trainer + del trainer_phase2 + + # Verify logging contains resumption messages + log_output = cap_log.out + self.assertIn("WandB logging enabled", log_output) + + # If resumption happened correctly, we should see a resumption message + # The exact message depends on the WandB implementation, but we can check for key indicators + if "Resuming WandB run with ID:" in log_output: + self.assertIn(f"Resuming WandB run with ID: {original_wandb_run_id}", log_output) diff --git a/tests/trainer/__init__.py b/tests/trainer/__init__.py index e69de29b..5ce37eec 100644 --- a/tests/trainer/__init__.py +++ b/tests/trainer/__init__.py @@ -0,0 +1 @@ +from .test_sft_trainer import SFTTrainerFastTestsMixin From 365e175e2fca99cbebd33804d3795e2cd1b1dc5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 15:22:58 +0300 Subject: [PATCH 11/24] style --- tests/test_trackers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 090b7958..3833f438 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -5,11 +5,12 @@ import unittest from diffusers.utils.testing_utils import CaptureLogger -from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa -from tests.trainer import SFTTrainerFastTestsMixin -from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger +from finetrainers import BaseArgs, SFTTrainer, TrainingType from finetrainers.trackers import WandbTracker +from tests.trainer import SFTTrainerFastTestsMixin + +from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa os.environ["WANDB_MODE"] = "offline" @@ -86,7 +87,10 @@ def test_wandb_session_resumption_with_checkpoint(self): # Verify that the resumed training uses the same wandb run ID resumed_wandb_run_id = None - if hasattr(trainer_phase2.state.parallel_backend, 'tracker') and trainer_phase2.state.parallel_backend.tracker: + if ( + hasattr(trainer_phase2.state.parallel_backend, "tracker") + and trainer_phase2.state.parallel_backend.tracker + ): resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") @@ -94,7 +98,7 @@ def test_wandb_session_resumption_with_checkpoint(self): original_wandb_run_id, resumed_wandb_run_id, f"WandB run ID should be the same after resumption. " - f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}" + f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}", ) # Verify that training continued from the correct step From 5ab2bb9cbaadef70b62a953ee7c31e8ebe590b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 18:40:20 +0300 Subject: [PATCH 12/24] feat: Add WandB run ID tracking in AccelerateCheckpointer in init --- finetrainers/parallel/accelerate.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 8d10adcf..d150d1ba 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -269,6 +269,14 @@ def __init__( **kwargs, ) -> None: self.accelerator = accelerator + if ( + self._parallel_backend + and hasattr(self._parallel_backend, "tracker") + and self._parallel_backend.tracker + ): + wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + states["wandb_run_id"] = wandb_run_id self.states = states self._parallel_backend = _parallel_backend From b07872dae25ee257391f82d34b41686bd923cf8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 19:27:29 +0300 Subject: [PATCH 13/24] fix: Correctly assign _parallel_backend in AccelerateCheckpointer initialization --- finetrainers/parallel/accelerate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index d150d1ba..1677cf84 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -269,6 +269,8 @@ def __init__( **kwargs, ) -> None: self.accelerator = accelerator + self._parallel_backend = _parallel_backend + if ( self._parallel_backend and hasattr(self._parallel_backend, "tracker") @@ -278,7 +280,6 @@ def __init__( if wandb_run_id: states["wandb_run_id"] = wandb_run_id self.states = states - self._parallel_backend = _parallel_backend self.checkpointing_steps = checkpointing_steps self.checkpointing_limit = checkpointing_limit From 8b9c285b758c5ec5ce7dac5278d32048e24cd556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 19:29:53 +0300 Subject: [PATCH 14/24] fix: Add sleep and process group cleanup to prevent test failures --- tests/test_trackers.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 3833f438..0ff3d9e2 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -9,6 +9,7 @@ from finetrainers import BaseArgs, SFTTrainer, TrainingType from finetrainers.trackers import WandbTracker from tests.trainer import SFTTrainerFastTestsMixin +import time, pytest, torch from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa @@ -16,6 +17,13 @@ os.environ["WANDB_MODE"] = "offline" os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" +@pytest.fixture(autouse=True) +def slow_down_tests(): + yield + # Sleep between each test so that process groups are cleaned and resources are released. + # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. + # !!!Look into this in future!!! + time.sleep(5) class WandbFastTests(unittest.TestCase): def test_wandb_logdir(self): @@ -76,7 +84,11 @@ def test_wandb_session_resumption_with_checkpoint(self): # Clean up the first trainer del trainer_phase1 - + # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually + # make sure to destroy it here. + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + time.sleep(3) # Phase 2: Resume training from checkpoint args_phase2 = self.get_args() args_phase2.resume_from_checkpoint = "finetrainers_step_5" # Resume from step 5 checkpoint From 4f624283e4129035a1a07195e6711d658f069801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 14 Jul 2025 19:30:14 +0300 Subject: [PATCH 15/24] style --- finetrainers/parallel/accelerate.py | 6 +----- tests/test_trackers.py | 6 +++++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 1677cf84..6a4a743f 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -271,11 +271,7 @@ def __init__( self.accelerator = accelerator self._parallel_backend = _parallel_backend - if ( - self._parallel_backend - and hasattr(self._parallel_backend, "tracker") - and self._parallel_backend.tracker - ): + if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker: wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() if wandb_run_id: states["wandb_run_id"] = wandb_run_id diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 0ff3d9e2..1b0c76ef 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -2,14 +2,16 @@ import os import pathlib import tempfile +import time import unittest +import pytest +import torch from diffusers.utils.testing_utils import CaptureLogger from finetrainers import BaseArgs, SFTTrainer, TrainingType from finetrainers.trackers import WandbTracker from tests.trainer import SFTTrainerFastTestsMixin -import time, pytest, torch from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa @@ -17,6 +19,7 @@ os.environ["WANDB_MODE"] = "offline" os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" + @pytest.fixture(autouse=True) def slow_down_tests(): yield @@ -25,6 +28,7 @@ def slow_down_tests(): # !!!Look into this in future!!! time.sleep(5) + class WandbFastTests(unittest.TestCase): def test_wandb_logdir(self): logger = logging.getLogger("finetrainers") From dde8129030f059fc5da2c10437e5232c851251d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 15:03:15 +0300 Subject: [PATCH 16/24] fix: Ensure wandb run ID is saved correctly in PTDCheckpointer state --- finetrainers/parallel/ptd.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py index c2695c47..45878f1a 100644 --- a/finetrainers/parallel/ptd.py +++ b/finetrainers/parallel/ptd.py @@ -321,7 +321,12 @@ def __init__( } ) self.states.update(schedulers.get_lr_scheduler_state()) - + if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker: + wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() + if wandb_run_id: + self.states["wandb_run_id"] = wandb_run_id + else: + self.states["wandb_run_id"] = None self.checkpointing_steps = checkpointing_steps self.checkpointing_limit = checkpointing_limit self.output_dir = pathlib.Path(output_dir) @@ -335,12 +340,6 @@ def save(self, step: int = -1, force: bool = False, *, _device: torch.device, _i if not self._should_checkpoint(step, force): return None - # Save wandb run ID if available - if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker: - wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() - if wandb_run_id: - self.states["wandb_run_id"] = wandb_run_id - checkpoint_dir = self._get_checkpoint_dir(step) begin_time = time.monotonic() torch.distributed.checkpoint.save(self.states, checkpoint_id=checkpoint_dir.as_posix()) From 4a58dbe876ad45116d2ca40f438d7887ee5923b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 15:03:22 +0300 Subject: [PATCH 17/24] fix: Update parallel_backend to 'ptd' and correct resume_from_checkpoint argument type in SFTTrainerLoRAWandbResumeTests --- tests/test_trackers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 1b0c76ef..199e91a4 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -49,7 +49,7 @@ class SFTTrainerLoRAWandbResumeTests(SFTTrainerFastTestsMixin, unittest.TestCase def get_args(self) -> BaseArgs: args = self.get_base_args() args.checkpointing_steps = 5 - args.parallel_backend = "accelerate" + args.parallel_backend = "ptd" args.training_type = TrainingType.LORA args.rank = 4 args.lora_alpha = 4 @@ -95,7 +95,7 @@ def test_wandb_session_resumption_with_checkpoint(self): time.sleep(3) # Phase 2: Resume training from checkpoint args_phase2 = self.get_args() - args_phase2.resume_from_checkpoint = "finetrainers_step_5" # Resume from step 5 checkpoint + args_phase2.resume_from_checkpoint = 5 # Resume from step 5 checkpoint model_specification_2 = self.model_specification_cls() trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) From 2b6fb2193e6a4fc8b269499f5212874e0fd8c4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 15:05:13 +0300 Subject: [PATCH 18/24] fix: Ensure wandb_run_id is set to None when not available in AccelerateCheckpointer --- finetrainers/parallel/accelerate.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 6a4a743f..654e50c0 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -275,6 +275,8 @@ def __init__( wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() if wandb_run_id: states["wandb_run_id"] = wandb_run_id + else: + states["wandb_run_id"] = None self.states = states self.checkpointing_steps = checkpointing_steps @@ -294,17 +296,7 @@ def save_model_hook(models, weights, output_dir: str) -> None: _callback_fn(weights[0]) - states_to_save = dict(self.states) - if ( - self._parallel_backend - and hasattr(self._parallel_backend, "tracker") - and self._parallel_backend.tracker - ): - wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id() - if wandb_run_id: - states_to_save["wandb_run_id"] = wandb_run_id - - torch.save(states_to_save, os.path.join(output_dir, "states.pt")) + torch.save(self.states, os.path.join(output_dir, "states.pt")) def load_model_hook(models, input_dir) -> None: self.states = torch.load(os.path.join(input_dir, "states.pt")) From b4f73d2ad6666623cc95113ba8c48c48b24acb3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 15:17:06 +0300 Subject: [PATCH 19/24] style --- tests/test_trackers.py | 111 +++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 66 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 199e91a4..36cb5170 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -66,69 +66,48 @@ def test_wandb_session_resumption_with_checkpoint(self): 3. Resume training from checkpoint at step 5 for additional steps 4. Verify that the same wandb session ID is maintained """ - logger = logging.getLogger("finetrainers") - - with CaptureLogger(logger) as cap_log: - # Phase 1: Initial training run (6 steps, checkpoint at step 5) - args_phase1 = self.get_args() - args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5) - - model_specification_1 = self.model_specification_cls() - trainer_phase1 = SFTTrainer(args_phase1, model_specification_1) - trainer_phase1.run() - - # Verify checkpoint was created at step 5 - checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5" - self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}") - - # Extract the wandb run ID from the first training run - # This should be stored in the checkpoint - original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint() - self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint") - - # Clean up the first trainer - del trainer_phase1 - # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually - # make sure to destroy it here. - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - time.sleep(3) - # Phase 2: Resume training from checkpoint - args_phase2 = self.get_args() - args_phase2.resume_from_checkpoint = 5 # Resume from step 5 checkpoint - - model_specification_2 = self.model_specification_cls() - trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) - trainer_phase2.run() - - # Verify that the resumed training uses the same wandb run ID - resumed_wandb_run_id = None - if ( - hasattr(trainer_phase2.state.parallel_backend, "tracker") - and trainer_phase2.state.parallel_backend.tracker - ): - resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() - - self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") - self.assertEqual( - original_wandb_run_id, - resumed_wandb_run_id, - f"WandB run ID should be the same after resumption. " - f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}", - ) - - # Verify that training continued from the correct step - final_step = trainer_phase2.state.train_state.step - self.assertGreaterEqual(final_step, 10, "Training should have reached at least step 10") - - # Clean up the second trainer - del trainer_phase2 - - # Verify logging contains resumption messages - log_output = cap_log.out - self.assertIn("WandB logging enabled", log_output) - - # If resumption happened correctly, we should see a resumption message - # The exact message depends on the WandB implementation, but we can check for key indicators - if "Resuming WandB run with ID:" in log_output: - self.assertIn(f"Resuming WandB run with ID: {original_wandb_run_id}", log_output) + # Phase 1: Initial training run (6 steps, checkpoint at step 5) + args_phase1 = self.get_args() + args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5) + + model_specification_1 = self.model_specification_cls() + trainer_phase1 = SFTTrainer(args_phase1, model_specification_1) + trainer_phase1.run() + + # Verify checkpoint was created at step 5 + checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5" + self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}") + + # Extract the wandb run ID from the first training run + # This should be stored in the checkpoint + original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint() + self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint") + + del trainer_phase1 + # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually + # make sure to destroy it here. + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + time.sleep(3) + # Phase 2: Resume training from the checkpoint + args_phase2 = self.get_args() + args_phase2.resume_from_checkpoint = 5 + + model_specification_2 = self.model_specification_cls() + trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) + trainer_phase2.run() + + # Verify that the resumed training uses the same wandb run ID + resumed_wandb_run_id = None + if hasattr(trainer_phase2.state.parallel_backend, "tracker") and trainer_phase2.state.parallel_backend.tracker: + resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() + + self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") + self.assertEqual( + original_wandb_run_id, + resumed_wandb_run_id, + f"WandB run ID should be the same after resumption. " + f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}", + ) + + del trainer_phase2 From ea2b8adc19c5072626672a12b874a53bfca7d67d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 15:59:31 +0300 Subject: [PATCH 20/24] fix: Refactor wandb session resumption test to iterate over parallel backends --- tests/test_trackers.py | 93 +++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 46 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 36cb5170..23f3db00 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -49,7 +49,6 @@ class SFTTrainerLoRAWandbResumeTests(SFTTrainerFastTestsMixin, unittest.TestCase def get_args(self) -> BaseArgs: args = self.get_base_args() args.checkpointing_steps = 5 - args.parallel_backend = "ptd" args.training_type = TrainingType.LORA args.rank = 4 args.lora_alpha = 4 @@ -66,48 +65,50 @@ def test_wandb_session_resumption_with_checkpoint(self): 3. Resume training from checkpoint at step 5 for additional steps 4. Verify that the same wandb session ID is maintained """ - # Phase 1: Initial training run (6 steps, checkpoint at step 5) - args_phase1 = self.get_args() - args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5) - - model_specification_1 = self.model_specification_cls() - trainer_phase1 = SFTTrainer(args_phase1, model_specification_1) - trainer_phase1.run() - - # Verify checkpoint was created at step 5 - checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5" - self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}") - - # Extract the wandb run ID from the first training run - # This should be stored in the checkpoint - original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint() - self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint") - - del trainer_phase1 - # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually - # make sure to destroy it here. - if torch.distributed.is_initialized(): - torch.distributed.destroy_process_group() - time.sleep(3) - # Phase 2: Resume training from the checkpoint - args_phase2 = self.get_args() - args_phase2.resume_from_checkpoint = 5 - - model_specification_2 = self.model_specification_cls() - trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) - trainer_phase2.run() - - # Verify that the resumed training uses the same wandb run ID - resumed_wandb_run_id = None - if hasattr(trainer_phase2.state.parallel_backend, "tracker") and trainer_phase2.state.parallel_backend.tracker: - resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() - - self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") - self.assertEqual( - original_wandb_run_id, - resumed_wandb_run_id, - f"WandB run ID should be the same after resumption. " - f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}", - ) - - del trainer_phase2 + for parallel_backend in ("ptd", "accelerate"): + # Phase 1: Initial training run (6 steps, checkpoint at step 5) + args_phase1 = self.get_args() + args_phase1.parallel_backend = parallel_backend + args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5) + + model_specification_1 = self.model_specification_cls() + trainer_phase1 = SFTTrainer(args_phase1, model_specification_1) + trainer_phase1.run() + + # Verify checkpoint was created at step 5 + checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5" + self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}") + + # Extract the wandb run ID from the first training run + # This should be stored in the checkpoint + original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint() + self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint") + + del trainer_phase1 + # Reinitialize process group for resumed training + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend="nccl") # or 'gloo' for CPU + + # Phase 2: Resume training from the checkpoint + args_phase2 = self.get_args() + args_phase2.parallel_backend = parallel_backend + args_phase2.resume_from_checkpoint = 5 + + model_specification_2 = self.model_specification_cls() + trainer_phase2 = SFTTrainer(args_phase2, model_specification_2) + trainer_phase2.run() + + # Verify that the resumed training uses the same wandb run ID + resumed_wandb_run_id = None + if hasattr(trainer_phase2.state.parallel_backend, "tracker") and trainer_phase2.state.parallel_backend.tracker: + resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() + + self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") + self.assertEqual( + original_wandb_run_id, + resumed_wandb_run_id, + f"WandB run ID should be the same after resumption. " + f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}", + ) + + del trainer_phase2 From e056d07d2049f00d8f8c77f997e12b05b993c6cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 15:59:38 +0300 Subject: [PATCH 21/24] fix: Update load_model_hook to include weights_only parameter for state loading --- finetrainers/parallel/accelerate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 654e50c0..8a3003e1 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -299,7 +299,7 @@ def save_model_hook(models, weights, output_dir: str) -> None: torch.save(self.states, os.path.join(output_dir, "states.pt")) def load_model_hook(models, input_dir) -> None: - self.states = torch.load(os.path.join(input_dir, "states.pt")) + self.states = torch.load(os.path.join(input_dir, "states.pt"), weights_only=False) self.accelerator.register_save_state_pre_hook(save_model_hook) self.accelerator.register_load_state_pre_hook(load_model_hook) From c4981ce4724521da2951d8f1b4cf6763c07ff306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 16:00:56 +0300 Subject: [PATCH 22/24] fix: Simplify retrieval of resumed wandb run ID in SFTTrainerLoRAWandbResumeTests --- tests/test_trackers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 23f3db00..9be52d30 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -99,9 +99,7 @@ def test_wandb_session_resumption_with_checkpoint(self): trainer_phase2.run() # Verify that the resumed training uses the same wandb run ID - resumed_wandb_run_id = None - if hasattr(trainer_phase2.state.parallel_backend, "tracker") and trainer_phase2.state.parallel_backend.tracker: - resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() + resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id() self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID") self.assertEqual( From d8e02f774c98da7bbd3dcfea4596f2fb96cdf5f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 16:02:14 +0300 Subject: [PATCH 23/24] fix: Remove unnecessary pytest fixture and sleep to improve test performance --- tests/test_trackers.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 9be52d30..629146d4 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -2,10 +2,8 @@ import os import pathlib import tempfile -import time import unittest -import pytest import torch from diffusers.utils.testing_utils import CaptureLogger @@ -20,15 +18,6 @@ os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO" -@pytest.fixture(autouse=True) -def slow_down_tests(): - yield - # Sleep between each test so that process groups are cleaned and resources are released. - # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. - # !!!Look into this in future!!! - time.sleep(5) - - class WandbFastTests(unittest.TestCase): def test_wandb_logdir(self): logger = logging.getLogger("finetrainers") From b26eeb935f3912ccf19f62fd98edaff9a0b280e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 16 Jul 2025 16:08:30 +0300 Subject: [PATCH 24/24] fix: Update process group initialization condition for resumed training --- tests/test_trackers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trackers.py b/tests/test_trackers.py index 629146d4..2694c748 100644 --- a/tests/test_trackers.py +++ b/tests/test_trackers.py @@ -75,7 +75,7 @@ def test_wandb_session_resumption_with_checkpoint(self): del trainer_phase1 # Reinitialize process group for resumed training - if not torch.distributed.is_initialized(): + if parallel_backend != "ptd" and not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") # or 'gloo' for CPU # Phase 2: Resume training from the checkpoint