Skip to content

Commit 087726d

Browse files
committed
typing
1 parent d43e1d1 commit 087726d

File tree

5 files changed

+20
-15
lines changed

5 files changed

+20
-15
lines changed

src/lightning/fabric/fabric.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import partial
1919
from pathlib import Path
2020
from typing import (
21+
TYPE_CHECKING,
2122
Any,
2223
Callable,
2324
Optional,
@@ -32,7 +33,6 @@
3233
from lightning_utilities.core.overrides import is_overridden
3334
from torch import Tensor
3435
from torch.optim import Optimizer
35-
from torch.optim.lr_scheduler import _LRScheduler
3636
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
3737

3838
import lightning.fabric
@@ -75,6 +75,9 @@
7575
_unwrap_objects,
7676
)
7777

78+
if TYPE_CHECKING:
79+
from torch.optim.lr_scheduler import _LRScheduler
80+
7881

7982
def _do_nothing(*_: Any) -> None:
8083
pass
@@ -207,7 +210,7 @@ def setup(
207210
self,
208211
module: nn.Module,
209212
*optimizers: Optimizer,
210-
scheduler: Optional[_LRScheduler] = None,
213+
scheduler: Optional["_LRScheduler"] = None,
211214
move_to_device: bool = True,
212215
_reapply_compile: bool = True,
213216
) -> Any: # no specific return because the way we want our API to look does not play well with mypy

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from lightning_utilities.core.imports import RequirementCache
2828
from torch.nn import Module
2929
from torch.optim import Optimizer
30-
from torch.optim.lr_scheduler import _LRScheduler
3130
from typing_extensions import override
3231

3332
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
@@ -45,6 +44,7 @@
4544

4645
if TYPE_CHECKING:
4746
from deepspeed import DeepSpeedEngine
47+
from torch.optim.lr_scheduler import _LRScheduler
4848

4949
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
5050
_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1")
@@ -317,7 +317,7 @@ def model(self) -> "DeepSpeedEngine":
317317

318318
@override
319319
def setup_module_and_optimizers(
320-
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
320+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
321321
) -> tuple["DeepSpeedEngine", list[Optimizer], Any]:
322322
"""Set up a model and multiple optimizers together, along with an optional learning rate scheduler. Currently,
323323
only a single optimizer is supported.
@@ -596,7 +596,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
596596
)
597597

598598
def _initialize_engine(
599-
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None
599+
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None
600600
) -> tuple["DeepSpeedEngine", Optimizer, Any]:
601601
"""Initialize one model and one optimizer with an optional learning rate scheduler.
602602

src/lightning/fabric/strategies/fsdp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from torch import Tensor
3434
from torch.nn import Module
3535
from torch.optim import Optimizer
36-
from torch.optim.lr_scheduler import _LRScheduler
3736
from typing_extensions import TypeGuard, override
3837

3938
from lightning.fabric.accelerators import Accelerator
@@ -72,6 +71,7 @@
7271
from torch.distributed.device_mesh import DeviceMesh
7372
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
7473
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
74+
from torch.optim.lr_scheduler import _LRScheduler
7575

7676
_POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
7777
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
@@ -262,8 +262,8 @@ def setup_environment(self) -> None:
262262

263263
@override
264264
def setup_module_and_optimizers(
265-
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
266-
) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]:
265+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
266+
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
267267
"""Wraps the model into a :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel`
268268
module and sets `use_orig_params=True` to keep the reference to the original parameters in the optimizer."""
269269
use_orig_params = self._fsdp_kwargs.get("use_orig_params")

src/lightning/fabric/strategies/strategy.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Iterable
1717
from contextlib import AbstractContextManager, ExitStack
18-
from typing import Any, Callable, Optional, TypeVar, Union
18+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
1919

2020
import torch
2121
from torch import Tensor
2222
from torch.nn import Module
2323
from torch.optim import Optimizer
24-
from torch.optim.lr_scheduler import _LRScheduler
2524
from torch.utils.data import DataLoader
2625

2726
from lightning.fabric.accelerators import Accelerator
@@ -34,6 +33,9 @@
3433
from lightning.fabric.utilities.init import _EmptyInit
3534
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp, _Stateful
3635

36+
if TYPE_CHECKING:
37+
from torch.optim.lr_scheduler import _LRScheduler
38+
3739
TBroadcast = TypeVar("TBroadcast")
3840
TReduce = TypeVar("TReduce")
3941

@@ -146,8 +148,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont
146148
return stack
147149

148150
def setup_module_and_optimizers(
149-
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
150-
) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]:
151+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
152+
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
151153
"""Set up a model and multiple optimizers together.
152154
153155
The returned objects are expected to be in the same order they were passed in. The default implementation will

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torch import Tensor
2222
from torch.nn import Module
2323
from torch.optim import Optimizer
24-
from torch.optim.lr_scheduler import _LRScheduler
2524
from torch.utils.data import DataLoader
2625
from typing_extensions import override
2726

@@ -45,6 +44,7 @@
4544
from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp
4645

4746
if TYPE_CHECKING:
47+
from torch.optim.lr_scheduler import _LRScheduler
4848
from torch_xla.distributed.parallel_loader import MpDeviceLoader
4949

5050
_POLICY_SET = set[type[Module]]
@@ -197,8 +197,8 @@ def setup_environment(self) -> None:
197197

198198
@override
199199
def setup_module_and_optimizers(
200-
self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None
201-
) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]:
200+
self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None
201+
) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]:
202202
"""Returns NotImplementedError since for XLAFSDP optimizer setup must happen after module setup."""
203203
raise NotImplementedError(
204204
f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."

0 commit comments

Comments
 (0)