Skip to content

Commit b240864

Browse files
committed
Apply formatting changes from merge
1 parent 75f89a1 commit b240864

File tree

19 files changed

+228
-225
lines changed

19 files changed

+228
-225
lines changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15-
from typing import Any, Union
15+
from typing import Any
1616

1717
import torch
1818
from lightning_utilities.core.imports import RequirementCache
@@ -47,13 +47,13 @@ def teardown(self) -> None:
4747

4848
@staticmethod
4949
@override
50-
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
50+
def parse_devices(devices: int | str | list[int]) -> int | list[int]:
5151
"""Accelerator device parsing logic."""
5252
return _parse_tpu_devices(devices)
5353

5454
@staticmethod
5555
@override
56-
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
56+
def get_parallel_devices(devices: int | list[int]) -> list[torch.device]:
5757
"""Gets parallel devices for the Accelerator."""
5858
devices = _parse_tpu_devices(devices)
5959
if isinstance(devices, int):
@@ -131,7 +131,7 @@ def _using_pjrt() -> bool:
131131
return pjrt.using_pjrt()
132132

133133

134-
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
134+
def _parse_tpu_devices(devices: int | str | list[int]) -> int | list[int]:
135135
"""Parses the TPU devices given in the format as accepted by the
136136
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137137
@@ -168,7 +168,7 @@ def _check_tpu_devices_valid(devices: object) -> None:
168168
)
169169

170170

171-
def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
171+
def _parse_tpu_devices_str(devices: str) -> int | list[int]:
172172
devices = devices.strip()
173173
try:
174174
return int(devices)

src/lightning/fabric/plugins/environments/slurm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def job_name() -> str | None:
116116
return os.environ.get("SLURM_JOB_NAME")
117117

118118
@staticmethod
119-
def job_id() -> Optional[int]:
119+
def job_id() -> int | None:
120120
# in interactive mode, don't make logs use the same job id
121121
if _is_slurm_interactive_mode():
122122
return None

src/lightning/fabric/plugins/io/xla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16-
from typing import Any, Optional
16+
from typing import Any
1717

1818
import torch
1919
from lightning_utilities.core.apply_func import apply_to_collection
@@ -41,7 +41,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
4141
super().__init__(*args, **kwargs)
4242

4343
@override
44-
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
44+
def save_checkpoint(self, checkpoint: dict[str, Any], path: _PATH, storage_options: Any | None = None) -> None:
4545
"""Save model/training states as a checkpoint file through state-dump and file-write.
4646
4747
Args:

src/lightning/fabric/plugins/precision/bitsandbytes.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
import os
1818
import warnings
1919
from collections import OrderedDict
20+
from collections.abc import Callable
2021
from contextlib import AbstractContextManager, ExitStack
2122
from functools import partial
2223
from types import ModuleType
23-
from typing import Any, Callable, Literal, Optional, cast
24+
from typing import Any, Literal, cast
2425

2526
import torch
2627
from lightning_utilities import apply_to_collection
@@ -70,8 +71,8 @@ class BitsandbytesPrecision(Precision):
7071
def __init__(
7172
self,
7273
mode: Literal["nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"],
73-
dtype: Optional[torch.dtype] = None,
74-
ignore_modules: Optional[set[str]] = None,
74+
dtype: torch.dtype | None = None,
75+
ignore_modules: set[str] | None = None,
7576
) -> None:
7677
_import_bitsandbytes()
7778

