|
15 | 15 | from abc import ABC, abstractmethod |
16 | 16 | from collections.abc import Iterable |
17 | 17 | from contextlib import AbstractContextManager, ExitStack |
18 | | -from typing import Any, Callable, Optional, TypeVar, Union |
| 18 | +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union |
19 | 19 |
|
20 | 20 | import torch |
21 | 21 | from torch import Tensor |
22 | 22 | from torch.nn import Module |
23 | 23 | from torch.optim import Optimizer |
24 | | -from torch.optim.lr_scheduler import _LRScheduler |
25 | 24 | from torch.utils.data import DataLoader |
26 | 25 |
|
27 | 26 | from lightning.fabric.accelerators import Accelerator |
|
34 | 33 | from lightning.fabric.utilities.init import _EmptyInit |
35 | 34 | from lightning.fabric.utilities.types import _PATH, Optimizable, ReduceOp, _Stateful |
36 | 35 |
|
| 36 | +if TYPE_CHECKING: |
| 37 | + from torch.optim.lr_scheduler import _LRScheduler |
| 38 | + |
37 | 39 | TBroadcast = TypeVar("TBroadcast") |
38 | 40 | TReduce = TypeVar("TReduce") |
39 | 41 |
|
@@ -146,8 +148,8 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractCont |
146 | 148 | return stack |
147 | 149 |
|
148 | 150 | def setup_module_and_optimizers( |
149 | | - self, module: Module, optimizers: list[Optimizer], scheduler: Optional[_LRScheduler] = None |
150 | | - ) -> tuple[Module, list[Optimizer], Optional[_LRScheduler]]: |
| 151 | + self, module: Module, optimizers: list[Optimizer], scheduler: Optional["_LRScheduler"] = None |
| 152 | + ) -> tuple[Module, list[Optimizer], Optional["_LRScheduler"]]: |
151 | 153 | """Set up a model and multiple optimizers together. |
152 | 154 |
|
153 | 155 | The returned objects are expected to be in the same order they were passed in. The default implementation will |
|
0 commit comments