5
5
import torch
6
6
import torch .distributed as dist
7
7
from torch import Tensor
8
- from typing_extensions import Self
8
+ from typing_extensions import Self , override
9
9
10
10
from lightning .fabric .plugins .collectives .collective import Collective
11
11
from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_1_13
@@ -32,60 +32,73 @@ def __init__(self) -> None:
32
32
super ().__init__ ()
33
33
34
34
@property
35
+ @override
35
36
def group (self ) -> CollectibleGroup :
36
37
if self ._group is None :
37
38
self ._group = dist .GroupMember .WORLD
38
39
return super ().group
39
40
40
41
@property
42
+ @override
41
43
def rank (self ) -> int :
42
44
# local rank
43
45
return dist .get_rank (self .group ) # type: ignore[arg-type]
44
46
45
47
@property
48
+ @override
46
49
def world_size (self ) -> int :
47
50
return dist .get_world_size (self .group ) # type: ignore[arg-type]
48
51
52
+ @override
49
53
def broadcast (self , tensor : Tensor , src : int ) -> Tensor :
50
54
dist .broadcast (tensor , src , group = self .group )
51
55
return tensor
52
56
57
+ @override
53
58
def all_reduce (self , tensor : Tensor , op : Union [str , ReduceOp , RedOpType ] = "sum" ) -> Tensor :
54
59
op = self ._convert_to_native_op (op )
55
60
dist .all_reduce (tensor , op = op , group = self .group )
56
61
return tensor
57
62
63
+ @override
58
64
def reduce (self , tensor : Tensor , dst : int , op : Union [str , ReduceOp , RedOpType ] = "sum" ) -> Tensor :
59
65
op = self ._convert_to_native_op (op )
60
66
dist .reduce (tensor , dst , op = op , group = self .group )
61
67
return tensor
62
68
69
+ @override
63
70
def all_gather (self , tensor_list : List [Tensor ], tensor : Tensor ) -> List [Tensor ]:
64
71
dist .all_gather (tensor_list , tensor , group = self .group )
65
72
return tensor_list
66
73
74
+ @override
67
75
def gather (self , tensor : Tensor , gather_list : List [Tensor ], dst : int = 0 ) -> List [Tensor ]:
68
76
dist .gather (tensor , gather_list , dst , group = self .group )
69
77
return gather_list
70
78
79
+ @override
71
80
def scatter (self , tensor : Tensor , scatter_list : List [Tensor ], src : int = 0 ) -> Tensor :
72
81
dist .scatter (tensor , scatter_list , src , group = self .group )
73
82
return tensor
74
83
84
+ @override
75
85
def reduce_scatter (
76
86
self , output : Tensor , input_list : List [Tensor ], op : Union [str , ReduceOp , RedOpType ] = "sum"
77
87
) -> Tensor :
78
88
op = self ._convert_to_native_op (op )
79
89
dist .reduce_scatter (output , input_list , op = op , group = self .group )
80
90
return output
81
91
92
+ @override
82
93
def all_to_all (self , output_tensor_list : List [Tensor ], input_tensor_list : List [Tensor ]) -> List [Tensor ]:
83
94
dist .all_to_all (output_tensor_list , input_tensor_list , group = self .group )
84
95
return output_tensor_list
85
96
97
+ @override
86
98
def send (self , tensor : Tensor , dst : int , tag : int = 0 ) -> None :
87
99
dist .send (tensor , dst , tag = tag , group = self .group )
88
100
101
+ @override
89
102
def recv (self , tensor : Tensor , src : Optional [int ] = None , tag : int = 0 ) -> Tensor :
90
103
dist .recv (tensor , src , tag = tag , group = self .group )
91
104
return tensor
@@ -110,6 +123,7 @@ def scatter_object_list(
110
123
dist .scatter_object_list (scatter_object_output_list , scatter_object_input_list , src , group = self .group )
111
124
return scatter_object_output_list
112
125
126
+ @override
113
127
def barrier (self , device_ids : Optional [List [int ]] = None ) -> None :
114
128
if self .group == dist .GroupMember .NON_GROUP_MEMBER :
115
129
return
@@ -118,6 +132,7 @@ def barrier(self, device_ids: Optional[List[int]] = None) -> None:
118
132
def monitored_barrier (self , timeout : Optional [datetime .timedelta ] = None , wait_all_ranks : bool = False ) -> None :
119
133
dist .monitored_barrier (group = self .group , timeout = timeout , wait_all_ranks = wait_all_ranks )
120
134
135
+ @override
121
136
def setup (self , main_address : Optional [str ] = None , main_port : Optional [str ] = None , ** kwargs : Any ) -> Self :
122
137
if self .is_initialized ():
123
138
return self
@@ -144,6 +159,7 @@ def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = N
144
159
os .environ .pop ("MASTER_PORT" , None )
145
160
return self
146
161
162
+ @override
147
163
def teardown (self ) -> Self :
148
164
group_member = self .group != dist .GroupMember .NON_GROUP_MEMBER
149
165
super ().teardown () # will destroy its own group
@@ -162,29 +178,35 @@ def teardown(self) -> Self:
162
178
return self
163
179
164
180
@classmethod
181
+ @override
165
182
def is_available (cls ) -> bool :
166
183
return dist .is_available ()
167
184
168
185
@classmethod
186
+ @override
169
187
def is_initialized (cls ) -> bool :
170
188
return cls .is_available () and dist .is_initialized ()
171
189
172
190
@classmethod
191
+ @override
173
192
def init_group (cls , ** kwargs : Any ) -> None :
174
193
dist .init_process_group (** kwargs )
175
194
176
195
@classmethod
196
+ @override
177
197
def new_group (cls , ** kwargs : Any ) -> CollectibleGroup :
178
198
return dist .new_group (** kwargs )
179
199
180
200
@classmethod
201
+ @override
181
202
def destroy_group (cls , group : CollectibleGroup ) -> None :
182
203
# can be called by all processes in the default group, group will be `object()` if they are not part of the
183
204
# current group
184
205
if group in dist .distributed_c10d ._pg_map :
185
206
dist .destroy_process_group (group ) # type: ignore[arg-type]
186
207
187
208
@classmethod
209
+ @override
188
210
def _convert_to_native_op (cls , op : Union [str , ReduceOp , RedOpType ]) -> Union [ReduceOp , RedOpType ]:
189
211
# in 1.13, `ReduceOp` has become an empty shell for `RedOpType`, the latter being the actually returned class.
190
212
# for example, `ReduceOp.SUM` returns a `RedOpType.SUM`. the only exception is `RedOpType.PREMUL_SUM` where
0 commit comments