Skip to content
3 changes: 3 additions & 0 deletions src/megatron/bridge/data/vlm_datasets/hf_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class HFDatasetConversationProvider(DatasetProvider):
# DataloaderConfig fields are inherited (num_workers, dataloader_type, etc.)
dataloader_type: Optional[Literal["single", "cyclic", "external"]] = "single"

# Enable batch-level online sequence packing (dataset-level packing is available in FinetuneDatasetProvider)
pack_sequences_in_batch: bool = False

def _get_maker(self) -> Callable[..., List[Dict[str, Any]]]:
registry: Dict[str, Callable[..., List[Dict[str, Any]]]] = {
"make_rdr_dataset": make_rdr_dataset,
Expand Down
31 changes: 27 additions & 4 deletions src/megatron/bridge/models/gemma/gemma3_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,34 @@ def __init__(
**kwargs,
)

def forward(
self,
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
cp_group: torch.distributed.ProcessGroup | None = None,
) -> Tensor:
"""Get global and local rope embedding.

Note: Caching is bypassed when cp_group is provided since ProcessGroup is unhashable.
"""
# ProcessGroup is unhashable, so bypass caching when cp_group is provided
if cp_group is not None:
rope_global = super().forward(max_seq_len, offset, packed_seq, cp_group)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, cp_group)
return rope_local, rope_global
return self._forward_cached(max_seq_len, offset, packed_seq)

@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
"""Get global and local rope embedding"""
rope_global = super().forward(max_seq_len, offset, packed_seq)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
def _forward_cached(
self,
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
) -> Tensor:
"""Cached forward for hashable parameters only."""
rope_global = super().forward(max_seq_len, offset, packed_seq, None)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, None)
return rope_local, rope_global


Expand Down
54 changes: 43 additions & 11 deletions src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,30 @@

import types
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import MegatronModule
from torch import Tensor
from transformers import AutoModel, Gemma3Model

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
from megatron.bridge.utils.common_utils import (
hook_hf_module_setattr_for_tp_grad_sync,
slice_batch_for_context_parallel,
)
from megatron.bridge.utils.import_utils import safe_import_from


if TYPE_CHECKING:
from megatron.core.packed_seq_params import PackedSeqParams


TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm")


Expand Down Expand Up @@ -110,12 +118,16 @@ def forward(
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
) -> tuple[Tensor, Tensor | None]:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
Forward pass combining HuggingFace vision encoder with Megatron language model.

