Skip to content

Commit afa7d56

Browse files
amoreheadpre-commit-ci[bot]lantigaBorda
authored
Add learning rate scheduling support for DeepSpeedStrategy (#20320)
* Update fabric.py * Update deepspeed.py * Update fsdp.py * Update strategy.py * Update xla_fsdp.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <[email protected]> Co-authored-by: Luca Antiga <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent d9f0e44 commit afa7d56

File tree

10 files changed

+48
-31
lines changed

10 files changed

+48
-31
lines changed

docs/source-fabric/api/fabric_methods.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,17 @@ Moves the model and optimizer to the correct device automatically.
4040
4141
model = nn.Linear(32, 64)
4242
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
43+
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.3, total_iters=10)
4344
4445
# Set up model and optimizer for accelerated training
4546
model, optimizer = fabric.setup(model, optimizer)
4647
4748
# If you don't want Fabric to set the device
4849
model, optimizer = fabric.setup(model, optimizer, move_to_device=False)
4950
51+
# If you want to additionally register a learning rate scheduler with compatible strategies such as DeepSpeed
52+
model, optimizer, scheduler = fabric.setup(model, optimizer, scheduler)
53+
5054
5155
The setup method also prepares the model for the selected precision choice so that operations during ``forward()`` get
5256
cast automatically. Advanced users should read :doc:`the notes on models wrapped by Fabric <../api/wrappers>`.

docs/source-fabric/api/wrappers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ If you were to run this model in Fabric with multiple devices (DDP or FSDP), you
124124
# OK: Calling the model directly
125125
output = model(torch.randn(10))
126126
127-
# OK: Calling the model's forward (equivalent to the abvoe)
127+
# OK: Calling the model's forward (equivalent to the above)
128128
output = model.forward(torch.randn(10))
129129
130130
# ERROR: Calling another method that calls forward indirectly

docs/source-fabric/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@
287287
("py:class", "torch.distributed.fsdp.wrap.ModuleWrapPolicy"),
288288
("py:class", "torch.distributed.fsdp.sharded_grad_scaler.ShardedGradScaler"),
289289
("py:class", "torch.amp.grad_scaler.GradScaler"),
290+
("py:class", "torch.optim.lr_scheduler._LRScheduler"),
290291
# Mocked optional packages
291292
("py:class", "deepspeed.*"),
292293
("py:.*", "torch_xla.*"),

src/lightning/fabric/fabric.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import partial
1919
from pathlib import Path
2020
from typing import (
21+
TYPE_CHECKING,
2122
Any,
2223
Callable,
2324
Optional,
@@ -74,6 +75,9 @@
7475
_unwrap_objects,
7576
)
7677

78+
if TYPE_CHECKING:
79+
from torch.optim.lr_scheduler import _LRScheduler
80+
7781

