Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ Here is some helpful information when setting up DeepSpeed ZeRO Stage 3 with Lig
* Treat your GPU/CPU memory as one large pool. In some cases, you may not want to offload certain things (like activations) to provide even more space to offload model parameters
* When offloading to the CPU, make sure to bump up the batch size as GPU memory will be freed
* We also support sharded checkpointing. By passing ``save_full_weights=False`` to the ``DeepSpeedStrategy``, we'll save shards of the model which allows you to save extremely large models. However to load the model and run test/validation/predict you must use the Trainer object.
* DeepSpeed provides `MiCS support <https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.runtime.zero.config.DeepSpeedZeroConfig.mics_shard_size>`_ which allows you to control how model parameters are sharded across GPUs. This can be useful if you have a large cluster of GPUs and want to avoid communication overhead.

.. _deepspeed-zero-stage-3-single-file:

Expand Down
24 changes: 19 additions & 5 deletions src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,25 @@ def module_sharded_context(self) -> AbstractContextManager:
import deepspeed

assert self._config_initialized
return deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
)
assert self.config is not None

if (
"zero_optimization" in self.config
and "mics_shard_size" in self.config["zero_optimization"]
and self.config["zero_optimization"]["mics_shard_size"] > 0
and self.zero_stage_3
):
return deepspeed.zero.MiCS_Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
)
else:
return deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
)

@override
def save_checkpoint(
Expand Down
27 changes: 22 additions & 5 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,29 @@ def model_sharded_context(self) -> Generator[None, None, None]:
import deepspeed

self._init_config_if_needed()
with deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
assert self.config is not None
# If detect 'mics_shard_size'>0 in config['zero_optimization'], alter to use deepspeed.zero.MiCS_Init()
# https://deepspeed.readthedocs.io/en/latest/zero3.html#mics-configurations
#! default deepspeed 0.9.0 is not compatible
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the min version to support this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. MICS was implemented after ds 0.9.2, and my test env is with ds 0.16.0.

Copy link
Contributor

@Borda Borda Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create another PR with bumping dependency and then we can land this peace

deepspeed >=0.14.1,<=0.15.0; platform_system != "Windows" and platform_system != "Darwin" # strict

if (
"zero_optimization" in self.config
and "mics_shard_size" in self.config["zero_optimization"]
and self.config["zero_optimization"]["mics_shard_size"] > 0
and self.zero_stage_3
):
yield
with deepspeed.zero.MiCS_Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
):
yield
else:
with deepspeed.zero.Init(
enabled=self.zero_stage_3,
remote_device=self.remote_device,
config_dict_or_path=self.config,
):
yield

def _set_deepspeed_activation_checkpointing(self) -> None:
import deepspeed
Expand Down
147 changes: 147 additions & 0 deletions tests/tests_fabric/strategies/test_deepspeed_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,150 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init):
zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY)
assert init_mock.call_count == int(not empty_init)
assert model.layer.weight.dtype == torch.bfloat16


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_support():
"""Test to ensure ZeRO Stage 3 MiCS works with a parallel model."""
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support():
"""Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support."""
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support():
"""Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support."""
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()


@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support():
"""Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' =
True)."""
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 2
strategy.config["zero_optimization"]["offload_param"] = {}
strategy.config["zero_optimization"]["offload_optimizer"] = {}
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True

fabric = Fabric(
strategy=strategy,
accelerator="cuda",
devices=2,
precision="16-mixed",
)
fabric.launch()

def _make_block():
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())

with fabric.init_module():
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)

x = torch.rand(2, 32, device=fabric.device)
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
x = model(x)
x = x.float() # Ensure output is in float32 for softmax operation
logits = F.softmax(x, dim=1)
loss = F.cross_entropy(logits, y)
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
123 changes: 123 additions & 0 deletions tests/tests_pytorch/strategies/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,3 +1279,126 @@ def test_deepspeed_load_checkpoint_validate_path(tmp_path):
checkpoint_path.touch()
with pytest.raises(FileNotFoundError, match=f"Try to load using this parent directory instead: {tmp_path}"):
strategy.load_checkpoint(checkpoint_path=checkpoint_path)


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_support(tmp_path):
"""Test to ensure we can use DeepSpeed with basic ZeRO Stage 3 MiCS Support."""
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False

trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=2,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1
assert trainer.strategy.config["zero_optimization"]["stage"] == 3


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support(tmp_path):
"""Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support \
However, in some past pratice, offload param + mics + torchrun will cause inner exception in multi-node environment. \
Probably this exception is caused by torchrun, not deepspeed. """
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=2,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1
assert trainer.strategy.config["zero_optimization"]["stage"] == 3


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support(tmp_path):
"""Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support."""
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu")
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 1
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=2,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is False
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 1
assert trainer.strategy.config["zero_optimization"]["stage"] == 3


@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True)
def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support(tmp_path):
"""Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' =
True)."""
model = ModelParallelBoringModel()
strategy = DeepSpeedStrategy(stage=3)
strategy.config["zero_optimization"]["stage"] = 3
strategy.config["zero_optimization"]["mics_shard_size"] = 2
strategy.config["zero_optimization"]["offload_param"] = {}
strategy.config["zero_optimization"]["offload_optimizer"] = {}
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True
# Forming a 2 x 2 hierarchy
trainer = Trainer(
default_root_dir=tmp_path,
strategy=strategy,
accelerator="gpu",
devices=4,
fast_dev_run=True,
precision="16-mixed",
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.test(model)
trainer.fit(model)

_assert_save_model_is_equal(model, tmp_path, trainer)
assert isinstance(trainer.strategy, DeepSpeedStrategy)
assert "zero_optimization" in trainer.strategy.config
assert trainer.strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] is True
assert trainer.strategy.config["zero_optimization"]["mics_shard_size"] == 2
assert trainer.strategy.config["zero_optimization"]["stage"] == 3
Loading