Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6cd935d
feat: Add support for resuming W&B runs from checkpoints
tolgacangoz Jul 5, 2025
1985fb6
style
tolgacangoz Jul 5, 2025
44c57c3
Adds tests for resuming wandb runs from checkpoints
tolgacangoz Jul 5, 2025
00a1976
refactor: Simplify wandb resumption tests by removing redundant code
tolgacangoz Jul 12, 2025
3ccbba7
Test: Add tests to reproduce wandb run resumption failure
tolgacangoz Jul 12, 2025
e750039
refactor: Remove redundant tests for SFTTrainer and ControlTrainer wa…
tolgacangoz Jul 12, 2025
51a7bfa
refactor: Replace SimpleModel with Mock in PTDCheckpointer tests for …
tolgacangoz Jul 12, 2025
7f20a1e
refactor: Clean up whitespace and improve readability in wandb resump…
tolgacangoz Jul 12, 2025
00c2290
down
tolgacangoz Jul 14, 2025
0ee553d
Adds test for WandB session resumption from checkpoint
tolgacangoz Jul 14, 2025
365e175
style
tolgacangoz Jul 14, 2025
5ab2bb9
feat: Add WandB run ID tracking in AccelerateCheckpointer in init
tolgacangoz Jul 14, 2025
b07872d
fix: Correctly assign _parallel_backend in AccelerateCheckpointer ini…
tolgacangoz Jul 14, 2025
8b9c285
fix: Add sleep and process group cleanup to prevent test failures
tolgacangoz Jul 14, 2025
4f62428
style
tolgacangoz Jul 14, 2025
dde8129
fix: Ensure wandb run ID is saved correctly in PTDCheckpointer state
tolgacangoz Jul 16, 2025
4a58dbe
fix: Update parallel_backend to 'ptd' and correct resume_from_checkpo…
tolgacangoz Jul 16, 2025
2b6fb21
fix: Ensure wandb_run_id is set to None when not available in Acceler…
tolgacangoz Jul 16, 2025
b4f73d2
style
tolgacangoz Jul 16, 2025
ea2b8ad
fix: Refactor wandb session resumption test to iterate over parallel …
tolgacangoz Jul 16, 2025
e056d07
fix: Update load_model_hook to include weights_only parameter for sta…
tolgacangoz Jul 16, 2025
c4981ce
fix: Simplify retrieval of resumed wandb run ID in SFTTrainerLoRAWand…
tolgacangoz Jul 16, 2025
d8e02f7
fix: Remove unnecessary pytest fixture and sleep to improve test perf…
tolgacangoz Jul 16, 2025
b26eeb9
fix: Update process group initialization condition for resumed training
tolgacangoz Jul 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion finetrainers/parallel/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def _get_mesh():
return _get_mesh()

def get_checkpointer(self, *args, **kwargs):
kwargs["_parallel_backend"] = self
return AccelerateCheckpointer(self._accelerator, *args, **kwargs)

@property
Expand Down Expand Up @@ -263,10 +264,19 @@ def __init__(
enable: bool = True,
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
_prefix: str = "finetrainers_step",
_parallel_backend: Optional["BaseParallelBackend"] = None,
*args,
**kwargs,
) -> None:
self.accelerator = accelerator
self._parallel_backend = _parallel_backend

if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker:
wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id()
if wandb_run_id:
states["wandb_run_id"] = wandb_run_id
else:
states["wandb_run_id"] = None
self.states = states

self.checkpointing_steps = checkpointing_steps
Expand All @@ -285,10 +295,11 @@ def save_model_hook(models, weights, output_dir: str) -> None:
assert len(models) == 1

_callback_fn(weights[0])

torch.save(self.states, os.path.join(output_dir, "states.pt"))

def load_model_hook(models, input_dir) -> None:
self.states = torch.load(os.path.join(input_dir, "states.pt"))
self.states = torch.load(os.path.join(input_dir, "states.pt"), weights_only=False)

self.accelerator.register_save_state_pre_hook(save_model_hook)
self.accelerator.register_load_state_pre_hook(load_model_hook)
Expand Down Expand Up @@ -334,6 +345,10 @@ def load(self, step: int = -1) -> bool:

return True

def get_wandb_run_id_from_checkpoint(self) -> Optional[str]:
"""Get the wandb run ID from the loaded checkpoint states."""
return self.states.get("wandb_run_id", None)

def _should_checkpoint(self, step: int, force: bool) -> bool:
if not self.enable:
return False
Expand Down
9 changes: 7 additions & 2 deletions finetrainers/parallel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ def get_checkpointer(self, *args, **kwargs) -> None:
raise NotImplementedError("Method `get_checkpointer` must be implemented by subclass.")

def initialize_trackers(
self, trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
self,
trackers: List[str],
experiment_name: str,
config: Dict[str, Any],
log_dir: str,
resume_run_id: Optional[str] = None,
) -> TrackerType:
if self.is_main_process:
self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir)
self.tracker = initialize_trackers(trackers, experiment_name, config, log_dir, resume_run_id)
else:
self.tracker = DummyTracker()

Expand Down
15 changes: 13 additions & 2 deletions finetrainers/parallel/ptd.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _get_mesh():
return _get_mesh()

def get_checkpointer(self, *args, **kwargs):
return PTDCheckpointer(*args, **kwargs)
return PTDCheckpointer(*args, **kwargs, _parallel_backend=self)

@property
def world_size(self):
Expand Down Expand Up @@ -309,7 +309,9 @@ def __init__(
enable: bool = True,
_callback_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None,
_prefix: str = "finetrainers_step",
_parallel_backend: Optional["BaseParallelBackend"] = None,
) -> None:
self._parallel_backend = _parallel_backend
self.states = states
self.states.update(
{
Expand All @@ -319,7 +321,12 @@ def __init__(
}
)
self.states.update(schedulers.get_lr_scheduler_state())

if self._parallel_backend and hasattr(self._parallel_backend, "tracker") and self._parallel_backend.tracker:
wandb_run_id = self._parallel_backend.tracker.get_wandb_run_id()
if wandb_run_id:
self.states["wandb_run_id"] = wandb_run_id
else:
self.states["wandb_run_id"] = None
self.checkpointing_steps = checkpointing_steps
self.checkpointing_limit = checkpointing_limit
self.output_dir = pathlib.Path(output_dir)
Expand Down Expand Up @@ -385,6 +392,10 @@ def load(self, step: int = -1) -> bool:

return True

def get_wandb_run_id_from_checkpoint(self) -> Optional[str]:
"""Get the wandb run ID from the loaded checkpoint states."""
return self.states.get("wandb_run_id", None)

def _should_checkpoint(self, step: int, force: bool) -> bool:
if not self.enable:
return False
Expand Down
43 changes: 39 additions & 4 deletions finetrainers/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def log(self, metrics: Dict[str, Any], step: int) -> None:
def finish(self) -> None:
pass

def get_wandb_run_id(self) -> Optional[str]:
r"""Get the wandb run ID if available."""
return None


class DummyTracker(BaseTracker):
def __init__(self):
Expand All @@ -52,7 +56,13 @@ def finish(self) -> None:
class WandbTracker(BaseTracker):
r"""Logger implementation for Weights & Biases."""

def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
experiment_name: str,
log_dir: str,
config: Optional[Dict[str, Any]] = None,
resume_run_id: Optional[str] = None,
) -> None:
super().__init__()

import wandb
Expand All @@ -62,7 +72,11 @@ def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str
# WandB does not create a directory if it does not exist and instead starts using the system temp directory.
pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True)

self.run = wandb.init(project=experiment_name, dir=log_dir, config=config)
if resume_run_id is not None:
logger.info(f"Resuming WandB run with ID: {resume_run_id}")
self.run = wandb.init(project=experiment_name, dir=log_dir, config=config, id=resume_run_id, resume="must")
else:
self.run = wandb.init(project=experiment_name, dir=log_dir, config=config)
logger.info("WandB logging enabled")

def log(self, metrics: Dict[str, Any], step: int) -> None:
Expand All @@ -73,6 +87,15 @@ def log(self, metrics: Dict[str, Any], step: int) -> None:
def finish(self) -> None:
self.run.finish()