@@ -176,7 +177,7 @@ def _ignore_missing_weights_hook(module: torch.nn.Module, incompatible_keys: _In
176177

177178

178179
def _replace_param(
179-
param: torch.nn.Parameter, data: torch.Tensor, quant_state: Optional[tuple] = None
180+
param: torch.nn.Parameter, data: torch.Tensor, quant_state: tuple | None = None
180181
) -> torch.nn.Parameter:
181182
bnb = _import_bitsandbytes()
182183

@@ -223,18 +224,18 @@ class _Linear8bitLt(bnb.nn.Linear8bitLt):
223224
"""Wraps `bnb.nn.Linear8bitLt` and enables instantiation directly on the device and re-quantizaton when loading
224225
the state dict."""
225226

226-
def __init__(self, *args: Any, device: Optional[_DEVICE] = None, threshold: float = 6.0, **kwargs: Any) -> None:
227+
def __init__(self, *args: Any, device: _DEVICE | None = None, threshold: float = 6.0, **kwargs: Any) -> None:
227228
super().__init__(*args, device=device, threshold=threshold, **kwargs)
228229
self.weight = cast(bnb.nn.Int8Params, self.weight) # type: ignore[has-type]
229-
self.bias: Optional[torch.nn.Parameter] = self.bias
230+
self.bias: torch.nn.Parameter | None = self.bias
230231
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
231232
# filling the device memory with float32 weights which could lead to OOM
232233
if torch.tensor(0, device=device).device.type == "cuda":
233234
self.quantize_()
234235
self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
235236
self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
236237

237-
def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
238+
def quantize_(self, weight: torch.Tensor | None = None, device: torch.device | None = None) -> None:
238239
"""Inplace quantize."""
239240
if weight is None:
240241
weight = self.weight.data
@@ -246,7 +247,7 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
246247

247248
@staticmethod
248249
def quantize(
249-
int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: Optional[torch.device]
250+
int8params: bnb.nn.Int8Params, weight: torch.Tensor, device: torch.device | None
250251
) -> bnb.nn.Int8Params:
251252
device = device or torch.device("cuda")
252253
if device.type != "cuda":
@@ -310,18 +311,18 @@ class _Linear4bit(bnb.nn.Linear4bit):
310311
"""Wraps `bnb.nn.Linear4bit` to enable: instantiation directly on the device, re-quantizaton when loading the
311312
state dict, meta-device initialization, and materialization."""
312313

313-
def __init__(self, *args: Any, device: Optional[_DEVICE] = None, **kwargs: Any) -> None:
314+
def __init__(self, *args: Any, device: _DEVICE | None = None, **kwargs: Any) -> None:
314315
super().__init__(*args, device=device, **kwargs)
315316
self.weight = cast(bnb.nn.Params4bit, self.weight) # type: ignore[has-type]
316-
self.bias: Optional[torch.nn.Parameter] = self.bias
317+
self.bias: torch.nn.Parameter | None = self.bias
317318
# if the device is CUDA or we are under a CUDA context manager, quantize the weight here, so we don't end up
318319
# filling the device memory with float32 weights which could lead to OOM
319320
if torch.tensor(0, device=device).device.type == "cuda":
320321
self.quantize_()
321322
self._register_load_state_dict_pre_hook(partial(_quantize_on_load_hook, self.quantize_))
322323
self.register_load_state_dict_post_hook(_ignore_missing_weights_hook)
323324

324-
def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torch.device] = None) -> None:
325+
def quantize_(self, weight: torch.Tensor | None = None, device: torch.device | None = None) -> None:
325326
"""Inplace quantize."""
326327
if weight is None:
327328
weight = self.weight.data
@@ -334,7 +335,7 @@ def quantize_(self, weight: Optional[torch.Tensor] = None, device: Optional[torc
334335

335336
@staticmethod
336337
def quantize(
337-
params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: Optional[torch.device]
338+
params4bit: bnb.nn.Params4bit, weight: torch.Tensor, device: torch.device | None
338339
) -> bnb.nn.Params4bit:
339340
device = device or torch.device("cuda")
340341
if device.type != "cuda":

src/lightning/fabric/plugins/precision/transformer_engine.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
from collections.abc import Mapping
1616
from contextlib import AbstractContextManager, ExitStack
17-
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
17+
from typing import TYPE_CHECKING, Any, Literal, Union
1818

1919
import torch
2020
from lightning_utilities import apply_to_collection
@@ -68,9 +68,9 @@ def __init__(
6868
self,
6969
*,
7070
weights_dtype: torch.dtype,
71-
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
72-
replace_layers: Optional[bool] = None,
73-
fallback_compute_dtype: Optional[torch.dtype] = None,
71+
recipe: Union[Mapping[str, Any], "DelayedScaling"] | None = None,
72+
replace_layers: bool | None = None,
73+
fallback_compute_dtype: torch.dtype | None = None,
7474
) -> None:
7575
if not _TRANSFORMER_ENGINE_AVAILABLE:
7676
raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE))

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import logging
1717
import os
1818
import platform
19-
from collections.abc import Mapping
19+
from collections.abc import Callable, Mapping
2020
from contextlib import AbstractContextManager, ExitStack
2121
from datetime import timedelta
2222
from itertools import chain
2323
from pathlib import Path
24-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
24+
from typing import TYPE_CHECKING, Any, Optional
2525

