@@ -161,28 +161,54 @@ def forward(
161161
162162 # CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
163163 # This must happen AFTER vision-text merge so image token positions are correct
164- if self .config ._pg_collection .cp .size () > 1 :
165- # inputs_embeds is (T, B, D), need to transpose to (B, T, D) for get_batch_on_this_cp_rank
164+ cp_size = self .config ._pg_collection .cp .size ()
165+ if cp_size > 1 :
166+ cp_rank = self .config ._pg_collection .cp .rank ()
167+
168+ # (T, B, D) -> (B, T, D) for slicing
166169 if inputs_embeds is not None :
167170 inputs_embeds = inputs_embeds .transpose (0 , 1 ).contiguous ()
168171
169- cp_group = self .config ._pg_collection .cp
170- cp_batch = get_batch_on_this_cp_rank (
171- {
172- "decoder_input" : inputs_embeds ,
173- "labels" : labels ,
174- "loss_mask" : loss_mask ,
175- "position_ids" : position_ids ,
176- "attention_mask" : attention_mask ,
177- },
178- cp_group = cp_group ,
179- )
180-
181- inputs_embeds = cp_batch .get ("decoder_input" )
182- labels = cp_batch .get ("labels" )
183- loss_mask = cp_batch .get ("loss_mask" )
184- position_ids = cp_batch .get ("position_ids" )
185- attention_mask = cp_batch .get ("attention_mask" )
172+ if packed_seq_params is not None and packed_seq_params .qkv_format == "thd" :
173+ import transformer_engine_torch as tex
174+
175+ cu_seqlens = packed_seq_params .cu_seqlens_q
176+ cu_seqlens_padded = (
177+ packed_seq_params .cu_seqlens_q_padded
178+ if packed_seq_params .cu_seqlens_q_padded is not None
179+ else cu_seqlens
180+ )
181+ seq_len = inputs_embeds .size (1 )
182+
183+ index = tex .thd_get_partitioned_indices (cu_seqlens_padded , seq_len , cp_size , cp_rank )
184+
185+ # Slice all tensors using THD indices
186+ if inputs_embeds is not None :
187+ inputs_embeds = inputs_embeds .index_select (1 , index )
188+ if labels is not None :
189+ labels = labels .index_select (1 , index )
190+ if loss_mask is not None :
191+ loss_mask = loss_mask .index_select (1 , index )
192+ if position_ids is not None :
193+ position_ids = position_ids .index_select (1 , index )
194+ else :
195+ cp_group = self .config ._pg_collection .cp
196+ cp_batch = get_batch_on_this_cp_rank (
197+ {
198+ "decoder_input" : inputs_embeds ,
199+ "labels" : labels ,
200+ "loss_mask" : loss_mask ,
201+ "position_ids" : position_ids ,
202+ "attention_mask" : attention_mask ,
203+ },
204+ cp_group = cp_group ,
205+ )
206+
207+ inputs_embeds = cp_batch .get ("decoder_input" )
208+ labels = cp_batch .get ("labels" )
209+ loss_mask = cp_batch .get ("loss_mask" )
210+ position_ids = cp_batch .get ("position_ids" )
211+ attention_mask = cp_batch .get ("attention_mask" )
186212
187213 # Transpose back to (T, B, D)
188214 if inputs_embeds is not None :
@@ -198,7 +224,8 @@ def forward(
198224 runtime_gather_output = runtime_gather_output ,
199225 packed_seq_params = packed_seq_params ,
200226 )
201- return outputs
227+ # Return both outputs and the CP-sliced loss_mask for consistent loss computation
228+ return (outputs , loss_mask )
202229
203230 def freeze (self , freeze_language_model : bool , freeze_vision_model : bool , freeze_vision_projection : bool ):
204231 """Freeze model modules.
0 commit comments