Skip to content

Commit 41f76cd

Browse files
authored
Add @override for files in src/lightning/fabric/strategies (#19185)
1 parent 9d25e9a commit 41f76cd

File tree

13 files changed

+146
-1
lines changed

13 files changed

+146
-1
lines changed

src/lightning/fabric/strategies/ddp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torch import Tensor
2222
from torch.nn import Module
2323
from torch.nn.parallel.distributed import DistributedDataParallel
24+
from typing_extensions import override
2425

2526
from lightning.fabric.accelerators.accelerator import Accelerator
2627
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
@@ -78,6 +79,7 @@ def __init__(
7879
self._ddp_kwargs = kwargs
7980

8081
@property
82+
@override
8183
def root_device(self) -> torch.device:
8284
assert self.parallel_devices is not None
8385
return self.parallel_devices[self.local_rank]
@@ -96,24 +98,28 @@ def num_processes(self) -> int:
9698
return len(self.parallel_devices) if self.parallel_devices is not None else 0
9799

98100
@property
101+
@override
99102
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
100103
return {"num_replicas": (self.num_nodes * self.num_processes), "rank": self.global_rank}
101104

102105
@property
103106
def process_group_backend(self) -> Optional[str]:
104107
return self._process_group_backend
105108

109+
@override
106110
def _configure_launcher(self) -> None:
107111
assert self.cluster_environment is not None
108112
if self._start_method == "popen":
109113
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
110114
else:
111115
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)
112116

117+
@override
113118
def setup_environment(self) -> None:
114119
self._setup_distributed()
115120
super().setup_environment()
116121

122+
@override
117123
def setup_module(self, module: Module) -> DistributedDataParallel:
118124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
119125
device_ids = self._determine_ddp_device_ids()
@@ -122,9 +128,11 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
122128
with ctx:
123129
return DistributedDataParallel(module=module, device_ids=device_ids, **self._ddp_kwargs)
124130

131+
@override
125132
def module_to_device(self, module: Module) -> None:
126133
module.to(self.root_device)
127134

135+
@override
128136
def all_reduce(
129137
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
130138
) -> Tensor:
@@ -144,6 +152,7 @@ def all_reduce(
144152
return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
145153
return tensor
146154

155+
@override
147156
def barrier(self, *args: Any, **kwargs: Any) -> None:
148157
if not _distributed_is_initialized():
149158
return
@@ -152,6 +161,7 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
152161
else:
153162
torch.distributed.barrier()
154163

164+
@override
155165
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
156166
if not _distributed_is_initialized():
157167
return obj
@@ -160,11 +170,13 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
160170
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
161171
return obj[0]
162172

173+
@override
163174
def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]:
164175
if isinstance(module, DistributedDataParallel):
165176
module = module.module
166177
return super().get_module_state_dict(module)
167178

179+
@override
168180
def load_module_state_dict(
169181
self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True
170182
) -> None:
@@ -173,6 +185,7 @@ def load_module_state_dict(
173185
super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict)
174186

175187
@classmethod
188+
@override
176189
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
177190
entries = (
178191
("ddp", "popen"),
@@ -210,6 +223,7 @@ def _determine_ddp_device_ids(self) -> Optional[List[int]]:
210223

211224

212225
class _DDPBackwardSyncControl(_BackwardSyncControl):
226+
@override
213227
def no_backward_sync(self, module: Module) -> ContextManager:
214228
"""Blocks gradient synchronization inside the :class:`~torch.nn.parallel.distributed.DistributedDataParallel`
215229
wrapper."""

src/lightning/fabric/strategies/deepspeed.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from lightning_utilities.core.imports import RequirementCache
2626
from torch.nn import Module
2727
from torch.optim import Optimizer
28+
from typing_extensions import override
2829

2930
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
3031
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
@@ -299,13 +300,15 @@ def zero_stage_3(self) -> bool:
299300
return zero_optimization is not None and zero_optimization.get("stage") == 3
300301

301302
@property
303+
@override
302304
def distributed_sampler_kwargs(self) -> Dict[str, int]:
303305
return {"num_replicas": self.world_size, "rank": self.global_rank}
304306

305307
@property
306308
def model(self) -> "DeepSpeedEngine":
307309
return self._deepspeed_engine
308310

311+
@override
309312
def setup_module_and_optimizers(
310313
self, module: Module, optimizers: List[Optimizer]
311314
) -> Tuple["DeepSpeedEngine", List[Optimizer]]:
@@ -328,6 +331,7 @@ def setup_module_and_optimizers(
328331
self._set_deepspeed_activation_checkpointing()
329332
return self._deepspeed_engine, [optimizer]
330333

334+
@override
331335
def setup_module(self, module: Module) -> "DeepSpeedEngine":
332336
"""Set up a module for inference (no optimizers).
333337
@@ -337,6 +341,7 @@ def setup_module(self, module: Module) -> "DeepSpeedEngine":
337341
self._deepspeed_engine, _ = self._initialize_engine(module)
338342
return self._deepspeed_engine
339343

344+
@override
340345
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
341346
"""Optimizers can only be set up jointly with the model in this strategy.
342347
@@ -345,6 +350,7 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
345350
"""
346351
raise NotImplementedError(self._err_msg_joint_setup_required())
347352

353+
@override
348354
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
349355
if self.zero_stage_3 and empty_init is False:
350356
raise NotImplementedError(
@@ -357,6 +363,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag
357363
stack.enter_context(module_sharded_ctx)
358364
return stack
359365

366+
@override
360367
def module_sharded_context(self) -> ContextManager:
361368
# Current limitation in Fabric: The config needs to be fully determined at the time of calling the context
362369
# manager. Later modifications through e.g. `Fabric.setup()` won't have an effect here.
@@ -370,6 +377,7 @@ def module_sharded_context(self) -> ContextManager:
370377
config_dict_or_path=self.config,
371378
)
372379

380+
@override
373381
def save_checkpoint(
374382
self,
375383
path: _PATH,
@@ -434,6 +442,7 @@ def save_checkpoint(
434442
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
435443
engine.save_checkpoint(path, client_state=state, tag="checkpoint")
436444

445+
@override
437446
def load_checkpoint(
438447
self,
439448
path: _PATH,
@@ -514,6 +523,7 @@ def load_checkpoint(
514523
_move_state_into(source=client_state, destination=state, keys=keys)
515524
return client_state
516525

526+
@override
517527
def clip_gradients_norm(
518528
self,
519529
module: "DeepSpeedEngine",
@@ -527,6 +537,7 @@ def clip_gradients_norm(
527537
"Make sure to set the `gradient_clipping` value in your Config."
528538
)
529539

540+
@override
530541
def clip_gradients_value(
531542
self, module: "DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
532543
) -> None:
@@ -536,6 +547,7 @@ def clip_gradients_value(
536547
)
537548

538549
@classmethod
550+
@override
539551
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
540552
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")
541553
strategy_registry.register("deepspeed_stage_1", cls, description="DeepSpeed with ZeRO Stage 1 enabled", stage=1)
@@ -591,6 +603,7 @@ def _initialize_engine(
591603
)
592604
return deepspeed_engine, deepspeed_optimizer
593605

606+
@override
594607
def _setup_distributed(self) -> None:
595608
if not isinstance(self.accelerator, CUDAAccelerator):
596609
raise RuntimeError(

src/lightning/fabric/strategies/dp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
from torch import Tensor
1818
from torch.nn import DataParallel, Module
19+
from typing_extensions import override
1920

2021
from lightning.fabric.accelerators import Accelerator
2122
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
@@ -47,25 +48,31 @@ def __init__(
4748
)
4849

4950
@property
51+
@override
5052
def root_device(self) -> torch.device:
5153
assert self.parallel_devices is not None
5254
return self.parallel_devices[0]
5355

5456
@property
57+
@override
5558
def distributed_sampler_kwargs(self) -> None:
5659
return None
5760

61+
@override
5862
def setup_module(self, module: Module) -> DataParallel:
5963
"""Wraps the given model into a :class:`~torch.nn.DataParallel` module."""
6064
return DataParallel(module=module, device_ids=self.parallel_devices)
6165

66+
@override
6267
def module_to_device(self, module: Module) -> None:
6368
module.to(self.root_device)
6469

70+
@override
6571
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
6672
# DataParallel handles the transfer of batch to the device
6773
return batch
6874

75+
@override
6976
def all_reduce(
7077
self, collection: TReduce, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
7178
) -> TReduce:
@@ -75,20 +82,25 @@ def mean(t: Tensor) -> Tensor:
7582

7683
return apply_to_collection(collection, Tensor, mean)
7784

85+
@override
7886
def barrier(self, *args: Any, **kwargs: Any) -> None:
7987
pass
8088

89+
@override
8190
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
8291
return obj
8392

93+
@override
8494
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
8595
return decision
8696

97+
@override
8798
def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]:
8899
if isinstance(module, DataParallel):
89100
module = module.module
90101
return super().get_module_state_dict(module)
91102

103+
@override
92104
def load_module_state_dict(
93105
self, module: Module, state_dict: Dict[str, Union[Any, Tensor]], strict: bool = True
94106
) -> None:
@@ -97,5 +109,6 @@ def load_module_state_dict(
97109
super().load_module_state_dict(module=module, state_dict=state_dict, strict=strict)
98110

99111
@classmethod
112+
@override
100113
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
101114
strategy_registry.register("dp", cls, description=cls.__name__)

0 commit comments

Comments
 (0)