@@ -86,12 +86,12 @@ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
8686 Works for both logps (float) and masks (bool/int).
8787
8888 Args:
89- tensor: [1, packed_len/cp_size] - CP-split tensor (any dtype)
90- packed_seq_params: PackedSeqParams object
89+ tensor: [1, packed_len/cp_size] in padding_free mode, or [batch_size, seq_len/cp_size] otherwise
90+ packed_seq_params: PackedSeqParams object (None in non-padding_free mode)
9191 num_samples: Number of samples in the batch
9292
9393 Returns:
94- output_full: [1, packed_len] - Full sequence tensor
94+ output_full: [1, packed_len] in padding_free mode, or [batch_size, seq_len] otherwise
9595 """
9696 args = get_args ()
9797 cp_size = args .context_parallel_size
@@ -102,36 +102,61 @@ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
102102 torch .distributed .all_gather (output_list , tensor .contiguous (), group = mpu .get_context_parallel_group ())
103103 output_list [cp_rank ] = tensor
104104
105- # Reconstruct full sequence
106- # Shape : [1, packed_len/cp_size] -> [1, packed_len]
107- cu_seqlens_full = packed_seq_params .cu_seqlens_q
108- cu_seqlens_cp = cu_seqlens_full // cp_size
105+ if packed_seq_params is not None :
106+ # padding_free mode : [1, packed_len/cp_size] -> [1, packed_len]
107+ cu_seqlens_full = packed_seq_params .cu_seqlens_q
108+ cu_seqlens_cp = cu_seqlens_full // cp_size
109109
110- # Calculate total packed length
111- total_packed_len = cu_seqlens_full [num_samples ].item ()
112- output_full = tensor .new_zeros (1 , total_packed_len )
110+ # Calculate total packed length
111+ total_packed_len = cu_seqlens_full [num_samples ].item ()
112+ output_full = tensor .new_zeros (1 , total_packed_len )
113113
114- # Reconstruct each sequence
115- for i in range (num_samples ):
116- start_full = cu_seqlens_full [i ].item ()
117- end_full = cu_seqlens_full [i + 1 ].item ()
118- seq_len = end_full - start_full
114+ # Reconstruct each sequence
115+ for i in range (num_samples ):
116+ start_full = cu_seqlens_full [i ].item ()
117+ end_full = cu_seqlens_full [i + 1 ].item ()
118+ seq_len = end_full - start_full
119+
120+ # Length of each chunk after CP split
121+ chunk_len = seq_len // cp_size
122+ half_chunk = chunk_len // 2
123+
124+ # Concatenate from each CP rank's output (load-balanced split)
125+ for j in range (cp_size ):
126+ o = output_list [j ][0 ]
127+ start_cp = cu_seqlens_cp [i ].item ()
128+
129+ # Get two half chunks (CP's load-balanced split)
130+ o0 = o [start_cp :start_cp + half_chunk ]
131+ o1 = o [start_cp + half_chunk :start_cp + chunk_len ]
132+
133+ # Place back to full sequence
134+ output_full [0 , start_full + j * half_chunk :start_full + (j + 1 ) * half_chunk ] = o0
135+ output_full [0 , end_full - (j + 1 ) * half_chunk :end_full - j * half_chunk ] = o1
136+ else :
137+ # non-padding_free mode: [batch_size, seq_len/cp_size] -> [batch_size, seq_len]
138+ # Each CP rank has chunks split with load-balanced pattern (2*cp_size chunks)
139+ batch_size = tensor .shape [0 ]
140+ seq_len_per_cp = tensor .shape [1 ]
141+ full_seq_len = seq_len_per_cp * cp_size
119142
120- # Length of each chunk after CP split
121- chunk_len = seq_len // cp_size
122- half_chunk = chunk_len // 2
143+ output_full = tensor .new_zeros (batch_size , full_seq_len )
123144
124- # Concatenate from each CP rank's output (load-balanced split)
125- for j in range (cp_size ):
126- o = output_list [j ][0 ]
127- start_cp = cu_seqlens_cp [i ].item ()
145+ # Each CP rank j holds chunks j and (2*cp_size - j - 1) from the original 2*cp_size split
146+ # Reconstruct the full sequence by placing chunks back in correct positions
147+ chunk_len = full_seq_len // (2 * cp_size )
128148
129- # Get two half chunks (CP's load-balanced split)
130- o0 = o [start_cp :start_cp + half_chunk ]
131- o1 = o [start_cp + half_chunk :start_cp + chunk_len ]
132-
133- # Place back to full sequence
134- output_full [0 , start_full + j * half_chunk :start_full + (j + 1 ) * half_chunk ] = o0
135- output_full [0 , end_full - (j + 1 ) * half_chunk :end_full - j * half_chunk ] = o1
149+ for j in range (cp_size ):
150+ o = output_list [j ] # [batch_size, seq_len_per_cp]
151+ # This rank holds 2 chunks: chunk j and chunk (2*cp_size - j - 1)
152+ half_len = seq_len_per_cp // 2
153+ o0 = o [:, :half_len ] # First half -> chunk j
154+ o1 = o [:, half_len :] # Second half -> chunk (2*cp_size - j - 1)
155+
156+ # Place chunk j at position j * chunk_len
157+ output_full [:, j * chunk_len :(j + 1 ) * chunk_len ] = o0
158+ # Place chunk (2*cp_size - j - 1) at position (2*cp_size - j - 1) * chunk_len
159+ reverse_idx = 2 * cp_size - j - 1
160+ output_full [:, reverse_idx * chunk_len :(reverse_idx + 1 ) * chunk_len ] = o1
136161
137162 return output_full
0 commit comments