Skip to content

Commit 5071a04

Browse files
SkafteNickiBorda
andauthored
Add support for deepspeeds exclude_frozen_parameters (#21060)
* add to deepspeed strategies * add testing * changelog * GLOO_SOCKET_IFNAME --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent 4824cc1 commit 5071a04

File tree

7 files changed

+95
-4
lines changed

7 files changed

+95
-4
lines changed

.github/workflows/ci-tests-pytorch.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ jobs:
8383
- name: basic setup
8484
run: pip install -q -r .actions/requirements.txt
8585

86+
- name: Append Env. vars for Linux
87+
if: ${{ runner.os == 'Linux' }}
88+
run: echo "GLOO_SOCKET_IFNAME=eth0" >> $GITHUB_ENV
89+
8690
- name: Set min. dependencies
8791
if: ${{ matrix.config.requires == 'oldest' }}
8892
run: |

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))
13+
14+
1215
-
1316

1417

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
precision: Optional[Precision] = None,
101101
process_group_backend: Optional[str] = None,
102102
timeout: Optional[timedelta] = default_pg_timeout,
103+
exclude_frozen_parameters: bool = False,
103104
) -> None:
104105
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
105106
billion parameter models. `For more information: https://pytorch-
@@ -229,6 +230,8 @@ def __init__(
229230
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
230231
per worker.
231232
233+
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
234+
232235
"""
233236
if not _DEEPSPEED_AVAILABLE:
234237
raise ImportError(
@@ -289,6 +292,7 @@ def __init__(
289292

290293
self.remote_device = remote_device
291294
self.load_full_weights = load_full_weights
295+
self.exclude_frozen_parameters = exclude_frozen_parameters
292296

293297
# default FP16 parameters.
294298
self.loss_scale = loss_scale
@@ -444,7 +448,9 @@ def save_checkpoint(
444448
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
445449
state = self._convert_stateful_objects_in_state(state, filter={})
446450
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
447-
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
451+
engine.save_checkpoint(
452+
path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters
453+
)
448454

449455
@override
450456
def load_checkpoint(

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- Added Torch-Tensorrt integration with `LightningModule` ([#20808](https://github.com/Lightning-AI/pytorch-lightning/pull/20808))
1717

1818

19+
- Added `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))
20+
21+
1922
- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
2023

24+
2125
### Changed
2226

2327
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))

src/lightning/pytorch/strategies/deepspeed.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def __init__(
122122
precision_plugin: Optional[Precision] = None,
123123
process_group_backend: Optional[str] = None,
124124
timeout: Optional[timedelta] = default_pg_timeout,
125+
exclude_frozen_parameters: bool = False,
125126
) -> None:
126127
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
127128
billion parameter models. `For more information: https://pytorch-
@@ -253,6 +254,8 @@ def __init__(
253254
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
254255
per worker.
255256
257+
exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
258+
256259
"""
257260
if not _DEEPSPEED_AVAILABLE:
258261
raise MisconfigurationException(
@@ -311,6 +314,7 @@ def __init__(
311314

312315
self.remote_device = remote_device
313316
self.load_full_weights = load_full_weights
317+
self.exclude_frozen_parameters = exclude_frozen_parameters
314318

315319
# default FP16 parameters.
316320
self.loss_scale = loss_scale
@@ -648,7 +652,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
648652
# dump states as a checkpoint dictionary object
649653
_exclude_keys = ["state_dict", "optimizer_states"]
650654
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
651-
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint")
655+
self.deepspeed_engine.save_checkpoint(
656+
filepath,
657+
client_state=checkpoint,
658+
tag="checkpoint",
659+
exclude_frozen_parameters=self.exclude_frozen_parameters,
660+
)
652661

653662
@override
654663
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,19 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
194194
model.modules.return_value = [model]
195195
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
196196
# the client_state should not contain any deepspeed engine or deepspeed optimizer
197-
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
197+
model.save_checkpoint.assert_called_with(
198+
tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
199+
)
198200

199201
# Model and optimizer
200202
optimizer = Mock()
201203
model = Mock(spec=DeepSpeedEngine, optimizer=optimizer)
202204
model.modules.return_value = [model]
203205
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
204206
# the client_state should not contain any deepspeed engine or deepspeed optimizer
205-
model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
207+
model.save_checkpoint.assert_called_with(
208+
tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
209+
)
206210

207211

208212
@RunIf(deepspeed=True)
@@ -219,6 +223,27 @@ def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
219223
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
220224

221225

226+
@RunIf(deepspeed=True)
227+
@pytest.mark.parametrize("exclude_frozen_parameters", [True, False])
228+
def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_parameters):
229+
"""Test that the DeepSpeed strategy can save checkpoints with the `exclude_frozen_parameters` argument."""
230+
from deepspeed import DeepSpeedEngine
231+
232+
strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters)
233+
assert strategy.exclude_frozen_parameters is exclude_frozen_parameters
234+
235+
model = Mock(spec=DeepSpeedEngine, optimizer=None)
236+
model.modules.return_value = [model]
237+
strategy.save_checkpoint(path="test_path", state={"model": model, "extra": "data"})
238+
239+
model.save_checkpoint.assert_called_with(
240+
"test_path",
241+
client_state={"extra": "data"},
242+
tag="checkpoint",
243+
exclude_frozen_parameters=exclude_frozen_parameters,
244+
)
245+
246+
222247
@RunIf(deepspeed=True)
223248
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
224249
"""Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""

tests/tests_pytorch/strategies/test_deepspeed.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,46 @@ def test_deepspeed_multigpu_single_file(tmp_path):
562562
trainer.test(model, ckpt_path=checkpoint_path)
563563

564564

565+
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
566+
def test_deepspeed_strategy_exclude_frozen_parameters_integration(tmp_path):
567+
"""Test end-to-end integration of exclude_frozen_parameters with actual model training and checkpointing."""
568+
569+
class TestModelWithFrozenParams(BoringModel):
570+
def __init__(self):
571+
super().__init__()
572+
self.frozen_layer = torch.nn.Linear(32, 32)
573+
574+
def configure_model(self) -> None:
575+
super().configure_model()
576+
# Freeze the additional layer parameters
577+
for param in self.frozen_layer.parameters():
578+
param.requires_grad = False
579+
580+
def forward(self, x):
581+
x = self.frozen_layer(x)
582+
return super().forward(x)
583+
584+
model = TestModelWithFrozenParams()
585+
586+
trainer = Trainer(
587+
default_root_dir=tmp_path,
588+
strategy=DeepSpeedStrategy(exclude_frozen_parameters=True),
589+
accelerator="gpu",
590+
devices=1,
591+
fast_dev_run=True,
592+
precision="16-mixed",
593+
enable_progress_bar=False,
594+
enable_model_summary=False,
595+
)
596+
597+
trainer.fit(model)
598+
checkpoint_path = os.path.join(tmp_path, "checkpoint_exclude_frozen.ckpt")
599+
trainer.save_checkpoint(checkpoint_path)
600+
601+
# Verify checkpoint was created
602+
assert os.path.exists(checkpoint_path)
603+
604+
565605
class ModelParallelClassificationModel(LightningModule):
566606
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
567607
super().__init__()

0 commit comments

Comments
 (0)