2626
import torch
2727
from lightning_utilities.core.imports import RequirementCache
@@ -57,10 +57,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
5757

5858
def __init__(
5959
self,
60-
accelerator: Optional[Accelerator] = None,
60+
accelerator: Accelerator | None = None,
6161
zero_optimization: bool = True,
6262
stage: int = 2,
63-
remote_device: Optional[str] = None,
63+
remote_device: str | None = None,
6464
offload_optimizer: bool = False,
6565
offload_parameters: bool = False,
6666
offload_params_device: str = "cpu",
@@ -84,11 +84,11 @@ def __init__(
8484
allgather_bucket_size: int = 200_000_000,
8585
reduce_bucket_size: int = 200_000_000,
8686
zero_allow_untested_optimizer: bool = True,
87-
logging_batch_size_per_gpu: Optional[int] = None,
88-
config: Optional[Union[_PATH, dict[str, Any]]] = None,
87+
logging_batch_size_per_gpu: int | None = None,
88+
config: _PATH | dict[str, Any] | None = None,
8989
logging_level: int = logging.WARN,
90-
parallel_devices: Optional[list[torch.device]] = None,
91-
cluster_environment: Optional[ClusterEnvironment] = None,
90+
parallel_devices: list[torch.device] | None = None,
91+
cluster_environment: ClusterEnvironment | None = None,
9292
loss_scale: float = 0,
9393
initial_scale_power: int = 16,
9494
loss_scale_window: int = 1000,
@@ -99,9 +99,9 @@ def __init__(
9999
contiguous_memory_optimization: bool = False,
100100
synchronize_checkpoint_boundary: bool = False,
101101
load_full_weights: bool = False,
102-
precision: Optional[Precision] = None,
103-
process_group_backend: Optional[str] = None,
104-
timeout: Optional[timedelta] = default_pg_timeout,
102+
precision: Precision | None = None,
103+
process_group_backend: str | None = None,
104+
timeout: timedelta | None = default_pg_timeout,
105105
exclude_frozen_parameters: bool = False,
106106
) -> None:
107107
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
@@ -262,7 +262,7 @@ def __init__(
262262
process_group_backend=process_group_backend,
263263
)
264264
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
265-
self._timeout: Optional[timedelta] = timeout
265+
self._timeout: timedelta | None = timeout
266266

267267
self.config = self._load_config(config)
268268
if self.config is None:
@@ -316,7 +316,7 @@ def __init__(
316316
self.hysteresis = hysteresis
317317
self.min_loss_scale = min_loss_scale
318318

319-
self._deepspeed_engine: Optional[DeepSpeedEngine] = None
319+
self._deepspeed_engine: DeepSpeedEngine | None = None
320320

321321
@property
322322
def zero_stage_3(self) -> bool:
@@ -374,7 +374,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
374374
raise NotImplementedError(self._err_msg_joint_setup_required())
375375

376376
@override
377-
def module_init_context(self, empty_init: Optional[bool] = None) -> AbstractContextManager:
377+
def module_init_context(self, empty_init: bool | None = None) -> AbstractContextManager:
378378
if self.zero_stage_3 and empty_init is False:
379379
raise NotImplementedError(
380380
f"`{empty_init=}` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
@@ -404,9 +404,9 @@ def module_sharded_context(self) -> AbstractContextManager:
404404
def save_checkpoint(
405405
self,
406406
path: _PATH,
407-
state: dict[str, Union[Module, Optimizer, Any]],
408-
storage_options: Optional[Any] = None,
409-
filter: Optional[dict[str, Callable[[str, Any], bool]]] = None,
407+
state: dict[str, Module | Optimizer | Any],
408+
storage_options: Any | None = None,
409+
filter: dict[str, Callable[[str, Any], bool]] | None = None,
410410
) -> None:
411411
"""Save model, optimizer, and other state in a checkpoint directory.
412412
@@ -471,9 +471,9 @@ def save_checkpoint(
471471
def load_checkpoint(
472472
self,
473473
path: _PATH,
474-
state: Optional[Union[Module, Optimizer, dict[str, Union[Module, Optimizer, Any]]]] = None,
474+
state: Module | Optimizer | dict[str, Module | Optimizer | Any] | None = None,
475475
strict: bool = True,
476-
weights_only: Optional[bool] = None,
476+
weights_only: bool | None = None,
477477
) -> dict[str, Any]:
478478
"""Load the contents from a checkpoint and restore the state of the given objects.
479479
@@ -554,8 +554,8 @@ def clip_gradients_norm(
554554
self,
555555
module: "DeepSpeedEngine",
556556
optimizer: Optimizer,
557-
max_norm: Union[float, int],
558-
norm_type: Union[float, int] = 2.0,
557+
max_norm: float | int,
558+
norm_type: float | int = 2.0,
559559
error_if_nonfinite: bool = True,
560560
) -> torch.Tensor:
561561
raise NotImplementedError(
@@ -564,9 +564,7 @@ def clip_gradients_norm(
564564
)
565565

566566
@override
567-
def clip_gradients_value(
568-
self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
569-
) -> None:
567+
def clip_gradients_value(self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: float | int) -> None:
570568
raise NotImplementedError(
571569
"DeepSpeed handles gradient clipping automatically within the optimizer. "
572570
"Make sure to set the `gradient_clipping` value in your Config."
@@ -614,7 +612,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
614612
)
615613

616614
def _initialize_engine(
617-
self, model: Module, optimizer: Optional[Optimizer] = None, scheduler: Optional["_LRScheduler"] = None
615+
self, model: Module, optimizer: Optimizer | None = None, scheduler: Optional["_LRScheduler"] = None
618616
) -> tuple["DeepSpeedEngine", Optimizer, Any]:
619617
"""Initialize one model and one optimizer with an optional learning rate scheduler.
620618
@@ -716,7 +714,7 @@ def _create_default_config(
716714
self,
717715
zero_optimization: bool,
718716
zero_allow_untested_optimizer: bool,
719-
logging_batch_size_per_gpu: Optional[int],
717+
logging_batch_size_per_gpu: int | None,
720718
partition_activations: bool,
721719
cpu_checkpointing: bool,
722720
contiguous_memory_optimization: bool,
@@ -825,7 +823,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None:
825823

826824
load(module, prefix="")
827825

828-
def _load_config(self, config: Optional[Union[_PATH, dict[str, Any]]]) -> Optional[dict[str, Any]]:
826+
def _load_config(self, config: _PATH | dict[str, Any] | None) -> dict[str, Any] | None:
829827
if config is None and self.DEEPSPEED_ENV_VAR in os.environ:
830828
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
831829
config = os.environ[self.DEEPSPEED_ENV_VAR]

src/lightning/fabric/strategies/launchers/xla.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414
import queue
1515
import time
16-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
16+
from collections.abc import Callable
17+
from typing import TYPE_CHECKING, Any, Union
1718

1819
import torch.multiprocessing as mp
1920
from typing_extensions import override
@@ -68,7 +69,7 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
6869
**kwargs: Optional keyword arguments to be passed to the given function.
6970
7071
"""
71-
return_queue: Union[queue.Queue, mp.SimpleQueue]
72+
return_queue: queue.Queue | mp.SimpleQueue
7273
return_queue = mp.Manager().Queue()
7374

7475
import torch_xla.distributed.xla_multiprocessing as xmp
@@ -96,8 +97,8 @@ def _wrapping_function(
9697
function: Callable,
9798
args: Any,
9899
kwargs: Any,
99-
return_queue: Union[mp.SimpleQueue, queue.Queue],
100-
global_states: Optional[_GlobalStateSnapshot] = None,
100+
return_queue: mp.SimpleQueue | queue.Queue,
101+
global_states: _GlobalStateSnapshot | None = None,
101102
) -> None:
102103
import torch_xla.core.xla_model as xm
103104

0 commit comments

Comments
 (0)