Skip to content

Commit 4e8beca

Browse files
authored
move async comm import into function (#73)
1 parent 06e0381 commit 4e8beca

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

opendit/utils/operation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from einops import rearrange
77
from torch import Tensor
88
from torch.distributed import ProcessGroup
9-
from torch.distributed._functional_collectives import all_gather_tensor, reduce_scatter_tensor
109

1110

1211
class AllToAll(torch.autograd.Function):
@@ -67,6 +66,8 @@ def forward(
6766
outputs: Tensor
6867
handle: Optional[Work], if overlap is True
6968
"""
69+
from torch.distributed._functional_collectives import all_gather_tensor
70+
7071
ctx.group = group
7172
ctx.sp_rank = sp_rank
7273
ctx.sp_size = sp_size
@@ -93,6 +94,8 @@ def forward(
9394

9495
@staticmethod
9596
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
97+
from torch.distributed._functional_collectives import reduce_scatter_tensor
98+
9699
group = ctx.group
97100
sp_rank = ctx.sp_rank
98101
sp_size = ctx.sp_size

0 commit comments

Comments
 (0)