5757 AllreduceOptions ,
5858 BroadcastOptions ,
5959 ReduceOp ,
60+ ReduceScatterOptions ,
6061 Work ,
6162)
6263from torch .futures import Future
@@ -159,6 +160,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
159160 opts .rootRank = root
160161 return self .broadcast ([tensor ], opts )
161162
163+ # pyre-fixme[14]: inconsistent override
164+ def reduce_scatter (
165+ self ,
166+ output_tensors : List [torch .Tensor ],
167+ input_tensors : List [List [torch .Tensor ]],
168+ opts : ReduceScatterOptions ,
169+ ) -> Work :
170+ """
171+ Reduces, then scatters a list of tensors to all processes in a group.
172+
173+ See torch.distributed.reduce_scatter for more details.
174+ """
175+ raise NotImplementedError ("not implemented" )
176+
162177 def size (self ) -> int :
163178 raise NotImplementedError ("not implemented" )
164179
@@ -267,6 +282,14 @@ def allgather(
267282 def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
268283 return self .parent .broadcast (tensor_list , opts )
269284
285+ def reduce_scatter (
286+ self ,
287+ output_tensors : List [torch .Tensor ],
288+ input_tensors : List [List [torch .Tensor ]],
289+ opts : object ,
290+ ) -> Work :
291+ return self .parent .reduce_scatter (output_tensors , input_tensors , opts )
292+
270293 def size (self ) -> int :
271294 return self .parent .size ()
272295
@@ -295,6 +318,25 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
295318 def getBackendName (self ) -> str :
296319 return "torchft-gloo"
297320
321+ # pyre-fixme[14,15]: inconsistent override
322+ def reduce_scatter (
323+ self ,
324+ output_tensors : List [torch .Tensor ],
325+ input_tensors : List [List [torch .Tensor ]],
326+ opts : ReduceScatterOptions ,
327+ ) -> None :
328+ """
329+ This function is a placeholder for the reduce_scatter operation in the
330+ ProcessGroupGloo class. However, this operation is not supported by the
331+ Gloo backend, and thus, calling this function will raise a
332+ RuntimeError.
333+
334+ Raises:
335+ RuntimeError: Always raised since reduce_scatter is not
336+ supported by ProcessGroupGloo.
337+ """
338+ raise RuntimeError ("ProcessGroupGloo does not support reduce_scatter." )
339+
298340
299341class ProcessGroupNCCL (ProcessGroupWrapper ):
300342 """
@@ -354,11 +396,6 @@ def __init__(self, rank: int, world: int) -> None:
354396 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
355397 self .configure_count += 1
356398
357- def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
358- res = _DummyWork (tensor_list )
359- self ._work .append (res )
360- return res
361-
362399 def allgather (
363400 self ,
364401 output_tensors : List [List [torch .Tensor ]],
@@ -377,6 +414,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
377414 self ._work .append (res )
378415 return res
379416
417+ def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
418+ res = _DummyWork (tensor_list )
419+ self ._work .append (res )
420+ return res
421+
422+ def reduce_scatter (
423+ self ,
424+ output_tensors : List [torch .Tensor ],
425+ input_tensors : List [List [torch .Tensor ]],
426+ opts : object ,
427+ ) -> Work :
428+ for o , i in zip (output_tensors , input_tensors [0 ]):
429+ o .copy_ (i )
430+
431+ res = _DummyWork (output_tensors )
432+ self ._work .append (res )
433+ return res
434+
380435 def size (self ) -> int :
381436 return self ._world
382437
@@ -970,6 +1025,25 @@ def broadcast(
9701025
9711026 return self ._run_func ("broadcast" , tensor_list , opts )
9721027
1028+ def reduce_scatter (
1029+ self ,
1030+ output_tensors : List [torch .Tensor ],
1031+ input_tensors : List [List [torch .Tensor ]],
1032+ opts : ReduceScatterOptions ,
1033+ ) -> Work :
1034+ assert isinstance (output_tensors , list ), "input must be list"
1035+ assert isinstance (input_tensors , list ), "input must be list"
1036+
1037+ for tensor in output_tensors :
1038+ if not tensor .is_shared ():
1039+ tensor .share_memory_ ()
1040+
1041+ for tensor_list in input_tensors :
1042+ for tensor in tensor_list :
1043+ if not tensor .is_shared ():
1044+ tensor .share_memory_ ()
1045+ return self ._run_func ("reduce_scatter" , output_tensors , input_tensors , opts )
1046+
9731047 def size (self ) -> int :
9741048 return self ._world_size
9751049
@@ -992,7 +1066,15 @@ def safe_args(cls, args: T) -> T:
9921066 return tuple (cls .safe_args (arg ) for arg in args )
9931067 elif isinstance (args , list ):
9941068 return [cls .safe_args (arg ) for arg in args ]
995- elif isinstance (args , (AllreduceOptions , AllgatherOptions , BroadcastOptions )):
1069+ elif isinstance (
1070+ args ,
1071+ (
1072+ AllreduceOptions ,
1073+ AllgatherOptions ,
1074+ BroadcastOptions ,
1075+ ReduceScatterOptions ,
1076+ ),
1077+ ):
9961078 return cls .from_torch (args )
9971079 else :
9981080 return args
@@ -1038,6 +1120,25 @@ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGrou
10381120 def getBackendName (self ) -> str :
10391121 return "torchft-baby-gloo"
10401122
1123+ # pyre-fixme[15]: inconsistent override
1124+ def reduce_scatter (
1125+ self ,
1126+ output_tensors : List [torch .Tensor ],
1127+ input_tensors : List [List [torch .Tensor ]],
1128+ opts : ReduceScatterOptions ,
1129+ ) -> None :
1130+ """
1131+ This function is a placeholder for the reduce_scatter operation in the
1132+ ProcessGroupGloo class. However, this operation is not supported by the
1133+ Gloo backend, and thus, calling this function will raise a
1134+ RuntimeError.
1135+
1136+ Raises:
1137+ RuntimeError: Always raised since reduce_scatter is not
1138+ supported by ProcessGroupGloo.
1139+ """
1140+ raise RuntimeError ("ProcessGroupBabyGloo does not support reduce_scatter." )
1141+
10411142
10421143class ProcessGroupBabyNCCL (ProcessGroupBaby ):
10431144 """
0 commit comments