Skip to content

Commit 40096bf

Browse files
committed
Revert "cast(super().group, ProcessGroup)"
This reverts commit 432b0ca.
1 parent 432b0ca commit 40096bf

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/lightning/fabric/plugins/collectives/torch_collective.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import datetime
22
import os
3-
from typing import Any, Optional, Union, cast
3+
from typing import Any, Optional, Union
44

55
import torch
66
import torch.distributed as dist
77
from torch import Tensor
88
from typing_extensions import Self, override
99

1010
from lightning.fabric.plugins.collectives.collective import Collective
11-
from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp, ProcessGroup
11+
from lightning.fabric.utilities.types import CollectibleGroup, RedOpType, ReduceOp
1212

1313
if dist.is_available():
1414
from torch.distributed.constants import default_pg_timeout
@@ -32,10 +32,10 @@ def __init__(self) -> None:
3232

3333
@property
3434
@override
35-
def group(self) -> ProcessGroup:
35+
def group(self) -> CollectibleGroup:
3636
if self._group is None:
3737
self._group = dist.GroupMember.WORLD
38-
return cast(super().group, ProcessGroup)
38+
return super().group
3939

4040
@property
4141
@override

0 commit comments

Comments
 (0)