Skip to content

Commit e66dd11

Browse files
committed
Add Deepspeed Zero 3 MiCS support for fabric (Issues #20378, pr #20461)
1 parent ff1efa0 commit e66dd11

File tree

2 files changed

+160
-5
lines changed

2 files changed

+160
-5
lines changed

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,11 +373,21 @@ def module_sharded_context(self) -> AbstractContextManager:
373373
import deepspeed
374374

375375
assert self._config_initialized
376-
return deepspeed.zero.Init(
377-
enabled=self.zero_stage_3,
378-
remote_device=self.remote_device,
379-
config_dict_or_path=self.config,
380-
)
376+
assert self.config is not None
377+
378+
if 'zero_optimization' in self.config and 'mics_shard_size' in self.config['zero_optimization']\
379+
and self.config['zero_optimization']['mics_shard_size'] > 0 and self.zero_stage_3:
380+
return deepspeed.zero.MiCS_Init(
381+
enabled=self.zero_stage_3,
382+
remote_device=self.remote_device,
383+
config_dict_or_path=self.config,
384+
)
385+
else:
386+
return deepspeed.zero.Init(
387+
enabled=self.zero_stage_3,
388+
remote_device=self.remote_device,
389+
config_dict_or_path=self.config,
390+
)
381391

382392
@override
383393
def save_checkpoint(

tests/tests_fabric/strategies/test_deepspeed_integration.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,3 +414,148 @@ def test_deepspeed_init_module_with_stages_1_2(stage, empty_init):
414414
zero_init_mock.assert_called_with(enabled=False, remote_device=None, config_dict_or_path=ANY)
415415
assert init_mock.call_count == int(not empty_init)
416416
assert model.layer.weight.dtype == torch.bfloat16
417+
418+
419+
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
420+
def test_deepspeed_multigpu_stage_3_MiCS_support():
421+
"""Test to ensure ZeRO Stage 3 MiCS works with a parallel model."""
422+
strategy = DeepSpeedStrategy(stage=3)
423+
strategy.config["zero_optimization"]["stage"] = 3
424+
strategy.config["zero_optimization"]["mics_shard_size"] = 1
425+
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
426+
427+
fabric = Fabric(
428+
strategy= strategy,
429+
accelerator="cuda",
430+
devices=2,
431+
precision="16-mixed",
432+
)
433+
fabric.launch()
434+
435+
def _make_block():
436+
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
437+
438+
with fabric.init_module():
439+
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))
440+
441+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
442+
model, optimizer = fabric.setup(model, optimizer)
443+
444+
x = torch.rand(2, 32, device=fabric.device)
445+
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
446+
x = model(x)
447+
x = x.float() # Ensure output is in float32 for softmax operation
448+
logits = F.softmax(x, dim=1)
449+
loss = F.cross_entropy(logits, y)
450+
fabric.backward(loss)
451+
optimizer.step()
452+
optimizer.zero_grad()
453+
454+
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
455+
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_support():
456+
"""Test to ensure we can use DeepSpeed with ZeRO Stage param offload 3 MiCS Support"""
457+
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu")
458+
strategy.config["zero_optimization"]["stage"] = 3
459+
strategy.config["zero_optimization"]["mics_shard_size"] = 1
460+
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
461+
462+
fabric = Fabric(
463+
strategy= strategy,
464+
accelerator="cuda",
465+
devices=2,
466+
precision="16-mixed",
467+
)
468+
fabric.launch()
469+
470+
def _make_block():
471+
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
472+
473+
with fabric.init_module():
474+
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))
475+
476+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
477+
model, optimizer = fabric.setup(model, optimizer)
478+
479+
x = torch.rand(2, 32, device=fabric.device)
480+
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
481+
x = model(x)
482+
x = x.float() # Ensure output is in float32 for softmax operation
483+
logits = F.softmax(x, dim=1)
484+
loss = F.cross_entropy(logits, y)
485+
fabric.backward(loss)
486+
optimizer.step()
487+
optimizer.zero_grad()
488+
489+
490+
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
491+
def test_deepspeed_multigpu_stage_3_MiCS_offload_param_optimizer_support():
492+
"""Test to ensure we can use DeepSpeed with ZeRO Stage param & optimizer offload 3 MiCS Support."""
493+
strategy = DeepSpeedStrategy(stage=3, offload_params_device="cpu", offload_optimizer_device="cpu")
494+
strategy.config["zero_optimization"]["stage"] = 3
495+
strategy.config["zero_optimization"]["mics_shard_size"] = 1
496+
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = False
497+
498+
fabric = Fabric(
499+
strategy= strategy,
500+
accelerator="cuda",
501+
devices=2,
502+
precision="16-mixed",
503+
)
504+
fabric.launch()
505+
506+
def _make_block():
507+
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
508+
509+
with fabric.init_module():
510+
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))
511+
512+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
513+
model, optimizer = fabric.setup(model, optimizer)
514+
515+
x = torch.rand(2, 32, device=fabric.device)
516+
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
517+
x = model(x)
518+
x = x.float() # Ensure output is in float32 for softmax operation
519+
logits = F.softmax(x, dim=1)
520+
loss = F.cross_entropy(logits, y)
521+
fabric.backward(loss)
522+
optimizer.step()
523+
optimizer.zero_grad()
524+
525+
@RunIf(min_cuda_gpus=4, standalone=True, deepspeed=True)
526+
def test_deepspeed_multigpu_stage_3_hierarchical_MiCS_support():
527+
"""Test to ensure we can use DeepSpeed with ZeRO Stage 3 MiCS Support ('mics_hierarchical_params_gather' =
528+
True)."""
529+
strategy = DeepSpeedStrategy(stage=3)
530+
strategy.config["zero_optimization"]["stage"] = 3
531+
strategy.config["zero_optimization"]["mics_shard_size"] = 2
532+
strategy.config["zero_optimization"]["offload_param"] = {}
533+
strategy.config["zero_optimization"]["offload_optimizer"] = {}
534+
strategy.config["zero_optimization"]["mics_hierarchical_params_gather"] = True
535+
536+
fabric = Fabric(
537+
strategy= strategy,
538+
accelerator="cuda",
539+
devices=2,
540+
precision="16-mixed",
541+
)
542+
fabric.launch()
543+
544+
def _make_block():
545+
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
546+
547+
with fabric.init_module():
548+
model = nn.Sequential(*(_make_block() for _ in range(5)), nn.Linear(32, 3))
549+
550+
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
551+
model, optimizer = fabric.setup(model, optimizer)
552+
553+
x = torch.rand(2, 32, device=fabric.device)
554+
y = torch.ones(x.size(0), device=x.device, dtype=torch.long)
555+
x = model(x)
556+
x = x.float() # Ensure output is in float32 for softmax operation
557+
logits = F.softmax(x, dim=1)
558+
loss = F.cross_entropy(logits, y)
559+
fabric.backward(loss)
560+
optimizer.step()
561+
optimizer.zero_grad()

0 commit comments

Comments
 (0)