Returns:
tuple: (output_tensor, loss_mask) where output_tensor contains model output
and loss_mask is the CP-sliced mask for consistent loss computation.
"""
if self.pre_process:
if inputs_embeds is None:
Expand All @@ -134,7 +146,7 @@ def forward(
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)

if inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = special_image_mask.sum(dim=1).item(dim=0)[0]
image_tokens_in_text = special_image_mask[:, :, 0].sum().item()
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
Expand All @@ -144,18 +156,38 @@ def forward(
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D)

# Apply sequence parallelism scatter if enabled
if self.config.sequence_parallel:
inputs_embeds = scatter_to_sequence_parallel_region(inputs_embeds)

# Compute attention mask on FULL sequence (before CP slicing)
# This is needed because image regions need bidirectional attention
attention_mask = self._compute_attention_mask(input_ids)

# CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
# This must happen AFTER vision-text merge so image token positions are correct
inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel(
inputs_embeds=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
pg_collection=self.config._pg_collection,
)

outputs = self.language_model.forward(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask, # (B, 1, T, T)
decoder_input=inputs_embeds, # (T, B, D)
labels=labels, # (B, T)
attention_mask=attention_mask,
decoder_input=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs
# Return both outputs and the CP-sliced loss_mask for consistent loss computation
return (outputs, loss_mask)

def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool):
"""Freeze model modules.
Expand Down Expand Up @@ -191,7 +223,7 @@ def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_
def _compute_attention_mask(
self,
input_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Optional[torch.Tensor]:
if not self.pre_process:
return None
batch_size, seq_len = input_ids.shape
Expand Down
8 changes: 7 additions & 1 deletion src/megatron/bridge/models/glm_vl/modeling_glm_45v.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import types
from typing import Optional
from typing import TYPE_CHECKING, Optional

import torch
import transformers
Expand All @@ -37,6 +37,10 @@
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync


if TYPE_CHECKING:
from megatron.core.packed_seq_params import PackedSeqParams


def is_transformers_min_version(version):
"""Check if minimum version of transformers is installed."""
try:
Expand Down Expand Up @@ -158,6 +162,7 @@ def forward(
video_grid_thw: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
Expand Down Expand Up @@ -233,6 +238,7 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs

Expand Down
13 changes: 10 additions & 3 deletions src/megatron/bridge/models/ministral3/ministral3_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,15 @@ def __init__(
self.max_position_embeddings = self.config.seq_length

def _get_llama_4_attn_scale(
self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int
self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
) -> torch.Tensor:
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
return scaling.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
# Add dimensions to match query shape: [seq_len] -> [seq_len, 1, 1] for packed or [seq_len, 1, 1, 1] for unpacked
# Query can be either [seq_len, num_heads, head_dim] (packed) or [seq_len, batch, num_heads, head_dim] (unpacked)
num_dims_to_add = len(query_shape) - 1
for _ in range(num_dims_to_add):
scaling = scaling.unsqueeze(-1)
return scaling

def forward(
self,
Expand All @@ -276,6 +281,8 @@ def forward(
**kwargs,
):
positions_ids = torch.arange(query.shape[0], device=query.device)
query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings).to(query.dtype)
query *= self._get_llama_4_attn_scale(positions_ids, self.beta, self.max_position_embeddings, query.shape).to(
query.dtype
)

return super().forward(query, key, value, attention_mask, attn_mask_type, **kwargs)
33 changes: 28 additions & 5 deletions src/megatron/bridge/models/ministral3/modeling_ministral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@
"""

import types
from typing import Optional
from typing import TYPE_CHECKING, Optional

import torch
from megatron.core.transformer.module import MegatronModule
from torch import Tensor

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync
from megatron.bridge.utils.common_utils import (
hook_hf_module_setattr_for_tp_grad_sync,
slice_batch_for_context_parallel,
)


if TYPE_CHECKING:
from megatron.core.packed_seq_params import PackedSeqParams


# Import HuggingFace Mistral3 model classes with fallback
Expand Down Expand Up @@ -185,9 +192,10 @@ def forward(
labels: Optional[torch.Tensor] = None,
runtime_gather_output: Optional[bool] = None,
image_sizes: Optional[torch.Tensor] = None,
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
) -> tuple[Tensor, Tensor | None]:
"""
Forward pass combining HuggingFace vision encoder with Megatron language model.

Expand All @@ -202,7 +210,8 @@ def forward(
loss_mask: Mask for loss computation.

Returns:
Model output (logits or loss depending on mode).
tuple: (output_tensor, loss_mask) where output_tensor contains model output
and loss_mask is the CP-sliced mask for consistent loss computation.
"""
if self.pre_process:
if inputs_embeds is None:
Expand Down Expand Up @@ -237,6 +246,18 @@ def forward(
# Transpose back to Megatron format [seq_len, batch, hidden]
inputs_embeds = inputs_embeds.transpose(1, 0).contiguous()

# CP slicing: slice embeddings, labels, loss_mask, position_ids, and attention_mask
# This must happen AFTER vision-text merge so image token positions are correct
inputs_embeds, labels, loss_mask, position_ids, attention_mask = slice_batch_for_context_parallel(
inputs_embeds=inputs_embeds,
labels=labels,
loss_mask=loss_mask,
position_ids=position_ids,
attention_mask=attention_mask,
packed_seq_params=packed_seq_params,
pg_collection=self.config._pg_collection,
)

# Forward through Megatron language model
outputs = self.language_model.forward(
input_ids=None,
Expand All @@ -246,8 +267,10 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs
# Return both outputs and the CP-sliced loss_mask for consistent loss computation
return (outputs, loss_mask)

def freeze(
self,
Expand Down
1 change: 1 addition & 0 deletions src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def forward(
labels=labels,
loss_mask=loss_mask,
runtime_gather_output=runtime_gather_output,
packed_seq_params=packed_seq_params,
)
return outputs

Expand Down
2 changes: 2 additions & 0 deletions src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def _gemma3_vl_common(
model_cfg.freeze_vision_model = freeze_vision_model
model_cfg.freeze_vision_projection = freeze_vision_projection
model_cfg.seq_length = seq_length
if model_cfg.context_parallel_size > 1:
model_cfg.cp_comm_type = "a2a"

# Optimizer and scheduler - use finetune_lr if provided, otherwise use lr
effective_lr = finetune_lr if finetune_lr is not None else lr
Expand Down
12 changes: 8 additions & 4 deletions src/megatron/bridge/training/utils/packed_seq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,18 @@ def get_packed_seq_params(batch: dict[str, torch.Tensor]) -> PackedSeqParams:
cu_seqlens_unpadded_argmin = batch.get("cu_seqlens_unpadded_argmin")

if cu_seqlens_argmin is not None:
cu_seqlens_padded = cu_seqlens_padded[: cu_seqlens_argmin.item()]
else:
argmin_idx = cu_seqlens_argmin.item()
assert argmin_idx == 0 or cu_seqlens_padded[argmin_idx] == -1 # cu_seqlens padding is -1
cu_seqlens_padded = cu_seqlens_padded[:argmin_idx]
elif torch.min(cu_seqlens_padded) == -1:
cu_seqlens_padded = cu_seqlens_padded[: torch.argmin(cu_seqlens_padded)]

if cu_seqlens_unpadded is not None:
if cu_seqlens_unpadded_argmin is not None:
cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()]
else:
argmin_idx = cu_seqlens_unpadded_argmin.item()
assert argmin_idx == 0 or cu_seqlens_unpadded[argmin_idx] == -1 # cu_seqlens padding is -1
cu_seqlens_unpadded = cu_seqlens_unpadded[:argmin_idx]
elif torch.min(cu_seqlens_unpadded) == -1:
cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)]

max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None
Expand Down
Loading
Loading