44
44
45
45
if TYPE_CHECKING:
46
46
from deepspeed import DeepSpeedEngine
47
+ from torch.optim.lr_scheduler import _LRScheduler
47
48
48
49
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
49
50
_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1")
@@ -316,25 +317,24 @@ def model(self) -> "DeepSpeedEngine":
316
317
317
318
@override
318
319
def setup_module_and_optimizers(
319
- self, module: Module, optimizers: list[Optimizer]
320
- ) -> tuple["DeepSpeedEngine", list[Optimizer]]:
321
- """Set up a model and multiple optimizers together.
322
-
323
- Currently, only a single optimizer is supported.
320
+ self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
321
+ ) -> tuple["DeepSpeedEngine", list[Optimizer], Any]:
322
+ """Set up a model and multiple optimizers together, along with an optional learning rate scheduler. Currently,
323
+ only a single optimizer is supported.
324
324
325
325
Return:
326
- The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
327
- deepspeed optimizer.
326
+ The model wrapped into a :class:`deepspeed.DeepSpeedEngine`, a list with a single
327
+ deepspeed optimizer, and an optional learning rate scheduler .
328
328
329
329
"""
330
330
if len(optimizers) != 1:
331
331
raise ValueError(
332
332
f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead."
333
333
)
334
334
335
- self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0])
335
+ self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler )
336
336
self._set_deepspeed_activation_checkpointing()
337
- return self._deepspeed_engine, [optimizer]
337
+ return self._deepspeed_engine, [optimizer], scheduler
338
338
339
339
@override
340
340
def setup_module(self, module: Module) -> "DeepSpeedEngine":
@@ -343,7 +343,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
343
343
For training, see :meth:`setup_module_and_optimizers`.
344
344
345
345
"""
346
- self._deepspeed_engine, _ = self._initialize_engine(module)
346
+ self._deepspeed_engine, _, _ = self._initialize_engine(module)
347
347
return self._deepspeed_engine
348
348
349
349
@override
@@ -596,10 +596,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
596
596
)
597
597
598
598
def _initialize_engine(
599
- self,
600
- model: Module,
601
- optimizer: Optional[Optimizer] = None,
602
- ) -> tuple["DeepSpeedEngine", Optimizer]:
599
+ self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None
600
+ ) -> tuple["DeepSpeedEngine", Optimizer, Any]:
603
601
"""Initialize one model and one optimizer with an optional learning rate scheduler.
604
602
605
603
This calls ``deepspeed.initialize`` internally.
@@ -608,15 +606,16 @@ def _initialize_engine(
608
606
import deepspeed
609
607
610
608
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
611
- deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
609
+ deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
612
610
args=argparse.Namespace(device_rank=self.root_device.index),
613
611
config=self.config,
614
612
model=model,
615
613
model_parameters=model_parameters,
616
614
optimizer=optimizer,
615
+ lr_scheduler=scheduler,
617
616
dist_init_required=False,
618
617
)
619
- return deepspeed_engine, deepspeed_optimizer
618
+ return deepspeed_engine, deepspeed_optimizer, deepspeed_scheduler
620
619
621
620
@override
622
621
def setup_environment(self) -> None:
0 commit comments