4444
4545if TYPE_CHECKING :
4646 from deepspeed import DeepSpeedEngine
47+ from torch .optim .lr_scheduler import _LRScheduler
4748
4849_DEEPSPEED_AVAILABLE = RequirementCache ("deepspeed" )
4950_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache ("deepspeed>=0.14.1" )
@@ -316,25 +317,24 @@ def model(self) -> "DeepSpeedEngine":
316317
317318 @override
318319 def setup_module_and_optimizers (
319- self , module : Module , optimizers : list [Optimizer ]
320- ) -> tuple ["DeepSpeedEngine" , list [Optimizer ]]:
321- """Set up a model and multiple optimizers together.
322-
323- Currently, only a single optimizer is supported.
320+ self , module : Module , optimizers : list [Optimizer ], scheduler : Optional ["_LRScheduler" ] = None
321+ ) -> tuple ["DeepSpeedEngine" , list [Optimizer ], Any ]:
322+ """Set up a model and multiple optimizers together, along with an optional learning rate scheduler. Currently,
323+ only a single optimizer is supported.
324324
325325 Return:
326- The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
327- deepspeed optimizer.
326+ The model wrapped into a :class:`deepspeed.DeepSpeedEngine`, a list with a single
327+ deepspeed optimizer, and an optional learning rate scheduler .
328328
329329 """
330330 if len (optimizers ) != 1 :
331331 raise ValueError (
332332 f"Currently only one optimizer is supported with DeepSpeed. Got { len (optimizers )} optimizers instead."
333333 )
334334
335- self ._deepspeed_engine , optimizer = self ._initialize_engine (module , optimizers [0 ])
335+ self ._deepspeed_engine , optimizer , scheduler = self ._initialize_engine (module , optimizers [0 ], scheduler )
336336 self ._set_deepspeed_activation_checkpointing ()
337- return self ._deepspeed_engine , [optimizer ]
337+ return self ._deepspeed_engine , [optimizer ], scheduler
338338
339339 @override
340340 def setup_module (self , module : Module ) -> "DeepSpeedEngine" :
@@ -343,7 +343,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
343343 For training, see :meth:`setup_module_and_optimizers`.
344344
345345 """
346- self ._deepspeed_engine , _ = self ._initialize_engine (module )
346+ self ._deepspeed_engine , _ , _ = self ._initialize_engine (module )
347347 return self ._deepspeed_engine
348348
349349 @override
@@ -596,10 +596,8 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
596596 )
597597
598598 def _initialize_engine (
599- self ,
600- model : Module ,
601- optimizer : Optional [Optimizer ] = None ,
602- ) -> tuple ["DeepSpeedEngine" , Optimizer ]:
599+ self , model : Module , optimizer : Optional [Optimizer ] = None , scheduler : Optional ["_LRScheduler" ] = None
600+ ) -> tuple ["DeepSpeedEngine" , Optimizer , Any ]:
603601 """Initialize one model and one optimizer with an optional learning rate scheduler.
604602
605603 This calls ``deepspeed.initialize`` internally.
@@ -608,15 +606,16 @@ def _initialize_engine(
608606 import deepspeed
609607
610608 model_parameters = filter (lambda p : p .requires_grad , model .parameters ())
611- deepspeed_engine , deepspeed_optimizer , _ , _ = deepspeed .initialize (
609+ deepspeed_engine , deepspeed_optimizer , _ , deepspeed_scheduler = deepspeed .initialize (
612610 args = argparse .Namespace (device_rank = self .root_device .index ),
613611 config = self .config ,
614612 model = model ,
615613 model_parameters = model_parameters ,
616614 optimizer = optimizer ,
615+ lr_scheduler = scheduler ,
617616 dist_init_required = False ,
618617 )
619- return deepspeed_engine , deepspeed_optimizer
618+ return deepspeed_engine , deepspeed_optimizer , deepspeed_scheduler
620619
621620 @override
622621 def setup_environment (self ) -> None :
0 commit comments