@@ -233,6 +233,21 @@ def post_forward(self, module, output):
233233 return output [0 ] if is_tensor else tuple (output )
234234
235235
236+ class AllGatherFunction (torch .autograd .Function ):
237+ @staticmethod
238+ def forward (ctx , tensor , dim , group ):
239+ ctx .dim = dim
240+ ctx .group = group
241+ ctx .world_size = torch .distributed .get_world_size (group )
242+ ctx .rank = torch .distributed .get_rank (group )
243+ return funcol .all_gather_tensor (tensor , dim , group = group )
244+
245+ @staticmethod
246+ def backward (ctx , grad_output ):
247+ grad_chunks = torch .chunk (grad_output , ctx .world_size , dim = ctx .dim )
248+ return grad_chunks [ctx .rank ], None , None
249+
250+
236251class EquipartitionSharder :
237252 @classmethod
238253 def shard (cls , tensor : torch .Tensor , dim : int , mesh : torch .distributed .device_mesh .DeviceMesh ) -> torch .Tensor :
@@ -246,7 +261,7 @@ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_me
246261 @classmethod
247262 def unshard (cls , tensor : torch .Tensor , dim : int , mesh : torch .distributed .device_mesh .DeviceMesh ) -> torch .Tensor :
248263 tensor = tensor .contiguous ()
249- tensor = funcol . all_gather_tensor (tensor , dim , group = mesh .get_group ())
264+ tensor = AllGatherFunction . apply (tensor , dim , mesh .get_group ())
250265 return tensor
251266
252267
0 commit comments