Skip to content

Commit 188aadc

Browse files
committed
manual fixes
1 parent 5c71611 commit 188aadc

File tree

17 files changed

+52
-56
lines changed

17 files changed

+52
-56
lines changed

src/lightning/fabric/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
from collections import Counter
1616
from collections.abc import Iterable
17-
from typing import Any, Union, cast, get_args
17+
from typing import Any, cast, get_args
1818

1919
import torch
2020

@@ -67,7 +67,7 @@
6767
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
6868
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
6969

70-
_PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO]
70+
_PLUGIN_INPUT = Precision | ClusterEnvironment | CheckpointIO
7171

7272

7373
class _Connector:

src/lightning/fabric/plugins/precision/precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import AbstractContextManager, nullcontext
15-
from typing import Any, Literal, Union
15+
from typing import Any, Literal
1616

1717
from torch import Tensor
1818
from torch.nn import Module
@@ -33,7 +33,7 @@
3333
"32-true",
3434
"64-true",
3535
]
36-
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS]
36+
_PRECISION_INPUT = _PRECISION_INPUT_INT | _PRECISION_INPUT_STR | _PRECISION_INPUT_STR_ALIAS
3737

3838

3939
class Precision:

src/lightning/fabric/strategies/fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@
7373
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
7474
from torch.optim.lr_scheduler import _LRScheduler
7575

76-
_POLICY = Union[set[type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
77-
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]
76+
_POLICY = set[type[Module]] | Callable[[Module, bool, int], bool] | ModuleWrapPolicy
77+
_SHARDING_STRATEGY = ShardingStrategy | Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]
7878

7979

8080
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")

src/lightning/fabric/strategies/xla_fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from contextlib import AbstractContextManager, ExitStack, nullcontext
1717
from functools import partial
1818
from pathlib import Path
19-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
19+
from typing import TYPE_CHECKING, Any, Literal, Optional
2020

2121
import torch
2222
from torch import Tensor
@@ -49,7 +49,7 @@
4949
from torch_xla.distributed.parallel_loader import MpDeviceLoader
5050

5151
_POLICY_SET = set[type[Module]]
52-
_POLICY = Union[_POLICY_SET, Callable[[Module, bool, int], bool]]
52+
_POLICY = _POLICY_SET | Callable[[Module, bool, int], bool]
5353

5454

5555
class XLAFSDPStrategy(ParallelStrategy, _Sharded):

src/lightning/fabric/utilities/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
from pathlib import Path
1717
from typing import (
1818
Any,
19-
Optional,
2019
Protocol,
2120
TypeAlias,
2221
TypeVar,
23-
Union,
2422
runtime_checkable,
2523
)
2624

@@ -30,9 +28,9 @@
3028

3129
UntypedStorage: TypeAlias = torch.UntypedStorage
3230

33-
_PATH = Union[str, Path]
34-
_DEVICE = Union[torch.device, str, int]
35-
_MAP_LOCATION_TYPE = Optional[_DEVICE | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[_DEVICE, _DEVICE]]
31+
_PATH = str | Path
32+
_DEVICE = torch.device | str | int
33+
_MAP_LOCATION_TYPE = _DEVICE | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[_DEVICE, _DEVICE] | None
3634
_PARAMETERS = Iterator[torch.nn.Parameter]
3735

3836
if torch.distributed.is_available():

src/lightning/pytorch/cli.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from functools import partial, update_wrapper
1919
from pathlib import Path
2020
from types import MethodType
21-
from typing import Any, Optional, TypeVar, Union
21+
from typing import Any, TypeVar
2222

2323
import torch
2424
import yaml
@@ -80,12 +80,12 @@ def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any
8080

8181
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
8282
LRSchedulerTypeTuple = (LRScheduler, ReduceLROnPlateau)
83-
LRSchedulerTypeUnion = Union[LRScheduler, ReduceLROnPlateau]
84-
LRSchedulerType = Union[type[LRScheduler], type[ReduceLROnPlateau]]
83+
LRSchedulerTypeUnion = LRScheduler | ReduceLROnPlateau
84+
LRSchedulerType = type[LRScheduler] | type[ReduceLROnPlateau]
8585

