@@ -663,14 +663,29 @@ def send_tensor_dict(
663
663
tensor_dict : dict [str , Union [torch .Tensor , Any ]],
664
664
dst : Optional [int ] = None ,
665
665
all_gather_group : Optional ["GroupCoordinator" ] = None ,
666
+ all_gather_tensors : Optional [dict [str , bool ]] = None ,
666
667
) -> Optional [dict [str , Union [torch .Tensor , Any ]]]:
667
668
"""Send the input tensor dictionary.
668
669
NOTE: `dst` is the local rank of the source rank.
670
+
671
+ all_gather_group: The group for the all-gather operation. If provided,
672
+ an optimization is enabled where each rank in the group sends a
673
+ slice of a tensor and the receiver reconstructs it using an
674
+ all-gather, which can improve performance. This is typically the
675
+ tensor-parallel group.
676
+ all_gather_tensors: A dictionary to specify which tensors should use
677
+ the all-gather optimization, which is only effective when
678
+ `all_gather_group` is provided. By default, this optimization is
679
+ on for any tensor whose size is divisible by the
680
+ `all_gather_group`'s world size. However, it should be disabled
681
+ for tensors that are not fully replicated across the group (e.g.,
682
+ the residual tensor when sequence parallelism is enabled). This
683
+ dictionary allows overriding the default behavior on a per-tensor
684
+ basis.
669
685
"""
670
686
# Bypass the function if we are using only 1 GPU.
671
687
if not torch .distributed .is_initialized () or self .world_size == 1 :
672
688
return tensor_dict
673
-
674
689
all_gather_size = (1 if all_gather_group is None else
675
690
all_gather_group .world_size )
676
691
all_gather_rank = (0 if all_gather_group is None else
@@ -699,14 +714,23 @@ def send_tensor_dict(
699
714
# `send_object_list` has serialization & deserialization,
700
715
# all happening on CPU. Therefore, we can use the CPU group.
701
716
self .send_object (metadata_list , dst = dst )
702
- for tensor in tensor_list :
717
+
718
+ tensor_keys = [
719
+ k for k , v in tensor_dict .items () if isinstance (v , torch .Tensor )
720
+ ]
721
+ assert len (tensor_keys ) == len (tensor_list )
722
+
723
+ for key , tensor in zip (tensor_keys , tensor_list ):
703
724
if tensor .numel () == 0 :
704
725
# Skip sending empty tensors.
705
726
continue
706
727
707
728
# send-allgather: send only a slice, then do allgather.
708
- if (all_gather_group is not None
709
- and tensor .numel () % all_gather_size == 0 ):
729
+ use_all_gather = (all_gather_group is not None
730
+ and tensor .numel () % all_gather_size == 0 )
731
+ use_all_gather = all_gather_tensors .get (key , use_all_gather ) \
732
+ if all_gather_tensors else use_all_gather
733
+ if use_all_gather :
710
734
tensor = tensor .reshape (all_gather_size , - 1 )[all_gather_rank ]
711
735
712
736
if tensor .is_cpu :
@@ -725,14 +749,29 @@ def recv_tensor_dict(
725
749
self ,
726
750
src : Optional [int ] = None ,
727
751
all_gather_group : Optional ["GroupCoordinator" ] = None ,
752
+ all_gather_tensors : Optional [dict [str , bool ]] = None ,
728
753
) -> Optional [dict [str , Union [torch .Tensor , Any ]]]:
729
754
"""Recv the input tensor dictionary.
730
755
NOTE: `src` is the local rank of the source rank.
756
+
757
+ all_gather_group: The group for the all-gather operation. If provided,
758
+ an optimization is enabled where each rank in the group sends a
759
+ slice of a tensor and the receiver reconstructs it using an
760
+ all-gather, which can improve performance. This is typically the
761
+ tensor-parallel group.
762
+ all_gather_tensors: A dictionary to specify which tensors should use
763
+ the all-gather optimization, which is only effective when
764
+ `all_gather_group` is provided. By default, this optimization is
765
+ on for any tensor whose size is divisible by the
766
+ `all_gather_group`'s world size. However, it should be disabled
767
+ for tensors that are not fully replicated across the group (e.g.,
768
+ the residual tensor when sequence parallelism is enabled). This
769
+ dictionary allows overriding the default behavior on a per-tensor
770
+ basis.
731
771
"""
732
772
# Bypass the function if we are using only 1 GPU.
733
773
if not torch .distributed .is_initialized () or self .world_size == 1 :
734
774
return None
735
-
736
775
all_gather_size = (1 if all_gather_group is None else
737
776
all_gather_group .world_size )
738
777
all_gather_rank = (0 if all_gather_group is None else
@@ -766,6 +805,8 @@ def recv_tensor_dict(
766
805
# send-allgather: send only a slice, then do allgather.
767
806
use_all_gather = (all_gather_group is not None
768
807
and tensor .numel () % all_gather_size == 0 )
808
+ use_all_gather = all_gather_tensors .get (key , use_all_gather ) \
809
+ if all_gather_tensors else use_all_gather
769
810
770
811
if use_all_gather :
771
812
orig_shape = tensor .shape
0 commit comments