@@ -50,7 +50,7 @@ def world_size(self) -> int:
5050
5151 @override
5252 def broadcast (self , tensor : Tensor , src : int ) -> Tensor :
53- dist .broadcast (tensor , src , group = self .group )
53+ dist .broadcast (tensor , src , group = self .group ) # type: ignore[arg-type]
5454 return tensor
5555
5656 @override
@@ -62,7 +62,7 @@ def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum"
6262 @override
6363 def reduce (self , tensor : Tensor , dst : int , op : Union [str , ReduceOp , RedOpType ] = "sum" ) -> Tensor :
6464 op = self ._convert_to_native_op (op )
65- dist .reduce (tensor , dst , op = op , group = self .group )
65+ dist .reduce (tensor , dst , op = op , group = self .group ) # type: ignore[arg-type]
6666 return tensor
6767
6868 @override
@@ -72,12 +72,12 @@ def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]:
7272
7373 @override
7474 def gather (self , tensor : Tensor , gather_list : list [Tensor ], dst : int = 0 ) -> list [Tensor ]:
75- dist .gather (tensor , gather_list , dst , group = self .group )
75+ dist .gather (tensor , gather_list , dst , group = self .group ) # type: ignore[arg-type]
7676 return gather_list
7777
7878 @override
7979 def scatter (self , tensor : Tensor , scatter_list : list [Tensor ], src : int = 0 ) -> Tensor :
80- dist .scatter (tensor , scatter_list , src , group = self .group )
80+ dist .scatter (tensor , scatter_list , src , group = self .group ) # type: ignore[arg-type]
8181 return tensor
8282
8383 @override
@@ -109,27 +109,27 @@ def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]:
109109 def broadcast_object_list (
110110 self , object_list : list [Any ], src : int , device : Optional [torch .device ] = None
111111 ) -> list [Any ]:
112- dist .broadcast_object_list (object_list , src , group = self .group , device = device )
112+ dist .broadcast_object_list (object_list , src , group = self .group , device = device ) # type: ignore[arg-type]
113113 return object_list
114114
115115 def gather_object (self , obj : Any , object_gather_list : list [Any ], dst : int = 0 ) -> list [Any ]:
116- dist .gather_object (obj , object_gather_list , dst , group = self .group )
116+ dist .gather_object (obj , object_gather_list , dst , group = self .group ) # type: ignore[arg-type]
117117 return object_gather_list
118118
119119 def scatter_object_list (
120120 self , scatter_object_output_list : list [Any ], scatter_object_input_list : list [Any ], src : int = 0
121121 ) -> list [Any ]:
122- dist .scatter_object_list (scatter_object_output_list , scatter_object_input_list , src , group = self .group )
122+ dist .scatter_object_list (scatter_object_output_list , scatter_object_input_list , src , group = self .group ) # type: ignore[arg-type]
123123 return scatter_object_output_list
124124
125125 @override
126126 def barrier (self , device_ids : Optional [list [int ]] = None ) -> None :
127127 if self .group == dist .GroupMember .NON_GROUP_MEMBER :
128128 return
129- dist .barrier (group = self .group , device_ids = device_ids )
129+ dist .barrier (group = self .group , device_ids = device_ids ) # type: ignore[arg-type]
130130
131131 def monitored_barrier (self , timeout : Optional [datetime .timedelta ] = None , wait_all_ranks : bool = False ) -> None :
132- dist .monitored_barrier (group = self .group , timeout = timeout , wait_all_ranks = wait_all_ranks )
132+ dist .monitored_barrier (group = self .group , timeout = timeout , wait_all_ranks = wait_all_ranks ) # type: ignore[arg-type]
133133
134134 @override
135135 def setup (self , main_address : Optional [str ] = None , main_port : Optional [str ] = None , ** kwargs : Any ) -> Self :
0 commit comments