8686

8787
# Type aliases intended for convenience of CLI developers
88-
ArgsType = Optional[list[str] | dict[str, Any] | Namespace]
88+
ArgsType = list[str] | dict[str, Any] | Namespace | None
8989
OptimizerCallable = Callable[[Iterable], Optimizer]
9090
LRSchedulerCallable = Callable[[Optimizer], LRScheduler | ReduceLROnPlateau]
9191

@@ -448,7 +448,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No
448448
"""Adds default arguments to the parser."""
449449
parser.add_argument(
450450
"--seed_everything",
451-
type=Union[bool, int],
451+
type=bool | int,
452452
default=self.seed_everything_default,
453453
help=(
454454
"Set to an int to run seed_everything with this value before classes instantiation."

src/lightning/pytorch/core/module.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
Any,
2828
Literal,
2929
Optional,
30-
Union,
3130
cast,
3231
overload,
3332
)
@@ -89,9 +88,14 @@
8988
warning_cache = WarningCache()
9089
log = logging.getLogger(__name__)
9190

92-
MODULE_OPTIMIZERS = Union[
93-
Optimizer, LightningOptimizer, _FabricOptimizer, list[Optimizer], list[LightningOptimizer], list[_FabricOptimizer]
94-
]
91+
MODULE_OPTIMIZERS = (
92+
Optimizer
93+
| LightningOptimizer
94+
| _FabricOptimizer
95+
| list[Optimizer]
96+
| list[LightningOptimizer]
97+
| list[_FabricOptimizer]
98+
)
9599

96100

97101
class LightningModule(

src/lightning/pytorch/loops/training_epoch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import time
1717
from collections import OrderedDict
1818
from dataclasses import dataclass
19-
from typing import Any, Optional
19+
from typing import Any
2020

2121
import torch
2222
from typing_extensions import override
@@ -38,7 +38,7 @@
3838
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
3939
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
4040

41-
_BATCH_OUTPUTS_TYPE = Optional[_OPTIMIZER_LOOP_OUTPUTS_TYPE | _MANUAL_LOOP_OUTPUTS_TYPE]
41+
_BATCH_OUTPUTS_TYPE = _OPTIMIZER_LOOP_OUTPUTS_TYPE | _MANUAL_LOOP_OUTPUTS_TYPE | None
4242

4343

4444
@dataclass

src/lightning/pytorch/plugins/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Union
2-
31
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, TorchCheckpointIO, XLACheckpointIO
42
from lightning.pytorch.plugins.io.async_plugin import AsyncCheckpointIO
53
from lightning.pytorch.plugins.layer_sync import LayerSync, TorchSyncBatchNorm
@@ -13,7 +11,7 @@
1311
from lightning.pytorch.plugins.precision.transformer_engine import TransformerEnginePrecision
1412
from lightning.pytorch.plugins.precision.xla import XLAPrecision
1513

16-
_PLUGIN_INPUT = Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync]
14+
_PLUGIN_INPUT = Precision | ClusterEnvironment | CheckpointIO | LayerSync
1715

1816
__all__ = [
1917
"AsyncCheckpointIO",

src/lightning/pytorch/profilers/pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from contextlib import AbstractContextManager
2121
from functools import lru_cache, partial
2222
from pathlib import Path
23-
from typing import TYPE_CHECKING, Any, Union
23+
from typing import TYPE_CHECKING, Any
2424

2525
import torch
2626
from torch import Tensor, nn
@@ -42,7 +42,7 @@
4242
log = logging.getLogger(__name__)
4343
warning_cache = WarningCache()
4444

45-
_PROFILER = Union[torch.profiler.profile, torch.autograd.profiler.profile, torch.autograd.profiler.emit_nvtx]
45+
_PROFILER = torch.profiler.profile | torch.autograd.profiler.profile | torch.autograd.profiler.emit_nvtx
4646
_KINETO_AVAILABLE = torch.profiler.kineto_available()
4747

4848

0 commit comments

Comments
 (0)