Skip to content

Commit c985400

Browse files
authored
Add @override for files in src/lightning/fabric/plugins/collectives (#19156)
1 parent 234ded8 commit c985400

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

src/lightning/fabric/plugins/collectives/single_device.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, List
22

33
from torch import Tensor
4+
from typing_extensions import override
45

56
from lightning.fabric.plugins.collectives.collective import Collective
67
from lightning.fabric.utilities.types import CollectibleGroup
@@ -14,28 +15,36 @@ class SingleDeviceCollective(Collective):
1415
"""
1516

1617
@property
18+
@override
1719
def rank(self) -> int:
1820
return 0
1921

2022
@property
23+
@override
2124
def world_size(self) -> int:
2225
return 1
2326

27+
@override
2428
def broadcast(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
2529
return tensor
2630

31+
@override
2732
def all_reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
2833
return tensor
2934

35+
@override
3036
def reduce(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
3137
return tensor
3238

39+
@override
3340
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor, **__: Any) -> List[Tensor]:
3441
return [tensor]
3542

43+
@override
3644
def gather(self, tensor: Tensor, *_: Any, **__: Any) -> List[Tensor]:
3745
return [tensor]
3846

47+
@override
3948
def scatter(
4049
self,
4150
tensor: Tensor,
@@ -45,43 +54,54 @@ def scatter(
4554
) -> Tensor:
4655
return scatter_list[0]
4756

57+
@override
4858
def reduce_scatter(self, output: Tensor, input_list: List[Tensor], *_: Any, **__: Any) -> Tensor:
4959
return input_list[0]
5060

61+
@override
5162
def all_to_all(
5263
self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor], *_: Any, **__: Any
5364
) -> List[Tensor]:
5465
return input_tensor_list
5566

67+
@override
5668
def send(self, *_: Any, **__: Any) -> None:
5769
pass
5870

71+
@override
5972
def recv(self, tensor: Tensor, *_: Any, **__: Any) -> Tensor:
6073
return tensor
6174

75+
@override
6276
def barrier(self, *_: Any, **__: Any) -> None:
6377
pass
6478

6579
@classmethod
80+
@override
6681
def is_available(cls) -> bool:
6782
return True # vacuous truth
6883

6984
@classmethod
85+
@override
7086
def is_initialized(cls) -> bool:
7187
return True # vacuous truth
7288

7389
@classmethod
90+
@override
7491
def init_group(cls, **_: Any) -> None:
7592
pass
7693

7794
@classmethod
95+
@override
7896
def new_group(cls, **_: Any) -> CollectibleGroup:
7997
return object() # type: ignore[return-value]
8098

8199
@classmethod
100+
@override
82101
def destroy_group(cls, group: CollectibleGroup) -> None:
83102
pass
84103

85104
@classmethod
105+
@override
86106
def _convert_to_native_op(cls, op: str) -> str:
87107
return op

src/lightning/fabric/plugins/collectives/torch_collective.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.distributed as dist
77
from torch import Tensor
8-
from typing_extensions import Self
8+
from typing_extensions import Self, override
99

1010
from lightning.fabric.plugins.collectives.collective import Collective
1111
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_13
@@ -32,60 +32,73 @@ def __init__(self) -> None:
3232
super().__init__()
3333

3434
@property
35+
@override
3536
def group(self) -> CollectibleGroup:
3637
if self._group is None:
3738
self._group = dist.GroupMember.WORLD
3839
return super().group
3940

4041
@property
42+
@override
4143
def rank(self) -> int:
4244
# local rank
4345
return dist.get_rank(self.group) # type: ignore[arg-type]
4446

4547
@property
48+
@override
4649
def world_size(self) -> int:
4750
return dist.get_world_size(self.group) # type: ignore[arg-type]
4851

52+
@override
4953
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
5054
dist.broadcast(tensor, src, group=self.group)
5155
return tensor
5256

57+
@override
5358
def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
5459
op = self._convert_to_native_op(op)
5560
dist.all_reduce(tensor, op=op, group=self.group)
5661
return tensor
5762

63+
@override
5864
def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
5965
op = self._convert_to_native_op(op)
6066
dist.reduce(tensor, dst, op=op, group=self.group)
6167
return tensor
6268

69+
@override
6370
def all_gather(self, tensor_list: List[Tensor], tensor: Tensor) -> List[Tensor]:
6471
dist.all_gather(tensor_list, tensor, group=self.group)
6572
return tensor_list
6673

74+
@override
6775
def gather(self, tensor: Tensor, gather_list: List[Tensor], dst: int = 0) -> List[Tensor]:
6876
dist.gather(tensor, gather_list, dst, group=self.group)
6977
return gather_list
7078

79+
@override
7180
def scatter(self, tensor: Tensor, scatter_list: List[Tensor], src: int = 0) -> Tensor:
7281
dist.scatter(tensor, scatter_list, src, group=self.group)
7382
return tensor
7483

84+
@override
7585
def reduce_scatter(
7686
self, output: Tensor, input_list: List[Tensor], op: Union[str, ReduceOp, RedOpType] = "sum"
7787
) -> Tensor:
7888
op = self._convert_to_native_op(op)
7989
dist.reduce_scatter(output, input_list, op=op, group=self.group)
8090
return output
8191

92+
@override
8293
def all_to_all(self, output_tensor_list: List[Tensor], input_tensor_list: List[Tensor]) -> List[Tensor]:
8394
dist.all_to_all(output_tensor_list, input_tensor_list, group=self.group)
8495
return output_tensor_list
8596

97+
@override
8698
def send(self, tensor: Tensor, dst: int, tag: int = 0) -> None:
8799
dist.send(tensor, dst, tag=tag, group=self.group)
88100

101+
@override
89102
def recv(self, tensor: Tensor, src: Optional[int] = None, tag: int = 0) -> Tensor:
90103
dist.recv(tensor, src, tag=tag, group=self.group)
91104
return tensor
@@ -110,6 +123,7 @@ def scatter_object_list(
110123
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group)
111124
return scatter_object_output_list
112125

126+
@override
113127
def barrier(self, device_ids: Optional[List[int]] = None) -> None:
114128
if self.group == dist.GroupMember.NON_GROUP_MEMBER:
115129
return
@@ -118,6 +132,7 @@ def barrier(self, device_ids: Optional[List[int]] = None) -> None:
118132
def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None:
119133
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks)
120134

135+
@override
121136
def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self:
122137
if self.is_initialized():
123138
return self
@@ -144,6 +159,7 @@ def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = N
144159
os.environ.pop("MASTER_PORT", None)
145160
return self
146161

162+
@override
147163
def teardown(self) -> Self:
148164
group_member = self.group != dist.GroupMember.NON_GROUP_MEMBER
149165
super().teardown() # will destroy its own group
@@ -162,29 +178,35 @@ def teardown(self) -> Self:
162178
return self
163179

164180
@classmethod
181+
@override
165182
def is_available(cls) -> bool:
166183
return dist.is_available()
167184

168185
@classmethod
186+
@override
169187
def is_initialized(cls) -> bool:
170188
return cls.is_available() and dist.is_initialized()
171189

172190
@classmethod
191+
@override
173192
def init_group(cls, **kwargs: Any) -> None:
174193
dist.init_process_group(**kwargs)
175194

176195
@classmethod
196+
@override
177197
def new_group(cls, **kwargs: Any) -> CollectibleGroup:
178198
return dist.new_group(**kwargs)
179199

180200
@classmethod
201+
@override
181202
def destroy_group(cls, group: CollectibleGroup) -> None:
182203
# can be called by all processes in the default group, group will be `object()` if they are not part of the
183204
# current group
184205
if group in dist.distributed_c10d._pg_map:
185206
dist.destroy_process_group(group) # type: ignore[arg-type]
186207

187208
@classmethod
209+
@override
188210
def _convert_to_native_op(cls, op: Union[str, ReduceOp, RedOpType]) -> Union[ReduceOp, RedOpType]:
189211
# in 1.13, `ReduceOp` has become an empty shell for `RedOpType`, the latter being the actually returned class.
190212
# for example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where

0 commit comments

Comments
 (0)