File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change 66from einops import rearrange
77from torch import Tensor
88from torch .distributed import ProcessGroup
9- from torch .distributed ._functional_collectives import all_gather_tensor , reduce_scatter_tensor
109
1110
1211class AllToAll (torch .autograd .Function ):
@@ -67,6 +66,8 @@ def forward(
6766 outputs: Tensor
6867 handle: Optional[Work], if overlap is True
6968 """
69+ from torch .distributed ._functional_collectives import all_gather_tensor
70+
7071 ctx .group = group
7172 ctx .sp_rank = sp_rank
7273 ctx .sp_size = sp_size
@@ -93,6 +94,8 @@ def forward(
9394
9495 @staticmethod
9596 def backward (ctx : Any , * grad_outputs ) -> Tuple [Tensor , None , None ]:
97+ from torch .distributed ._functional_collectives import reduce_scatter_tensor
98+
9699 group = ctx .group
97100 sp_rank = ctx .sp_rank
98101 sp_size = ctx .sp_size
You can’t perform that action at this time.
0 commit comments