Skip to content

Commit 6019f49

Browse files
committed
fix thd cp convergence issue
Signed-off-by: Chen Cui <chcui@nvidia.com>
1 parent fac2cae commit 6019f49

File tree

2 files changed

+58
-25
lines changed

2 files changed

+58
-25
lines changed

src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -161,28 +161,54 @@ def forward(
161161

162162
# CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
163163
# This must happen AFTER vision-text merge so image token positions are correct
164-
if self.config._pg_collection.cp.size() > 1:
165-
# inputs_embeds is (T, B, D), need to transpose to (B, T, D) for get_batch_on_this_cp_rank
164+
cp_size = self.config._pg_collection.cp.size()
165+
if cp_size > 1:
166+
cp_rank = self.config._pg_collection.cp.rank()
167+
168+
# (T, B, D) -> (B, T, D) for slicing
166169
if inputs_embeds is not None:
167170
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
168171

169-
cp_group = self.config._pg_collection.cp
170-
cp_batch = get_batch_on_this_cp_rank(
171-
{
172-
"decoder_input": inputs_embeds,
173-
"labels": labels,
174-
"loss_mask": loss_mask,
175-
"position_ids": position_ids,
176-
"attention_mask": attention_mask,
177-
},
178-
cp_group=cp_group,
179-
)
180-
181-
inputs_embeds = cp_batch.get("decoder_input")
182-
labels = cp_batch.get("labels")
183-
loss_mask = cp_batch.get("loss_mask")
184-
position_ids = cp_batch.get("position_ids")
185-
attention_mask = cp_batch.get("attention_mask")
172+
if packed_seq_params is not None and packed_seq_params.qkv_format == "thd":
173+
import transformer_engine_torch as tex
174+
175+
cu_seqlens = packed_seq_params.cu_seqlens_q
176+
cu_seqlens_padded = (
177+
packed_seq_params.cu_seqlens_q_padded
178+
if packed_seq_params.cu_seqlens_q_padded is not None
179+
else cu_seqlens
180+
)
181+
seq_len = inputs_embeds.size(1)
182+
183+
index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank)
184+
185+
# Slice all tensors using THD indices
186+
if inputs_embeds is not None:
187+
inputs_embeds = inputs_embeds.index_select(1, index)
188+
if labels is not None:
189+
labels = labels.index_select(1, index)
190+
if loss_mask is not None:
191+
loss_mask = loss_mask.index_select(1, index)
192+
if position_ids is not None:
193+
position_ids = position_ids.index_select(1, index)
194+
else:
195+
cp_group = self.config._pg_collection.cp
196+
cp_batch = get_batch_on_this_cp_rank(
197+
{
198+
"decoder_input": inputs_embeds,
199+
"labels": labels,
200+
"loss_mask": loss_mask,
201+
"position_ids": position_ids,
202+
"attention_mask": attention_mask,
203+
},
204+
cp_group=cp_group,
205+
)
206+
207+
inputs_embeds = cp_batch.get("decoder_input")
208+
labels = cp_batch.get("labels")
209+
loss_mask = cp_batch.get("loss_mask")
210+
position_ids = cp_batch.get("position_ids")
211+
attention_mask = cp_batch.get("attention_mask")
186212

187213
# Transpose back to (T, B, D)
188214
if inputs_embeds is not None:
@@ -198,7 +224,8 @@ def forward(
198224
runtime_gather_output=runtime_gather_output,
199225
packed_seq_params=packed_seq_params,
200226
)
201-
return outputs
227+
# Return both outputs and the CP-sliced loss_mask for consistent loss computation
228+
return (outputs, loss_mask)
202229

203230
def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):
204231
"""Freeze model modules.

src/megatron/bridge/training/vlm_step.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from megatron.core.models.gpt import GPTModel
2121
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
22-
from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config
22+
from megatron.core.utils import get_model_config
2323

2424
from megatron.bridge.training.config import ConfigContainer
2525
from megatron.bridge.training.losses import (
@@ -345,11 +345,10 @@ def _ceil_to_mult(n: int, mult: int) -> int:
345345
cu_seqlens = None
346346
max_seqlen = None
347347

348-
cp_batch = get_batch_on_this_cp_rank({"loss_mask": batch.get("loss_mask")}, cp_group=pg_collection.cp)
349348
return (
350349
(batch.get("tokens") if batch.get("tokens") is not None else batch.get("input_ids")),
351350
batch.get("labels"),
352-
cp_batch.get("loss_mask"),
351+
batch.get("loss_mask"), # Full packed loss_mask, will be CP-sliced by model
353352
batch.get("attention_mask"),
354353
batch.get("position_ids"),
355354
cu_seqlens,
@@ -379,6 +378,7 @@ def forward_step(
379378
use_mtp = (getattr(config, "mtp_num_layers", None) or 0) > 0
380379

381380
timers("batch-generator", log_level=2).start()
381+
pg_collection = get_pg_collection(model)
382382
with straggler_timer(bdata=True):
383383
(
384384
tokens,
@@ -389,14 +389,15 @@ def forward_step(
389389
cu_seqlens,
390390
max_seqlen,
391391
visual_inputs,
392-
) = get_batch(data_iterator, state.cfg, use_mtp, pg_collection=get_pg_collection(model))
392+
) = get_batch(data_iterator, state.cfg, use_mtp, pg_collection=pg_collection)
393393
timers("batch-generator").stop()
394394

395395
forward_args = {
396396
"input_ids": tokens,
397397
"position_ids": position_ids,
398398
"attention_mask": attention_mask,
399399
"labels": labels,
400+
"loss_mask": loss_mask, # Pass full loss_mask so model can slice it consistently with labels
400401
}
401402

402403
if visual_inputs is not None:
@@ -423,7 +424,12 @@ def forward_step(
423424
loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss)
424425
return schedule_plan, loss_function
425426
else:
426-
output_tensor = model(**forward_args)
427+
model_output = model(**forward_args)
428+
# Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CPI'm
429+
if isinstance(model_output, tuple):
430+
output_tensor, loss_mask = model_output
431+
else:
432+
output_tensor = model_output
427433

428434
loss_function = _create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss)
429435

0 commit comments

Comments
 (0)