Skip to content
Merged
20 changes: 16 additions & 4 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,10 @@ def __init__(
sub_group_size=sub_group_size,
)

import deepspeed

self._config_initialized = False
deepspeed.utils.logging.logger.setLevel(logging_level)
# Defer importing and configuring DeepSpeed logging until it is actually needed.
# Store the desired logging level to be applied on first use.
self._logging_level = logging_level

self.remote_device = remote_device
self.load_full_weights = load_full_weights
Expand Down Expand Up @@ -374,6 +374,8 @@ def module_sharded_context(self) -> AbstractContextManager:

import deepspeed

deepspeed.utils.logging.logger.setLevel(self._logging_level)

assert self._config_initialized
return deepspeed.zero.Init(
enabled=self.zero_stage_3,
Expand Down Expand Up @@ -601,6 +603,8 @@ def _initialize_engine(
"""
import deepspeed

deepspeed.utils.logging.logger.setLevel(self._logging_level)

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
args=argparse.Namespace(device_rank=self.root_device.index),
Expand Down Expand Up @@ -628,14 +632,20 @@ def _setup_distributed(self) -> None:
_validate_device_index_selection(self.parallel_devices)
reset_seed()
self._set_world_ranks()
self._init_deepspeed_distributed()
# Avoid initializing DeepSpeed distributed for single-process runs. This also avoids importing
# DeepSpeed in environments where it may not be fully functional (e.g., missing nvcc),
# while still allowing configuration and dataloader setup logic to run.
if self.world_size > 1:
self._init_deepspeed_distributed()
if not self._config_initialized:
self._format_config()
self._config_initialized = True

def _init_deepspeed_distributed(self) -> None:
import deepspeed

deepspeed.utils.logging.logger.setLevel(self._logging_level)

assert self.cluster_environment is not None
if platform.system() != "Windows":
# do not set env variables on windows, allow deepspeed to control setup
Expand All @@ -661,6 +671,8 @@ def _set_node_environment_variables(self) -> None:
def _set_deepspeed_activation_checkpointing(self) -> None:
import deepspeed

deepspeed.utils.logging.logger.setLevel(self._logging_level)

assert isinstance(self.config, dict)
if self.config.get("activation_checkpointing"):
checkpoint_config = self.config["activation_checkpointing"]
Expand Down
56 changes: 53 additions & 3 deletions src/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
# When using step/time-based checkpointing with a validation-only monitored metric,
# defer the save until validation has produced the metric
self._defer_save_until_validation: bool = False

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
Expand Down Expand Up @@ -306,14 +309,17 @@ def on_train_batch_end(
batch_idx: int,
) -> None:
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
if self._should_skip_saving_checkpoint(trainer):
return
# Do not return early here because we may need to set deferral flags even
# if a save already happened at this global step. We'll enforce the skip
# just before actually saving below.
skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)

train_time_interval = self._train_time_interval
skip_time = True
now = time.monotonic()
if train_time_interval:
# Important: allow zero timedelta as a valid interval
if train_time_interval is not None:
prev_time_check = self._last_time_checked
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
Expand All @@ -326,6 +332,42 @@ def on_train_batch_end(
self._last_time_checked = now

monitor_candidates = self._monitor_candidates(trainer)
# If monitoring a metric that is not yet available (e.g., validation-only),
# defer saving until validation end so the metric is present.
if self.monitor is not None and self.monitor not in monitor_candidates:
# Defer both top-k and last to avoid blocking with `_last_global_step_saved`
self._defer_save_until_validation = True
return

# Even if the monitored key exists, it could be stale from a previous validation.
# If validation is scheduled to run right after this batch (e.g., last batch of epoch)
# and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
if (
self.monitor is not None
and not self._should_save_on_train_epoch_end(trainer)
and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False)
):
# Only defer if a validation loop is expected to run after this batch.
will_run_val = False
if getattr(trainer, "enable_validation", False):
num_val_batches = (
sum(trainer.num_val_batches)
if isinstance(trainer.num_val_batches, list)
else trainer.num_val_batches
)
if num_val_batches and num_val_batches > 0:
cve = trainer.check_val_every_n_epoch
if cve is None or ((trainer.current_epoch + 1) % cve == 0):
will_run_val = True

if will_run_val:
self._defer_save_until_validation = True
return

# Only proceed to save if not skipping due to trainer/callback state
if skip_due_to_state:
return

self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)

Expand All @@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
"""Save a checkpoint at the end of the validation stage."""
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
# If a step/time-triggered save was deferred due to a missing monitored metric,
# perform the save now that validation metrics are available.
if self._defer_save_until_validation:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
self._defer_save_until_validation = False
return

if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
Expand Down
229 changes: 229 additions & 0 deletions tests/tests_fabric/strategies/test_deepspeed_imports_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Copyright The Lightning AI team.
# This test file provides CPU-only coverage for DeepSpeed lazy-import paths by mocking a minimal
# `deepspeed` module. It does not require GPUs or the real DeepSpeed package.

import sys
from types import ModuleType
from unittest.mock import Mock

import pytest

from lightning.fabric.strategies import DeepSpeedStrategy


class _FakeLogger:
def __init__(self):
self.levels = []

def setLevel(self, lvl):
self.levels.append(lvl)


class _FakeZeroInit:
def __init__(self, *args, **kwargs):
# record for assertions
self.args = args
self.kwargs = kwargs

def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
return False


@pytest.fixture
def fake_deepspeed(monkeypatch):
"""Inject a minimal fake `deepspeed` package into sys.modules."""
ds = ModuleType("deepspeed")
# Mark as a package with a spec and path so importlib won't complain
import importlib.machinery

ds.__spec__ = importlib.machinery.ModuleSpec("deepspeed", loader=Mock(), is_package=True)
ds.__path__ = [] # type: ignore[attr-defined]

# utils.logging.logger
utils_mod = ModuleType("deepspeed.utils")
logging_mod = ModuleType("deepspeed.utils.logging")
utils_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils", loader=Mock(), is_package=True)
logging_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils.logging", loader=Mock(), is_package=False)
logger = _FakeLogger()
logging_mod.logger = logger
utils_mod.logging = logging_mod
ds.utils = utils_mod

# zero.Init
zero_mod = ModuleType("deepspeed.zero")
zero_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.zero", loader=Mock(), is_package=False)
zero_mod.Init = _FakeZeroInit
ds.zero = zero_mod

# checkpointing.configure
checkpointing_mod = ModuleType("deepspeed.checkpointing")
checkpointing_mod.__spec__ = importlib.machinery.ModuleSpec(
"deepspeed.checkpointing", loader=Mock(), is_package=False
)
recorded = {"configure_calls": []}

def _configure(**kwargs):
recorded["configure_calls"].append(kwargs)

checkpointing_mod.configure = _configure
ds.checkpointing = checkpointing_mod

# initialize
recorded["initialize_calls"] = []

def _initialize(**kwargs):
recorded["initialize_calls"].append(kwargs)
# return values: (engine, optimizer, _, scheduler)
return Mock(name="engine"), Mock(name="optimizer"), None, Mock(name="scheduler")

ds.initialize = _initialize

# init_distributed
recorded["init_distributed_calls"] = []

def _init_distributed(*args, **kwargs):
recorded["init_distributed_calls"].append((args, kwargs))

ds.init_distributed = _init_distributed

# install into sys.modules
monkeypatch.setitem(sys.modules, "deepspeed", ds)
monkeypatch.setitem(sys.modules, "deepspeed.utils", utils_mod)
monkeypatch.setitem(sys.modules, "deepspeed.utils.logging", logging_mod)
monkeypatch.setitem(sys.modules, "deepspeed.zero", zero_mod)
monkeypatch.setitem(sys.modules, "deepspeed.checkpointing", checkpointing_mod)

# Pretend deepspeed is installed by forcing availability flag to True
monkeypatch.setattr("lightning.fabric.strategies.deepspeed._DEEPSPEED_AVAILABLE", True, raising=False)

return ds, logger, recorded


def _make_strategy_with_defaults():
# Use defaults; we'll tweak attributes per test as needed
return DeepSpeedStrategy()


def _get_backend() -> str:
# simple helper used to override strategy._get_process_group_backend
return "gloo"


def test_module_sharded_context_sets_logger_and_returns_zero_init(fake_deepspeed):
ds_mod, logger, recorded = fake_deepspeed

strategy = _make_strategy_with_defaults()
# The context asserts that the config was initialized
strategy._config_initialized = True # type: ignore[attr-defined]

ctx = strategy.module_sharded_context()
assert isinstance(ctx, _FakeZeroInit)
# logger.setLevel should be called at least once
assert len(logger.levels) >= 1


def test_initialize_engine_import_and_logger_and_call(fake_deepspeed):
ds_mod, logger, recorded = fake_deepspeed

strategy = _make_strategy_with_defaults()
# root_device.index is read; use a CUDA device number even on CPU-only hosts (no allocation happens)
import torch

strategy.parallel_devices = [torch.device("cuda", 0)] # type: ignore[attr-defined]

class _Param:
requires_grad = True

model = Mock()
model.parameters.return_value = [_Param()]

engine, optimizer, scheduler = strategy._initialize_engine(model)

# assertions
assert len(logger.levels) >= 1
assert recorded["initialize_calls"], "deepspeed.initialize was not called"
call = recorded["initialize_calls"][0]
assert call["config"] == strategy.config
assert call["model"] is model
assert call["dist_init_required"] is False
# returned mocks are propagated
from unittest.mock import Mock as _M

assert isinstance(engine, _M)
assert engine._mock_name == "engine"
assert isinstance(optimizer, _M)
assert optimizer._mock_name == "optimizer"
assert isinstance(scheduler, _M)
assert scheduler._mock_name == "scheduler"


def test_init_deepspeed_distributed_calls_import_and_init(fake_deepspeed, monkeypatch):
ds_mod, logger, recorded = fake_deepspeed

strategy = _make_strategy_with_defaults()

# minimal cluster env
class _CE:
main_port = 12345
main_address = "127.0.0.1"

def global_rank(self):
return 0

def local_rank(self):
return 0

def node_rank(self):
return 0

def world_size(self):
return 1

def teardown(self):
pass

strategy.cluster_environment = _CE()
strategy._process_group_backend = "gloo" # avoid CUDA requirement
strategy._timeout = 300 # type: ignore[attr-defined]

strategy._get_process_group_backend = _get_backend # type: ignore[assignment]

# ensure non-Windows path
monkeypatch.setattr("platform.system", lambda: "Linux")

strategy._init_deepspeed_distributed()

assert len(logger.levels) >= 1
assert recorded["init_distributed_calls"], "deepspeed.init_distributed was not called"
args, kwargs = recorded["init_distributed_calls"][0]
assert args[0] == "gloo"
assert kwargs["distributed_port"] == 12345
assert "timeout" in kwargs


def test_set_deepspeed_activation_checkpointing_configured(fake_deepspeed):
ds_mod, logger, recorded = fake_deepspeed

strategy = _make_strategy_with_defaults()
# ensure config contains activation_checkpointing keys
assert isinstance(strategy.config, dict)
strategy.config.setdefault("activation_checkpointing", {})
strategy.config["activation_checkpointing"].update({
"partition_activations": True,
"contiguous_memory_optimization": False,
"cpu_checkpointing": True,
"profile": False,
})

strategy._set_deepspeed_activation_checkpointing()

assert len(logger.levels) >= 1
assert recorded["configure_calls"], "deepspeed.checkpointing.configure was not called"
cfg = recorded["configure_calls"][0]
assert cfg["partition_activations"] is True
assert cfg["contiguous_checkpointing"] is False
assert cfg["checkpoint_in_cpu"] is True
assert cfg["profile"] is False
Loading
Loading