@property
def run_id(self) -> Optional[str]:
"""Return the current wandb run ID for checkpointing purposes."""
return self.run.id if self.run is not None else None

def get_wandb_run_id(self) -> Optional[str]:
"""Return the wandb run ID if this tracker supports it."""
return self.run_id


class SequentialTracker(BaseTracker):
r"""Sequential tracker that logs to multiple trackers in sequence."""
Expand Down Expand Up @@ -106,6 +129,14 @@ def finish(self) -> None:
for tracker in self.trackers:
tracker.finish()

def get_wandb_run_id(self) -> Optional[str]:
"""Return the wandb run ID from the first WandB tracker in the sequence."""
for tracker in self.trackers:
run_id = tracker.get_wandb_run_id()
if run_id is not None:
return run_id
return None


class Trackers(str, Enum):
r"""Enum for supported trackers."""
Expand All @@ -118,7 +149,11 @@ class Trackers(str, Enum):


def initialize_trackers(
trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str
trackers: List[str],
experiment_name: str,
config: Dict[str, Any],
log_dir: str,
resume_run_id: Optional[str] = None,
) -> Union[BaseTracker, SequentialTracker]:
r"""Initialize loggers based on the provided configuration."""

Expand All @@ -135,7 +170,7 @@ def initialize_trackers(
if tracker_name == Trackers.NONE:
tracker = BaseTracker()
elif tracker_name == Trackers.WANDB:
tracker = WandbTracker(experiment_name, log_dir, config)
tracker = WandbTracker(experiment_name, log_dir, config, resume_run_id=resume_run_id)
tracker_instances.append(tracker)

tracker = SequentialTracker(tracker_instances)
Expand Down
10 changes: 7 additions & 3 deletions finetrainers/trainer/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import functools
import os
from typing import Callable, List, Tuple
from typing import Callable, List, Optional, Tuple

import torch
import torch.backends
Expand Down Expand Up @@ -116,12 +116,16 @@ def _init_logging(self) -> None:
logging.set_dependency_log_level(self.args.verbose, self.state.parallel_backend.is_local_main_process)
logger.info("Initialized FineTrainers")

def _init_trackers(self) -> None:
def _init_trackers(self, resume_run_id: Optional[str] = None) -> None:
# TODO(aryan): handle multiple trackers
trackers = [self.args.report_to]
experiment_name = self.args.tracker_name or "finetrainers-experiment"
self.state.parallel_backend.initialize_trackers(
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir
trackers,
experiment_name=experiment_name,
config=self._get_training_info(),
log_dir=self.args.logging_dir,
resume_run_id=resume_run_id,
)

def _init_config_options(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion finetrainers/trainer/control_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def _prepare_for_training(self) -> None:

# 3. Initialize trackers, directories and repositories
self._init_logging()
self._init_trackers()
if self.args.resume_from_checkpoint is None:
self._init_trackers()
self._init_directories_and_repositories()

def _prepare_dataset(self) -> None:
Expand Down Expand Up @@ -373,6 +374,9 @@ def save_model_hook(state_dict: Dict[str, Any]) -> None:
resume_from_checkpoint = -1
if resume_from_checkpoint is not None:
self.checkpointer.load(resume_from_checkpoint)
# Extract wandb run ID from loaded checkpoint and initialize trackers with it
wandb_run_id = self.checkpointer.get_wandb_run_id_from_checkpoint()
self._init_trackers(resume_run_id=wandb_run_id)

def _train(self) -> None:
logger.info("Starting training")
Expand Down
6 changes: 5 additions & 1 deletion finetrainers/trainer/sft_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def _prepare_for_training(self) -> None:

# 3. Initialize trackers, directories and repositories
self._init_logging()
self._init_trackers()
if self.args.resume_from_checkpoint is None:
self._init_trackers()
self._init_directories_and_repositories()

def _prepare_dataset(self) -> None:
Expand Down Expand Up @@ -324,6 +325,9 @@ def save_model_hook(state_dict: Dict[str, Any]) -> None:
resume_from_checkpoint = -1
if resume_from_checkpoint is not None:
self.checkpointer.load(resume_from_checkpoint)
# Extract wandb run ID from loaded checkpoint and initialize trackers with it
wandb_run_id = self.checkpointer.get_wandb_run_id_from_checkpoint()
self._init_trackers(resume_run_id=wandb_run_id)

def _train(self) -> None:
logger.info("Starting training")
Expand Down
75 changes: 75 additions & 0 deletions tests/test_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
import tempfile
import unittest

import torch
from diffusers.utils.testing_utils import CaptureLogger

from finetrainers import BaseArgs, SFTTrainer, TrainingType
from finetrainers.trackers import WandbTracker
from tests.trainer import SFTTrainerFastTestsMixin

from .models.cogview4.base_specification import DummyCogView4ModelSpecification # noqa


os.environ["WANDB_MODE"] = "offline"
os.environ["FINETRAINERS_LOG_LEVEL"] = "INFO"


class WandbFastTests(unittest.TestCase):
Expand All @@ -24,3 +30,72 @@ def test_wandb_logdir(self):
self.assertTrue(pathlib.Path(tempdir).exists())

self.assertTrue("WandB logging enabled" in cap_log.out)


class SFTTrainerLoRAWandbResumeTests(SFTTrainerFastTestsMixin, unittest.TestCase):
model_specification_cls = DummyCogView4ModelSpecification

def get_args(self) -> BaseArgs:
args = self.get_base_args()
args.checkpointing_steps = 5
args.training_type = TrainingType.LORA
args.rank = 4
args.lora_alpha = 4
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
return args

def test_wandb_session_resumption_with_checkpoint(self):
"""
Test the core issue: wandb session should be continued when resuming from checkpoint.

Steps:
1. Start training for 6 steps (with checkpointing every 5 steps)
2. Verify checkpoint is created at step 5
3. Resume training from checkpoint at step 5 for additional steps
4. Verify that the same wandb session ID is maintained
"""
for parallel_backend in ("ptd", "accelerate"):
# Phase 1: Initial training run (6 steps, checkpoint at step 5)
args_phase1 = self.get_args()
args_phase1.parallel_backend = parallel_backend
args_phase1.train_steps = 6 # Train for 6 steps (will checkpoint at step 5)

model_specification_1 = self.model_specification_cls()
trainer_phase1 = SFTTrainer(args_phase1, model_specification_1)
trainer_phase1.run()

# Verify checkpoint was created at step 5
checkpoint_dir = pathlib.Path(self.tmpdir.name) / "finetrainers_step_5"
self.assertTrue(checkpoint_dir.exists(), f"Checkpoint should exist at {checkpoint_dir}")

# Extract the wandb run ID from the first training run
# This should be stored in the checkpoint
original_wandb_run_id = trainer_phase1.checkpointer.get_wandb_run_id_from_checkpoint()
self.assertIsNotNone(original_wandb_run_id, "WandB run ID should be saved in checkpoint")

del trainer_phase1
# Reinitialize process group for resumed training
if parallel_backend != "ptd" and not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl") # or 'gloo' for CPU

# Phase 2: Resume training from the checkpoint
args_phase2 = self.get_args()
args_phase2.parallel_backend = parallel_backend
args_phase2.resume_from_checkpoint = 5

model_specification_2 = self.model_specification_cls()
trainer_phase2 = SFTTrainer(args_phase2, model_specification_2)
trainer_phase2.run()

# Verify that the resumed training uses the same wandb run ID
resumed_wandb_run_id = trainer_phase2.state.parallel_backend.tracker.get_wandb_run_id()

self.assertIsNotNone(resumed_wandb_run_id, "Resumed training should have a wandb run ID")
self.assertEqual(
original_wandb_run_id,
resumed_wandb_run_id,
f"WandB run ID should be the same after resumption. "
f"Original: {original_wandb_run_id}, Resumed: {resumed_wandb_run_id}",
)

del trainer_phase2
1 change: 1 addition & 0 deletions tests/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .test_sft_trainer import SFTTrainerFastTestsMixin