Skip to content

Commit 30da1d2

Browse files
committed
# type: ignore[arg-type]
1 parent 40096bf commit 30da1d2

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/lightning/fabric/plugins/collectives/torch_collective.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def world_size(self) -> int:
5050

5151
@override
5252
def broadcast(self, tensor: Tensor, src: int) -> Tensor:
53-
dist.broadcast(tensor, src, group=self.group)
53+
dist.broadcast(tensor, src, group=self.group) # type: ignore[arg-type]
5454
return tensor
5555

5656
@override
@@ -62,7 +62,7 @@ def all_reduce(self, tensor: Tensor, op: Union[str, ReduceOp, RedOpType] = "sum"
6262
@override
6363
def reduce(self, tensor: Tensor, dst: int, op: Union[str, ReduceOp, RedOpType] = "sum") -> Tensor:
6464
op = self._convert_to_native_op(op)
65-
dist.reduce(tensor, dst, op=op, group=self.group)
65+
dist.reduce(tensor, dst, op=op, group=self.group) # type: ignore[arg-type]
6666
return tensor
6767

6868
@override
@@ -72,12 +72,12 @@ def all_gather(self, tensor_list: list[Tensor], tensor: Tensor) -> list[Tensor]:
7272

7373
@override
7474
def gather(self, tensor: Tensor, gather_list: list[Tensor], dst: int = 0) -> list[Tensor]:
75-
dist.gather(tensor, gather_list, dst, group=self.group)
75+
dist.gather(tensor, gather_list, dst, group=self.group) # type: ignore[arg-type]
7676
return gather_list
7777

7878
@override
7979
def scatter(self, tensor: Tensor, scatter_list: list[Tensor], src: int = 0) -> Tensor:
80-
dist.scatter(tensor, scatter_list, src, group=self.group)
80+
dist.scatter(tensor, scatter_list, src, group=self.group) # type: ignore[arg-type]
8181
return tensor
8282

8383
@override
@@ -109,27 +109,27 @@ def all_gather_object(self, object_list: list[Any], obj: Any) -> list[Any]:
109109
def broadcast_object_list(
110110
self, object_list: list[Any], src: int, device: Optional[torch.device] = None
111111
) -> list[Any]:
112-
dist.broadcast_object_list(object_list, src, group=self.group, device=device)
112+
dist.broadcast_object_list(object_list, src, group=self.group, device=device) # type: ignore[arg-type]
113113
return object_list
114114

115115
def gather_object(self, obj: Any, object_gather_list: list[Any], dst: int = 0) -> list[Any]:
116-
dist.gather_object(obj, object_gather_list, dst, group=self.group)
116+
dist.gather_object(obj, object_gather_list, dst, group=self.group) # type: ignore[arg-type]
117117
return object_gather_list
118118

119119
def scatter_object_list(
120120
self, scatter_object_output_list: list[Any], scatter_object_input_list: list[Any], src: int = 0
121121
) -> list[Any]:
122-
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group)
122+
dist.scatter_object_list(scatter_object_output_list, scatter_object_input_list, src, group=self.group) # type: ignore[arg-type]
123123
return scatter_object_output_list
124124

125125
@override
126126
def barrier(self, device_ids: Optional[list[int]] = None) -> None:
127127
if self.group == dist.GroupMember.NON_GROUP_MEMBER:
128128
return
129-
dist.barrier(group=self.group, device_ids=device_ids)
129+
dist.barrier(group=self.group, device_ids=device_ids) # type: ignore[arg-type]
130130

131131
def monitored_barrier(self, timeout: Optional[datetime.timedelta] = None, wait_all_ranks: bool = False) -> None:
132-
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks)
132+
dist.monitored_barrier(group=self.group, timeout=timeout, wait_all_ranks=wait_all_ranks) # type: ignore[arg-type]
133133

134134
@override
135135
def setup(self, main_address: Optional[str] = None, main_port: Optional[str] = None, **kwargs: Any) -> Self:

0 commit comments

Comments
 (0)