7882
def _do_nothing(*_: Any) -> None:
7983
pass
@@ -206,6 +210,7 @@ def setup(
206210
self,
207211
module: nn.Module,
208212
*optimizers: Optimizer,
213+
scheduler: Optional["_LRScheduler"] = None,
209214
move_to_device: bool = True,
210215
_reapply_compile: bool = True,
211216
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
@@ -214,6 +219,7 @@ def setup(
214219
Args:
215220
module: A :class:`torch.nn.Module` to set up
216221
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
222+
scheduler: The learning rate scheduler to set up (no learning rate scheduler is also possible)
217223
move_to_device: If set ``True`` (default), moves the model to the correct device. Set this to ``False``
218224
and alternatively use :meth:`to_device` manually.
219225
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
@@ -222,7 +228,8 @@ def setup(
222228
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
223229
224230
Returns:
225-
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
231+
The tuple containing wrapped module, optimizers, and an optional learning rate scheduler,
232+
in the same order they were passed in.
226233
227234
"""
228235
self._validate_setup(module, optimizers)
@@ -236,8 +243,8 @@ def setup(
236243

237244
# Let accelerator/plugin wrap and connect the models and optimizers
238245
if optimizers:
239-
module, optimizers = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
240-
module, list(optimizers)
246+
module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
247+
module, list(optimizers), scheduler
241248
)
242249
else:
243250
module = self._strategy.setup_module(module)
@@ -266,7 +273,7 @@ def setup(
266273

267274
if optimizers:
268275
# join both types in a tuple for API convenience
269-
return (module, *optimizers)
276+
return (module, *optimizers, scheduler) if scheduler is not None else (module, *optimizers)
270277
return module
271278

272279
def setup_module(

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
if TYPE_CHECKING:
4646
from deepspeed import DeepSpeedEngine
47+
from torch.optim.lr_scheduler import _LRScheduler
4748

4849
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
4950
_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1")
@@ -316,25 +317,24 @@ def model(self) -> "DeepSpeedEngine":
316317

317318
@override
318319
def setup_module_and_optimizers(
319-
self, module: Module, optimizers: list[Optimizer]
320-
) -> tuple["DeepSpeedEngine", list[Optimizer]]:
321-
"""Set up a model and multiple optimizers together.
322-
323-
Currently, only a single optimizer is supported.
320+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
321+
) -> tuple["DeepSpeedEngine", list[Optimizer], Any]:
322+
"""Set up a model and multiple optimizers together, along with an optional learning rate scheduler. Currently,
323+
only a single optimizer is supported.
324324
325325
Return:
326-
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
327-
deepspeed optimizer.
326+
The model wrapped into a :class:`deepspeed.DeepSpeedEngine`, a list with a single
327+
deepspeed optimizer, and an optional learning rate scheduler.
328328
329329
"""
330330
if len(optimizers) != 1:
331331
raise ValueError(
332332
f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
333333
)
334334

335-
self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
335+
self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
336336
self._set_deepspeed_activation_checkpointing()
337-
return self._deepspeed_engine, [optimizer]
337+
return self._deepspeed_engine, [optimizer], scheduler
338338

339339
@override
340340
def setup_module(self, module: Module) -> "DeepSpeedEngine":
@@ -343,7 +343,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
343343
For training, see :meth:`setup_module_and_optimizers`.
344344
345345
"""
346-
self._deepspeed_engine, _ = self._initialize_engine(module)
346+
self._deepspeed_engine, _, _ = self._initialize_engine(module)
347347
return self._deepspeed_engine
348348

349349
@override
@@ -596,10 +596,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
596596
)
597597

598598
def _initialize_engine(
599-
self,
600-
model: Module,
601-
optimizer: Optional[Optimizer] = None,
602-
) -> tuple["DeepSpeedEngine", Optimizer]:
599+
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None
600+
) -> tuple["DeepSpeedEngine", Optimizer, Any]:
603601
"""Initialize one model and one optimizer with an optional learning rate scheduler.
604602
605603
This calls ``deepspeed.initialize`` internally.
@@ -608,15 +606,16 @@ def _initialize_engine(
608606
import deepspeed
609607

610608
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
611-
deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
609+
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
612610
args=argparse.Namespace(device_rank=self.root_device.index),
613611
config=self.config,
614612
model=model,
615613
model_parameters=model_parameters,
616614
optimizer=optimizer,
615+
lr_scheduler=scheduler,
617616
dist_init_required=False,
618617
)
619-
return deepspeed_engine, deepspeed_optimizer
618+
return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler
620619

621620
@override
622621
def setup_environment(self) -> None:

src/lightning/fabric/strategies/fsdp.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from torch.distributed.device_mesh import DeviceMesh
7272
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
7373
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
74+
from torch.optim.lr_scheduler import _LRScheduler
7475

7576
_POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
7677
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
@@ -261,8 +262,8 @@ def setup_environment(self) -> None:
261262

262263
@override
263264
def setup_module_and_optimizers(
264-
self, module: Module, optimizers: list[Optimizer]
265-
) -> tuple[Module, list[Optimizer]]:
265+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
266+
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
266267
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
267268
module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer."""
268269
use_orig_params = self._fsdp_kwargs.get("use_orig_params")
@@ -274,7 +275,7 @@ def setup_module_and_optimizers(
274275
" call `setup_optimizer`."
275276
)
276277
module = self.setup_module(module)
277-
return module, optimizers
278+
return module, optimizers, scheduler
278279

279280
@override
280281
def setup_module(self, module: Module) -> Module:

src/lightning/fabric/strategies/strategy.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Iterable
1717
from contextlib import AbstractContextManager, ExitStack
18-
from typing import Any, Callable, Optional, TypeVar, Union
18+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
1919

2020
import torch
2121
from torch import Tensor
@@ -33,6 +33,9 @@
3333
from lightning.fabric.utilities.init import _EmptyInit
3434
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp, _Stateful
3535

36+
if TYPE_CHECKING:
37+
from torch.optim.lr_scheduler import _LRScheduler
38+
3639
TBroadcast = TypeVar("TBroadcast")
3740
TReduce = TypeVar("TReduce")
3841

@@ -145,8 +148,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont
145148
return stack
146149

147150
def setup_module_and_optimizers(
148-
self, module: Module, optimizers: list[Optimizer]
149-
) -> tuple[Module, list[Optimizer]]:
151+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
152+
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
150153
"""Set up a model and multiple optimizers together.
151154
152155
The returned objects are expected to be in the same order they were passed in. The default implementation will
@@ -155,7 +158,7 @@ def setup_module_and_optimizers(
155158
"""
156159
module = self.setup_module(module)
157160
optimizers = [self.setup_optimizer(optimizer) for optimizer in optimizers]
158-
return module, optimizers
161+
return module, optimizers, scheduler
159162

160163
def setup_module(self, module: Module) -> Module:
161164
"""Performs setup for the model, e.g., by wrapping it by another class."""

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp
4545

4646
if TYPE_CHECKING:
47+
from torch.optim.lr_scheduler import _LRScheduler
4748
from torch_xla.distributed.parallel_loader import MpDeviceLoader
4849

4950
_POLICY_SET = set[type[Module]]
@@ -196,8 +197,8 @@ def setup_environment(self) -> None:
196197

197198
@override
198199
def setup_module_and_optimizers(
199-
self, module: Module, optimizers: list[Optimizer]
200-
) -> tuple[Module, list[Optimizer]]:
200+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
201+
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
201202
"""Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup."""
202203
raise NotImplementedError(
203204
f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."

tests/tests_fabric/strategies/test_deepspeed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def test_deepspeed_setup_module(init_mock):
137137
model=model,
138138
model_parameters=ANY,
139139
optimizer=None,
140+
lr_scheduler=None,
140141
dist_init_required=False,
141142
)
142143

tests/tests_fabric/strategies/test_model_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_parallelize_fn_call():
102102
strategy = ModelParallelStrategy(parallelize_fn=parallelize_fn)
103103
strategy._device_mesh = Mock()
104104
strategy.parallel_devices = [torch.device("cpu")]
105-
model_setup, [optimizer_setup] = strategy.setup_module_and_optimizers(model, [optimizer])
105+
model_setup, [optimizer_setup], _ = strategy.setup_module_and_optimizers(model, [optimizer])
106106
assert model_setup is parallel_model_mock
107107
assert optimizer_setup is optimizer
108108
parallelize_fn.assert_called_with(model, strategy.device_mesh)

0 commit comments

Comments
 (0)