Skip to content

Commit 6c1554a

Browse files
committed
test: add mock-based CPU tests for DeepSpeed strategy import paths
1 parent 59dda02 commit 6c1554a

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright The Lightning AI team.
2+
# This test file provides CPU-only coverage for DeepSpeed lazy-import paths by mocking a minimal
3+
# `deepspeed` module. It does not require GPUs or the real DeepSpeed package.
4+
5+
import sys
6+
from types import ModuleType
7+
from unittest.mock import Mock
8+
9+
import pytest
10+
11+
from lightning.fabric.strategies import DeepSpeedStrategy
12+
13+
14+
class _FakeLogger:
15+
def __init__(self):
16+
self.levels = []
17+
18+
def setLevel(self, lvl):
19+
self.levels.append(lvl)
20+
21+
22+
class _FakeZeroInit:
23+
def __init__(self, *args, **kwargs):
24+
# record for assertions
25+
self.args = args
26+
self.kwargs = kwargs
27+
28+
def __enter__(self):
29+
return self
30+
31+
def __exit__(self, exc_type, exc, tb):
32+
return False
33+
34+
35+
@pytest.fixture
36+
def fake_deepspeed(monkeypatch):
37+
"""Inject a minimal fake `deepspeed` package into sys.modules."""
38+
ds = ModuleType("deepspeed")
39+
# Mark as a package with a spec and path so importlib won't complain
40+
import importlib.machinery
41+
42+
ds.__spec__ = importlib.machinery.ModuleSpec("deepspeed", loader=Mock(), is_package=True)
43+
ds.__path__ = [] # type: ignore[attr-defined]
44+
45+
# utils.logging.logger
46+
utils_mod = ModuleType("deepspeed.utils")
47+
logging_mod = ModuleType("deepspeed.utils.logging")
48+
utils_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils", loader=Mock(), is_package=True)
49+
logging_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.utils.logging", loader=Mock(), is_package=False)
50+
logger = _FakeLogger()
51+
logging_mod.logger = logger
52+
utils_mod.logging = logging_mod
53+
ds.utils = utils_mod
54+
55+
# zero.Init
56+
zero_mod = ModuleType("deepspeed.zero")
57+
zero_mod.__spec__ = importlib.machinery.ModuleSpec("deepspeed.zero", loader=Mock(), is_package=False)
58+
zero_mod.Init = _FakeZeroInit
59+
ds.zero = zero_mod
60+
61+
# checkpointing.configure
62+
checkpointing_mod = ModuleType("deepspeed.checkpointing")
63+
checkpointing_mod.__spec__ = importlib.machinery.ModuleSpec(
64+
"deepspeed.checkpointing", loader=Mock(), is_package=False
65+
)
66+
recorded = {"configure_calls": []}
67+
68+
def _configure(**kwargs):
69+
recorded["configure_calls"].append(kwargs)
70+
71+
checkpointing_mod.configure = _configure
72+
ds.checkpointing = checkpointing_mod
73+
74+
# initialize
75+
recorded["initialize_calls"] = []
76+
77+
def _initialize(**kwargs):
78+
recorded["initialize_calls"].append(kwargs)
79+
# return values: (engine, optimizer, _, scheduler)
80+
return Mock(name="engine"), Mock(name="optimizer"), None, Mock(name="scheduler")
81+
82+
ds.initialize = _initialize
83+
84+
# init_distributed
85+
recorded["init_distributed_calls"] = []
86+
87+
def _init_distributed(*args, **kwargs):
88+
recorded["init_distributed_calls"].append((args, kwargs))
89+
90+
ds.init_distributed = _init_distributed
91+
92+
# install into sys.modules
93+
monkeypatch.setitem(sys.modules, "deepspeed", ds)
94+
monkeypatch.setitem(sys.modules, "deepspeed.utils", utils_mod)
95+
monkeypatch.setitem(sys.modules, "deepspeed.utils.logging", logging_mod)
96+
monkeypatch.setitem(sys.modules, "deepspeed.zero", zero_mod)
97+
monkeypatch.setitem(sys.modules, "deepspeed.checkpointing", checkpointing_mod)
98+
99+
# Pretend deepspeed is installed by forcing availability flag to True
100+
monkeypatch.setattr("lightning.fabric.strategies.deepspeed._DEEPSPEED_AVAILABLE", True, raising=False)
101+
102+
return ds, logger, recorded
103+
104+
105+
def _make_strategy_with_defaults():
106+
# Use defaults; we'll tweak attributes per test as needed
107+
return DeepSpeedStrategy()
108+
109+
110+
def _get_backend() -> str:
111+
# simple helper used to override strategy._get_process_group_backend
112+
return "gloo"
113+
114+
115+
def test_module_sharded_context_sets_logger_and_returns_zero_init(fake_deepspeed):
116+
ds_mod, logger, recorded = fake_deepspeed
117+
118+
strategy = _make_strategy_with_defaults()
119+
# The context asserts that the config was initialized
120+
strategy._config_initialized = True # type: ignore[attr-defined]
121+
122+
ctx = strategy.module_sharded_context()
123+
assert isinstance(ctx, _FakeZeroInit)
124+
# logger.setLevel should be called at least once
125+
assert len(logger.levels) >= 1
126+
127+
128+
def test_initialize_engine_import_and_logger_and_call(fake_deepspeed):
129+
ds_mod, logger, recorded = fake_deepspeed
130+
131+
strategy = _make_strategy_with_defaults()
132+
# root_device.index is read; use a CUDA device number even on CPU-only hosts (no allocation happens)
133+
import torch
134+
135+
strategy.parallel_devices = [torch.device("cuda", 0)] # type: ignore[attr-defined]
136+
137+
class _Param:
138+
requires_grad = True
139+
140+
model = Mock()
141+
model.parameters.return_value = [_Param()]
142+
143+
engine, optimizer, scheduler = strategy._initialize_engine(model)
144+
145+
# assertions
146+
assert len(logger.levels) >= 1
147+
assert recorded["initialize_calls"], "deepspeed.initialize was not called"
148+
call = recorded["initialize_calls"][0]
149+
assert call["config"] == strategy.config
150+
assert call["model"] is model
151+
assert call["dist_init_required"] is False
152+
# returned mocks are propagated
153+
from unittest.mock import Mock as _M
154+
155+
assert isinstance(engine, _M)
156+
assert engine._mock_name == "engine"
157+
assert isinstance(optimizer, _M)
158+
assert optimizer._mock_name == "optimizer"
159+
assert isinstance(scheduler, _M)
160+
assert scheduler._mock_name == "scheduler"
161+
162+
163+
def test_init_deepspeed_distributed_calls_import_and_init(fake_deepspeed, monkeypatch):
164+
ds_mod, logger, recorded = fake_deepspeed
165+
166+
strategy = _make_strategy_with_defaults()
167+
168+
# minimal cluster env
169+
class _CE:
170+
main_port = 12345
171+
main_address = "127.0.0.1"
172+
173+
def global_rank(self):
174+
return 0
175+
176+
def local_rank(self):
177+
return 0
178+
179+
def node_rank(self):
180+
return 0
181+
182+
def world_size(self):
183+
return 1
184+
185+
def teardown(self):
186+
pass
187+
188+
strategy.cluster_environment = _CE()
189+
strategy._process_group_backend = "gloo" # avoid CUDA requirement
190+
strategy._timeout = 300 # type: ignore[attr-defined]
191+
192+
strategy._get_process_group_backend = _get_backend # type: ignore[assignment]
193+
194+
# ensure non-Windows path
195+
monkeypatch.setattr("platform.system", lambda: "Linux")
196+
197+
strategy._init_deepspeed_distributed()
198+
199+
assert len(logger.levels) >= 1
200+
assert recorded["init_distributed_calls"], "deepspeed.init_distributed was not called"
201+
args, kwargs = recorded["init_distributed_calls"][0]
202+
assert args[0] == "gloo"
203+
assert kwargs["distributed_port"] == 12345
204+
assert "timeout" in kwargs
205+
206+
207+
def test_set_deepspeed_activation_checkpointing_configured(fake_deepspeed):
208+
ds_mod, logger, recorded = fake_deepspeed
209+
210+
strategy = _make_strategy_with_defaults()
211+
# ensure config contains activation_checkpointing keys
212+
assert isinstance(strategy.config, dict)
213+
strategy.config.setdefault("activation_checkpointing", {})
214+
strategy.config["activation_checkpointing"].update({
215+
"partition_activations": True,
216+
"contiguous_memory_optimization": False,
217+
"cpu_checkpointing": True,
218+
"profile": False,
219+
})
220+
221+
strategy._set_deepspeed_activation_checkpointing()
222+
223+
assert len(logger.levels) >= 1
224+
assert recorded["configure_calls"], "deepspeed.checkpointing.configure was not called"
225+
cfg = recorded["configure_calls"][0]
226+
assert cfg["partition_activations"] is True
227+
assert cfg["contiguous_checkpointing"] is False
228+
assert cfg["checkpoint_in_cpu"] is True
229+
assert cfg["profile"] is False

0 commit comments

Comments
 (0)