1616import logging
1717import os
1818import platform
19- from collections .abc import Mapping
19+ from collections .abc import Callable , Mapping
2020from contextlib import AbstractContextManager , ExitStack
2121from datetime import timedelta
2222from itertools import chain
2323from pathlib import Path
24- from typing import TYPE_CHECKING , Any , Callable , Optional , Union
24+ from typing import TYPE_CHECKING , Any , Optional
2525
2626import torch
2727from lightning_utilities .core .imports import RequirementCache
@@ -57,10 +57,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
5757
5858 def __init__ (
5959 self ,
60- accelerator : Optional [ Accelerator ] = None ,
60+ accelerator : Accelerator | None = None ,
6161 zero_optimization : bool = True ,
6262 stage : int = 2 ,
63- remote_device : Optional [ str ] = None ,
63+ remote_device : str | None = None ,
6464 offload_optimizer : bool = False ,
6565 offload_parameters : bool = False ,
6666 offload_params_device : str = "cpu" ,
@@ -84,11 +84,11 @@ def __init__(
8484 allgather_bucket_size : int = 200_000_000 ,
8585 reduce_bucket_size : int = 200_000_000 ,
8686 zero_allow_untested_optimizer : bool = True ,
87- logging_batch_size_per_gpu : Optional [ int ] = None ,
88- config : Optional [ Union [ _PATH , dict [str , Any ]]] = None ,
87+ logging_batch_size_per_gpu : int | None = None ,
88+ config : _PATH | dict [str , Any ] | None = None ,
8989 logging_level : int = logging .WARN ,
90- parallel_devices : Optional [ list [torch .device ]] = None ,
91- cluster_environment : Optional [ ClusterEnvironment ] = None ,
90+ parallel_devices : list [torch .device ] | None = None ,
91+ cluster_environment : ClusterEnvironment | None = None ,
9292 loss_scale : float = 0 ,
9393 initial_scale_power : int = 16 ,
9494 loss_scale_window : int = 1000 ,
@@ -99,9 +99,9 @@ def __init__(
9999 contiguous_memory_optimization : bool = False ,
100100 synchronize_checkpoint_boundary : bool = False ,
101101 load_full_weights : bool = False ,
102- precision : Optional [ Precision ] = None ,
103- process_group_backend : Optional [ str ] = None ,
104- timeout : Optional [ timedelta ] = default_pg_timeout ,
102+ precision : Precision | None = None ,
103+ process_group_backend : str | None = None ,
104+ timeout : timedelta | None = default_pg_timeout ,
105105 exclude_frozen_parameters : bool = False ,
106106 ) -> None :
107107 """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
@@ -262,7 +262,7 @@ def __init__(
262262 process_group_backend = process_group_backend ,
263263 )
264264 self ._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
265- self ._timeout : Optional [ timedelta ] = timeout
265+ self ._timeout : timedelta | None = timeout
266266
267267 self .config = self ._load_config (config )
268268 if self .config is None :
@@ -316,7 +316,7 @@ def __init__(
316316 self .hysteresis = hysteresis
317317 self .min_loss_scale = min_loss_scale
318318
319- self ._deepspeed_engine : Optional [ DeepSpeedEngine ] = None
319+ self ._deepspeed_engine : DeepSpeedEngine | None = None
320320
321321 @property
322322 def zero_stage_3 (self ) -> bool :
@@ -374,7 +374,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
374374 raise NotImplementedError (self ._err_msg_joint_setup_required ())
375375
376376 @override
377- def module_init_context (self , empty_init : Optional [ bool ] = None ) -> AbstractContextManager :
377+ def module_init_context (self , empty_init : bool | None = None ) -> AbstractContextManager :
378378 if self .zero_stage_3 and empty_init is False :
379379 raise NotImplementedError (
380380 f"`{ empty_init = } ` is not a valid choice with `DeepSpeedStrategy` when ZeRO stage 3 is enabled."
@@ -404,9 +404,9 @@ def module_sharded_context(self) -> AbstractContextManager:
404404 def save_checkpoint (
405405 self ,
406406 path : _PATH ,
407- state : dict [str , Union [ Module , Optimizer , Any ] ],
408- storage_options : Optional [ Any ] = None ,
409- filter : Optional [ dict [str , Callable [[str , Any ], bool ]]] = None ,
407+ state : dict [str , Module | Optimizer | Any ],
408+ storage_options : Any | None = None ,
409+ filter : dict [str , Callable [[str , Any ], bool ]] | None = None ,
410410 ) -> None :
411411 """Save model, optimizer, and other state in a checkpoint directory.
412412
@@ -471,9 +471,9 @@ def save_checkpoint(
471471 def load_checkpoint (
472472 self ,
473473 path : _PATH ,
474- state : Optional [ Union [ Module , Optimizer , dict [str , Union [ Module , Optimizer , Any ]]]] = None ,
474+ state : Module | Optimizer | dict [str , Module | Optimizer | Any ] | None = None ,
475475 strict : bool = True ,
476- weights_only : Optional [ bool ] = None ,
476+ weights_only : bool | None = None ,
477477 ) -> dict [str , Any ]:
478478 """Load the contents from a checkpoint and restore the state of the given objects.
479479
@@ -554,8 +554,8 @@ def clip_gradients_norm(
554554 self ,
555555 module : "DeepSpeedEngine" ,
556556 optimizer : Optimizer ,
557- max_norm : Union [ float , int ] ,
558- norm_type : Union [ float , int ] = 2.0 ,
557+ max_norm : float | int ,
558+ norm_type : float | int = 2.0 ,
559559 error_if_nonfinite : bool = True ,
560560 ) -> torch .Tensor :
561561 raise NotImplementedError (
@@ -564,9 +564,7 @@ def clip_gradients_norm(
564564 )
565565
566566 @override
567- def clip_gradients_value (
568- self , module : "DeepSpeedEngine" , optimizer : Optimizer , clip_val : Union [float , int ]
569- ) -> None :
567+ def clip_gradients_value (self , module : "DeepSpeedEngine" , optimizer : Optimizer , clip_val : float | int ) -> None :
570568 raise NotImplementedError (
571569 "DeepSpeed handles gradient clipping automatically within the optimizer. "
572570 "Make sure to set the `gradient_clipping` value in your Config."
@@ -614,7 +612,7 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
614612 )
615613
616614 def _initialize_engine (
617- self , model : Module , optimizer : Optional [ Optimizer ] = None , scheduler : Optional ["_LRScheduler" ] = None
615+ self , model : Module , optimizer : Optimizer | None = None , scheduler : Optional ["_LRScheduler" ] = None
618616 ) -> tuple ["DeepSpeedEngine" , Optimizer , Any ]:
619617 """Initialize one model and one optimizer with an optional learning rate scheduler.
620618
@@ -716,7 +714,7 @@ def _create_default_config(
716714 self ,
717715 zero_optimization : bool ,
718716 zero_allow_untested_optimizer : bool ,
719- logging_batch_size_per_gpu : Optional [ int ] ,
717+ logging_batch_size_per_gpu : int | None ,
720718 partition_activations : bool ,
721719 cpu_checkpointing : bool ,
722720 contiguous_memory_optimization : bool ,
@@ -825,7 +823,7 @@ def load(module: torch.nn.Module, prefix: str = "") -> None:
825823
826824 load (module , prefix = "" )
827825
828- def _load_config (self , config : Optional [ Union [ _PATH , dict [str , Any ]]] ) -> Optional [ dict [str , Any ]] :
826+ def _load_config (self , config : _PATH | dict [str , Any ] | None ) -> dict [str , Any ] | None :
829827 if config is None and self .DEEPSPEED_ENV_VAR in os .environ :
830828 rank_zero_info (f"Loading DeepSpeed config from set { self .DEEPSPEED_ENV_VAR } environment variable" )
831829 config = os .environ [self .DEEPSPEED_ENV_VAR ]
0 commit comments