Skip to content

Commit 1025875

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 67089a1 commit 1025875

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

src/lightning/fabric/fabric.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,16 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
209209
210210
"""
211211

212-
def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_LRScheduler] = None, move_to_device: bool = True, _reapply_compile: bool = True,) -> Any: # no specific return because the way we want our API to look does not play well with mypy
212+
def setup(
213+
self,
214+
module: nn.Module,
215+
*optimizers: Optimizer,
216+
scheduler: Optional[_LRScheduler] = None,
217+
move_to_device: bool = True,
218+
_reapply_compile: bool = True,
219+
) -> Any: # no specific return because the way we want our API to look does not play well with mypy
213220
r"""Set up a model and its optimizers for accelerated training.
214-
221+
215222
Args:
216223
module: A :class:`torch.nn.Module` to set up
217224
*optimizers: The optimizer(s) to set up (no optimizers is also possible)
@@ -222,50 +229,50 @@ def setup(self, module: nn.Module, *optimizers: Optimizer, scheduler: Optional[_
222229
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
223230
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
224231
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
225-
232+
226233
Returns:
227234
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
228-
235+
229236
"""
230237
self._validate_setup(module, optimizers)
231238
module, compile_kwargs = _unwrap_compiled(module) if _reapply_compile else (module, None)
232239
original_module = module
233-
240+
234241
module = self._precision.convert_module(module)
235-
242+
236243
if move_to_device:
237244
module = self._move_model_to_device(model=module, optimizers=list(optimizers))
238-
245+
239246
# Let accelerator/plugin wrap and connect the models and optimizers
240247
if optimizers:
241248
module, optimizers, scheduler = self._strategy.setup_module_and_optimizers( # type: ignore[assignment]
242249
module, list(optimizers), scheduler
243250
)
244251
else:
245252
module = self._strategy.setup_module(module)
246-
253+
247254
if compile_kwargs is not None:
248255
module = _to_compiled(module, compile_kwargs)
249256
module = _FabricModule(module, self._strategy, original_module=original_module)
250-
257+
251258
# Update the _DeviceDtypeModuleMixin's device parameter
252259
# NOTE: for sharded strategies or manual device placement, there's no single root device
253260
_update_properties(
254261
module, device=self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device
255262
)
256-
263+
257264
optimizers = [_FabricOptimizer(optimizer, self._strategy, self._callbacks) for optimizer in optimizers]
258-
265+
259266
self._models_setup += 1
260-
267+
261268
if hasattr(original_module, "_fabric"): # this is probably a LightningModule
262269
original_module._fabric = self
263270
original_module._fabric_optimizers = optimizers
264271
if original_module not in self._callbacks:
265272
self._callbacks.append(original_module)
266-
273+
267274
self.call("on_after_setup", fabric=self, module=module)
268-
275+
269276
if optimizers:
270277
# join both types in a tuple for API convenience
271278
return (module, *optimizers, scheduler)

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -311,22 +311,24 @@ def model(self) -> "DeepSpeedEngine":
311311
return self._deepspeed_engine
312312

313313
@override
314-
def setup_module_and_optimizers(self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]:
314+
def setup_module_and_optimizers(
315+
self, module: Module, optimizers: List[Optimizer], scheduler: Optional[_LRScheduler] = None
316+
) -> Tuple["DeepSpeedEngine", List[Optimizer], Optional[_LRScheduler]]:
315317
"""Set up a model and multiple optimizers together along with an optional learning rate scheduler.
316-
318+
317319
Currently, only a single optimizer is supported.
318-
320+
319321
Return:
320322
The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single
321323
deepspeed optimizer.
322-
324+
323325
"""
324326
if len(optimizers) != 1:
325327
raise ValueError(
326328
f"Currently only one optimizer is supported with DeepSpeed."
327329
f" Got {len(optimizers)} optimizers instead."
328330
)
329-
331+
330332
self._deepspeed_engine, optimizer, scheduler = self._initialize_engine(module, optimizers[0], scheduler)
331333
self._set_deepspeed_activation_checkpointing()
332334
return self._deepspeed_engine, [optimizer], scheduler
@@ -590,14 +592,16 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
590592
offload_optimizer_device="nvme",
591593
)
592594

593-
def _initialize_engine(self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]:
595+
def _initialize_engine(
596+
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None
597+
) -> Tuple["DeepSpeedEngine", Optimizer, Optional[_LRScheduler]]:
594598
"""Initialize one model and one optimizer with an optional learning rate scheduler.
595-
599+
596600
This calls :func:`deepspeed.initialize` internally.
597-
601+
598602
"""
599603
import deepspeed
600-
604+
601605
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
602606
deepspeed_engine, deepspeed_optimizer, _, deepspeed_scheduler = deepspeed.initialize(
603607
args=argparse.Namespace(device_rank=self.root_device.index),

0 commit comments

Comments
 (0)