Skip to content

Conversation

@cuichenx
Copy link
Contributor

@cuichenx cuichenx commented Jan 20, 2026

What does this PR do ?

Support THD Training in VLMs

TODOs

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

Gemma 3 VL:
image
Loss and grad norm are both matching closely

MInistral 3:
image
Not sure why grad norm is different between cp1 and cp2. It might be a display issue. Will resolve in next PR. Loss is matching very closely.

More plots to come

  • Related to # (issue)

Summary by CodeRabbit

  • New Features

    • Added batch-level sequence packing support for optimized dataset processing
    • Introduced context-parallel distributed training support for Gemma3, Ministral3, GLM, and Qwen vision-language models
  • Refactor

    • Updated model forward signatures to support packed sequence parameters across vision-language models
  • Tests

    • Added comprehensive test coverage for sequence packing utilities, distributed training configurations, and attention scaling algorithms

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
@cuichenx cuichenx changed the title [Draft] THD training in VLMs THD training in VLMs Jan 29, 2026
@coderabbitai
Copy link

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

This PR introduces batch-level sequence packing and context-parallel slicing support for vision-language models. Changes include a new pack_sequences_in_batch configuration flag, packed_seq_params parameter propagation through VLM forward methods, context-parallel group support in rope embeddings, dynamic attention scaling for packed sequences, and comprehensive utilities for CP-aware batch slicing and sequence packing logic.

Changes

Cohort / File(s) Summary
Dataset Configuration
src/megatron/bridge/data/vlm_datasets/hf_provider.py
Added pack_sequences_in_batch boolean field to HFDatasetConversationProvider for configurable batch-level sequence packing.
Rope Embeddings & Attention
src/megatron/bridge/models/gemma/gemma3_provider.py, src/megatron/bridge/models/ministral3/ministral3_provider.py
Added cp_group parameter to Gemma3RotaryEmbedding forward method; introduced dynamic dimensional alignment for Llama 4 attention scaling based on query tensor shape to support both packed and unpacked formats.
VLM Model Signature Updates
src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py, src/megatron/bridge/models/glm_vl/modeling_glm_45v.py, src/megatron/bridge/models/qwen_vl/modeling_qwen25_vl.py, src/megatron/bridge/models/ministral3/modeling_ministral3.py
Propagated packed_seq_params parameter through VLM forward methods to language model calls; Gemma3VL and Ministral3 additionally implement CP-aware batch slicing and return tuples of (outputs, loss_mask).
CP Configuration
src/megatron/bridge/recipes/gemma3_vl/gemma3_vl.py
Set cp_comm_type to "a2a" when context-parallel size exceeds 1.
Training Utilities
src/megatron/bridge/training/vlm_step.py, src/megatron/bridge/training/utils/packed_seq_utils.py
Added pack_batch_sequences function for variable-length sequence packing with cumulative length tracking; enhanced get_packed_seq_params with runtime assertions and explicit argmin-based slicing logic; integrated packing into get_batch workflow with CP-aware padding.
Data Utilities
src/megatron/bridge/utils/common_utils.py, src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py
Added slice_batch_for_context_parallel utility to partition batch tensors for CP using TE indices or Megatron's batch slicing; extended QwenVL task dataclasses with __restore_key__ and __subflavor__ fields and @stateless decorator for sample encoding.
Tests
tests/unit_tests/models/ministral3/test_ministral3_provider.py, tests/unit_tests/training/test_gpt_step.py, tests/unit_tests/training/test_vlm_step.py
Added comprehensive test coverage for dynamic attention scaling, cu_seqlens padding boundary validation, and sequence packing behavior including dtype preservation, position_ids handling, and both CPU and GPU paths.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • #2011: Adds ProcessGroupCollection and CP-group threading through training and VLM code paths, directly complementing this PR's CP slicing and parameter propagation.

Suggested reviewers

  • yaoyu-33
  • meatybobby
🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Draft] THD training in VLMs' is specific and accurately describes the main change: adding THD (Tensor-parallelism with Hierarchical Decomposition or similar packed sequence handling) training support for Vision Language Models.
Docstring Coverage ✅ Passed Docstring coverage is 86.44% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR introduces THD training support in VLMs with comprehensive testing and convergence validation documented for primary implementations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py`:
- Around line 121-124: The method whose signature includes "packed_seq_params:
Optional['PackedSeqParams'] = None, *, loss_mask: Optional[Tensor] = None"
returns a tuple (outputs, loss_mask) but is annotated "-> Tensor"; update the
return type to reflect the actual return (for example "-> Tuple[Tensor,
Optional[Tensor]]"), import Tuple from typing as needed, and update the method
docstring to describe the tuple return (outputs, loss_mask); apply the same
change to the other overloaded/duplicate signature in the file that has the same
parameters.

In `@src/megatron/bridge/models/gemma/gemma3_provider.py`:
- Around line 370-381: The forward method is using `@lru_cache` but accepts an
unhashable torch.distributed.ProcessGroup and uses Optional[...] typing; change
the signature to use PEP 604 union (torch.distributed.ProcessGroup | None) and
avoid passing the ProcessGroup into the cached path: implement a small uncached
branch in forward that, when cp_group is not None, directly calls
super().forward(...) and self.rope_local.forward(...) and returns rope_local,
rope_global; otherwise use the cached logic (keep `@lru_cache` on a helper that
only accepts hashable args like max_seq_len, offset, packed_seq) so caching is
only used for hashable parameters. Ensure you reference the existing forward
method and rope_local/rope_global calls when making the change.

In `@src/megatron/bridge/models/ministral3/modeling_ministral3.py`:
- Around line 195-198: The function whose signature includes the parameters
packed_seq_params: Optional["PackedSeqParams"] = None and *, loss_mask:
Optional[Tensor] = None is annotated as returning -> Tensor but actually returns
a tuple (outputs, loss_mask); update the return type annotation to something
like -> Tuple[Tensor, Optional[Tensor]] (import Tuple if needed) and update the
docstring (the lines describing the return value) to reflect the tuple form;
apply the same fix to the other occurrence that mirrors lines 271-272 so both
methods consistently annotate and document returning (outputs, loss_mask).

In `@src/megatron/bridge/training/vlm_step.py`:
- Around line 425-430: The inline comment above the model output handling in
vlm_step.py has a typo: change "CPI'm" to "CP" in the comment that describes
tuple returns from VLM models with CP; update the comment near the model_output
handling (variables/functions: model_output, output_tensor, loss_mask,
model(**forward_args)) so it reads "from VLM models with CP" instead of "CPI'm".
- Around line 299-303: The attn variable is defined only when tokens_or_input is
not None but the subsequent if attn is not None: check is dedented causing a
potential UnboundLocalError; fix by ensuring attn is always defined before the
check (e.g., initialize attn = batch.get("attention_mask") or move the if attn
is not None: block inside the tokens_or_input branch) and then call
pad_or_truncate_attn_to_len(attn, target_len, seq_cap) and assign back to
batch["attention_mask"]; locate uses of attn, tokens_or_input, batch, and
pad_or_truncate_attn_to_len in vlm_step.py to make the adjustment.
- Around line 151-162: The current fallback for all-padding batches returns
tokens[:1] etc. and cu_seqlens as torch.tensor([0, seq_len]) which falsely
signals a non-empty sequence; update the all-padding branch in vlm_step.py (the
block using valid_sequences, tokens, labels, loss_mask, attention_mask,
position_ids and the torch.tensor(...) cu_seqlens/total_len) to return a
consistent zero-length packed batch: produce empty tensors for
tokens/labels/loss_mask/position_ids with correct dtypes/devices, keep
attention_mask shaped for the original batch if needed, and set cu_seqlens to
[0,0] and total_len to 0 (or otherwise signal length 0) so downstream loss code
sees truly empty content and does not assume seq_len.

In `@src/megatron/bridge/utils/common_utils.py`:
- Around line 319-333: The THD branch assumes inputs_embeds is present by
calling inputs_embeds.size(1) and indexing it; guard against inputs_embeds being
None before computing seq_len and calling tex.thd_get_partitioned_indices.
Update the block around packed_seq_params.qkv_format == "thd" (references:
packed_seq_params, tex.thd_get_partitioned_indices, inputs_embeds) to either
assert inputs_embeds is not None or compute seq_len from an alternative tensor
(e.g., attention_mask or labels) and only call thd_get_partitioned_indices and
index_select when inputs_embeds is non-None.

In `@tests/unit_tests/models/ministral3/test_ministral3_provider.py`:
- Around line 186-194: Test currently duplicates production logic by
reimplementing _get_llama_4_attn_scale; replace this helper with a direct call
to the production function (e.g., import and call the real
_get_llama_4_attn_scale used by the model) and keep the test only asserting
expected output values. Locate the helper method in
tests/unit_tests/models/ministral3/test_ministral3_provider.py, remove the
reimplementation, call the production _get_llama_4_attn_scale with the same
parameters (positions_ids, beta, max_position_embeddings, query_shape), and
assert the returned scaling tensor matches expected values.
- Around line 318-331: Add pytest markers and an explicit CUDA skip to the GPU
test: annotate the test_gpu_tensor_support function with `@pytest.mark.gpu` and
`@pytest.mark.skipif`(not torch.cuda.is_available(), reason="requires CUDA")
instead of returning early, and add a module-level pytestmark =
[pytest.mark.unit] to categorize the file; update imports if necessary to
include pytest at top of the test file so decorators resolve.

In `@tests/unit_tests/training/test_gpt_step.py`:
- Around line 229-246: The test assumes get_packed_seq_params populates
cu_seqlens_q_padded and cu_seqlens_kv_padded when cu_seqlens_unpadded is
missing, but the implementation leaves those padded fields as None; update the
failing tests (e.g., test_packed_seq_params_without_unpadded_fallback and the
subsequent similar test in the same file) to assert that
result.cu_seqlens_q_padded and result.cu_seqlens_kv_padded are None (instead of
comparing to expected_cu_seqlens), or alternatively change get_packed_seq_params
to always populate the *_padded fields—choose the test update approach to align
tests with current behavior.
🧹 Nitpick comments (7)
src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py (1)

22-22: Use Python 3.10+ type hint syntax.

Per coding guidelines, prefer built-in generics and union syntax over typing equivalents:

  • Use tuple instead of Tuple
  • Use X | Y instead of Union[X, Y]
  • Use dict instead of Dict

Also, __subflavor__ at line 192 lacks type parameters.

♻️ Proposed refactor for modern type hints

Line 22:

-from typing import Dict, List, Tuple, Union
+from typing import Any, Dict, List

Line 172:

-    __restore_key__: Tuple[Union[str, int, tuple], ...]
+    __restore_key__: tuple[str | int | tuple, ...]

Lines 190-192:

     __key__: List[str]
-    __restore_key__: List[Tuple[Union[str, int, tuple], ...]]
-    __subflavor__: Dict
+    __restore_key__: list[tuple[str | int | tuple, ...]]
+    __subflavor__: dict[str, Any] | None

As per coding guidelines: "Use 'X | Y' for union types instead of 'Union[X, Y]'" and "Use built-in generics (list, dict, tuple) instead of typing equivalents".

Also applies to: 172-172, 190-192

src/megatron/bridge/utils/common_utils.py (1)

276-284: Add explicit typing (and PEP 604 unions) for the new utility.

packed_seq_params, pg_collection, and the return type are unannotated, and several args are nullable. Adding explicit annotations improves tooling and aligns with repo typing rules.

As per coding guidelines, use type hints for function arguments and return types, and use 'T | None' for nullable types instead of 'Optional[T]'.

src/megatron/bridge/models/glm_vl/modeling_glm_45v.py (1)

153-166: Use a PEP 604 union for the new nullable parameter.

To align with repo typing rules, annotate this as PackedSeqParams | None (string‑quoted is fine if you keep TYPE_CHECKING).

♻️ Suggested change
-        packed_seq_params: Optional["PackedSeqParams"] = None,
+        packed_seq_params: "PackedSeqParams | None" = None,
As per coding guidelines, use 'T | None' for nullable types instead of 'Optional[T]'.
src/megatron/bridge/models/gemma/gemma3_provider.py (1)

371-377: Prefer PEP 604 unions in the new annotation.

Use torch.distributed.ProcessGroup | None to match repo typing rules.

♻️ Suggested change
-        cp_group: Optional[torch.distributed.ProcessGroup] = None,
+        cp_group: torch.distributed.ProcessGroup | None = None,
As per coding guidelines, use 'T | None' for nullable types instead of 'Optional[T]'.
src/megatron/bridge/training/vlm_step.py (1)

138-149: Performance: Device-host synchronization in loop.

The .item() call at line 143 triggers a CUDA synchronization for each sequence in the batch. For larger batch sizes, this can significantly impact performance.

Consider vectorized approach
-    for i in range(batch_size):
-        # Find first padding token or use full length
-        non_pad_mask = tokens[i] != pad_token_id
-        if non_pad_mask.any():
-            # Find the last non-padding token
-            last_valid_idx = non_pad_mask.nonzero(as_tuple=True)[0][-1].item() + 1
-        else:
-            # Empty sequence, skip
-            continue
-
-        seq_lengths.append(last_valid_idx)
-        valid_sequences.append(i)
+    # Vectorized: compute all sequence lengths at once
+    non_pad_mask = tokens != pad_token_id  # [batch_size, seq_len]
+    # Find last non-padding position per sequence
+    reversed_mask = non_pad_mask.flip(dims=[1])
+    first_nonpad_from_end = reversed_mask.int().argmax(dim=1)
+    has_valid = non_pad_mask.any(dim=1)
+    seq_lengths_tensor = seq_len - first_nonpad_from_end
+    seq_lengths_tensor = torch.where(has_valid, seq_lengths_tensor, torch.zeros_like(seq_lengths_tensor))
+    
+    # Single sync to get all lengths
+    seq_lengths_list = seq_lengths_tensor.tolist()
+    valid_sequences = [i for i, length in enumerate(seq_lengths_list) if length > 0]
+    seq_lengths = [seq_lengths_list[i] for i in valid_sequences]
tests/unit_tests/training/test_vlm_step.py (1)

313-313: Address static analysis warnings for unused variables.

Static analysis flagged multiple unused variables. While tests often unpack full tuples for clarity, consider:

  1. Prefix unused variables with underscore (e.g., _packed_labels)
  2. Use the caplog fixture at line 525 to verify the warning was logged
Example fixes
# Line 313
-        packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result
+        packed_tokens, _packed_labels, _packed_loss_mask, packed_attn, _packed_pos, cu_seqlens, max_seqlen = result

# Line 525 - Use caplog to verify warning
     def test_packing_empty_batch_warning(self, caplog):
         """Test that all-padding batch returns dummy values with warning."""
+        import logging
+        caplog.set_level(logging.WARNING)
         tokens = torch.tensor([[0, 0, 0, 0]])  # All padding
         ...
         result = pack_batch_sequences(...)
+        assert "No valid sequences found" in caplog.text

Also applies to: 360-360, 457-457, 480-480, 525-525, 542-542, 592-592

tests/unit_tests/training/test_gpt_step.py (1)

15-25: Add pytest unit markers at module level.

These are unit tests but lack pytest categorization. A module-level marker keeps it consistent with the test guidelines.

✅ Suggested change
 from functools import partial
 from unittest.mock import Mock, patch
 
 import modelopt.torch.distill as mtd
+import pytest
 import torch
 from megatron.core.packed_seq_params import PackedSeqParams
 
@@
 from megatron.bridge.training.losses import (
     create_masked_next_token_loss_function as _create_loss_function,
 )
 
+pytestmark = pytest.mark.unit

As per coding guidelines: Use "pytest.mark" to categorize tests (unit, integration, system).

Comment on lines 121 to 124
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation is inconsistent with actual return value.

The method now returns a tuple (outputs, loss_mask) but the return type annotation is still -> Tensor. This inconsistency can cause type checking issues and confuse callers.

Proposed fix
-    ) -> Tensor:
+    ) -> tuple[Tensor, Tensor | None]:

Also update the docstring at lines 72-73:

-    Returns:
-        Tensor: Model output (e.g., logits or loss, depending on mode).
+    Returns:
+        tuple: (output_tensor, loss_mask) where output_tensor contains model output
+               and loss_mask is the CP-sliced mask for consistent loss computation.

Also applies to: 186-187

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py` around lines 121 -
124, The method whose signature includes "packed_seq_params:
Optional['PackedSeqParams'] = None, *, loss_mask: Optional[Tensor] = None"
returns a tuple (outputs, loss_mask) but is annotated "-> Tensor"; update the
return type to reflect the actual return (for example "-> Tuple[Tensor,
Optional[Tensor]]"), import Tuple from typing as needed, and update the method
docstring to describe the tuple return (outputs, loss_mask); apply the same
change to the other overloaded/duplicate signature in the file that has the same
parameters.

Comment on lines 370 to 381
@lru_cache(maxsize=32)
def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor:
def forward(
self,
max_seq_len: int,
offset: int = 0,
packed_seq: bool = False,
cp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Tensor:
"""Get global and local rope embedding"""
rope_global = super().forward(max_seq_len, offset, packed_seq)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq)
rope_global = super().forward(max_seq_len, offset, packed_seq, cp_group)
rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, cp_group)
return rope_local, rope_global
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check the actual file to understand context
head -n 400 src/megatron/bridge/models/gemma/gemma3_provider.py | tail -n 100

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3234


🌐 Web query:

torch.distributed.ProcessGroup hashable lru_cache Python

💡 Result:

functools.lru_cache can only cache calls whose positional/keyword arguments are hashable, because it uses the arguments as keys in an internal dict. If you pass a torch.distributed.distributed_c10d.ProcessGroup and it raises TypeError: unhashable type: 'ProcessGroup', that’s why. [1]

Practical patterns

1) Cache by a hashable key derived from the group

Use a wrapper that turns the group into a stable, hashable identifier (often process-local):

import torch.distributed as dist
from functools import lru_cache

def _pg_key(pg):
    # process-local identity + backend name (both hashable)
    return (id(pg), dist.get_backend(pg))

`@lru_cache`(maxsize=None)
def expensive_for_group(pg_id_backend):
    pg_id, backend = pg_id_backend
    ...  # compute

def expensive(pg):
    return expensive_for_group(_pg_key(pg))

dist.get_backend(group) is a supported way to query a process group’s backend. [2]

2) Don’t use lru_cache: use an explicit dict keyed by id(pg)

If you need more control (cleanup, limits, etc.), do:

_cache = {}

def expensive(pg):
    k = id(pg)
    if k not in _cache:
        _cache[k] = ...
    return _cache[k]

3) Cache by group-defining parameters instead of the object

If your “expensive” result is really determined by (say) (backend, world_size, ranks_tuple, options...), build a tuple key from those parameters and cache on that (tuples are hashable).


Sources

  • Python functools.lru_cache requires args to be hashable. [1]
  • PyTorch distributed APIs treat ProcessGroup as an opaque handle and provide get_backend(group) for it. [2]

[1] https://docs.python.org/3/library/functools.html
[2] https://docs.pytorch.org/docs/stable/distributed


🏁 Script executed:

# Search for other uses of ProcessGroup in the codebase
rg "ProcessGroup" -A 3 -B 3 src/ | head -50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4197


🏁 Script executed:

# Check if there are other uses of lru_cache with ProcessGroup or similar parameters in the codebase
rg "@lru_cache" -A 10 src/megatron/bridge/ | grep -A 10 "ProcessGroup"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1201


🏁 Script executed:

# Check if cp_group is ever passed as non-None in the codebase
rg "\.forward\(" src/megatron/bridge/models/gemma/ -A 2 -B 2 | grep -i "cp_group"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 343


🏁 Script executed:

# Check if this method is called anywhere and how
rg "Gemma3RotaryEmbedding" -A 5 -B 5 src/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2129


🏁 Script executed:

# Check the parent class RotaryEmbedding to understand the signature
rg "class RotaryEmbedding" -A 20 src/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Find where RotaryEmbedding is imported from
rg "from.*RotaryEmbedding|import.*RotaryEmbedding" src/megatron/bridge/models/gemma/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 210


🏁 Script executed:

# Check the imports at the top of gemma3_provider.py
head -50 src/megatron/bridge/models/gemma/gemma3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2201


🏁 Script executed:

# Search for RotaryEmbedding definition across the repository
fd "\.py$" src/ | xargs grep -l "class RotaryEmbedding" 2>/dev/null

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 55


🏁 Script executed:

# Search within megatron.core for RotaryEmbedding definition
find . -name "*.py" -type f -exec grep -l "class RotaryEmbedding" {} \; 2>/dev/null

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check if the parent RotaryEmbedding is in external megatron.core (not in this repo)
# Let's check what calls Gemma3RotaryEmbedding.forward to see if cp_group is ever actually used
rg "rotary_pos_emb.*forward\(" src/megatron/bridge/ -A 3 -B 3

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check if there are any actual calls to Gemma3RotaryEmbedding.forward with cp_group != None
rg "cp_group\s*=" src/megatron/bridge/models/gemma/ -B 5 -A 5

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Let's check the actual signature of the parent RotaryEmbedding.forward to see if it also accepts cp_group
# Since it's from megatron.core, let's see if there are any hints in usage
rg "super\(\)\.forward\(" src/megatron/bridge/models/gemma/ -B 3 -A 3

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 4267


🏁 Script executed:

# Check the actual parent class to understand the full signature
# Since RotaryEmbedding is from megatron.core, let's see if there are tests or examples
rg "RotaryEmbedding" src/megatron/bridge/ -A 5 | head -60

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 5616


🏁 Script executed:

# Check if the typing convention is violated - Optional vs | None
rg "Optional\[torch.distributed.ProcessGroup\]" src/megatron/bridge/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 664


🏁 Script executed:

# Verify the exact line numbers and full context of the method
sed -n '365,385p' src/megatron/bridge/models/gemma/gemma3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 721


Fix @lru_cache incompatibility with ProcessGroup parameter and update type hint.

lru_cache requires hashable arguments, but torch.distributed.ProcessGroup is unhashable. Passing a non-None cp_group will raise TypeError at runtime. Additionally, the type hint uses Optional[...] instead of the PEP 604 union syntax required by coding guidelines.

Either: (1) remove the cp_group parameter if it's not needed, (2) bypass caching when cp_group is provided, or (3) derive a hashable key from the group before caching. Also change Optional[torch.distributed.ProcessGroup] to torch.distributed.ProcessGroup | None.

🧰 Tools
🪛 Ruff (0.14.14)

370-370: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks

(B019)

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/gemma/gemma3_provider.py` around lines 370 - 381,
The forward method is using `@lru_cache` but accepts an unhashable
torch.distributed.ProcessGroup and uses Optional[...] typing; change the
signature to use PEP 604 union (torch.distributed.ProcessGroup | None) and avoid
passing the ProcessGroup into the cached path: implement a small uncached branch
in forward that, when cp_group is not None, directly calls super().forward(...)
and self.rope_local.forward(...) and returns rope_local, rope_global; otherwise
use the cached logic (keep `@lru_cache` on a helper that only accepts hashable
args like max_seq_len, offset, packed_seq) so caching is only used for hashable
parameters. Ensure you reference the existing forward method and
rope_local/rope_global calls when making the change.

Comment on lines 195 to 198
packed_seq_params: Optional["PackedSeqParams"] = None,
*,
loss_mask: Optional[Tensor] = None,
) -> Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation inconsistent with actual return value.

Same issue as in modeling_gemma3_vl.py - the method returns a tuple (outputs, loss_mask) but the annotation is -> Tensor.

Proposed fix
-    ) -> Tensor:
+    ) -> tuple[Tensor, Tensor | None]:

Also update the docstring at lines 212-213 to reflect the tuple return.

Also applies to: 271-272

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/ministral3/modeling_ministral3.py` around lines
195 - 198, The function whose signature includes the parameters
packed_seq_params: Optional["PackedSeqParams"] = None and *, loss_mask:
Optional[Tensor] = None is annotated as returning -> Tensor but actually returns
a tuple (outputs, loss_mask); update the return type annotation to something
like -> Tuple[Tensor, Optional[Tensor]] (import Tuple if needed) and update the
docstring (the lines describing the return value) to reflect the tuple form;
apply the same fix to the other occurrence that mirrors lines 271-272 so both
methods consistently annotate and document returning (outputs, loss_mask).

Comment on lines 151 to 162
if len(valid_sequences) == 0:
# No valid sequences, return dummy packed batch
logger.warning("No valid sequences found in batch, skipping packing")
return (
tokens[:1], # Return first sequence as-is
labels[:1],
loss_mask[:1],
attention_mask,
position_ids[:1],
torch.tensor([0, seq_len], dtype=torch.int32, device=device),
torch.tensor(seq_len, dtype=torch.int32, device=device),
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Edge case: All-padding batch returns potentially invalid data.

When all sequences are padding, the fallback returns tokens[:1] (which is all padding) with cu_seqlens=[0, seq_len]. This indicates a valid sequence of length seq_len, which is inconsistent with the actual empty content. Downstream loss computation might not handle this correctly.

Consider returning a zero-length packed batch or explicitly handling this case in loss computation:

     if len(valid_sequences) == 0:
         # No valid sequences, return dummy packed batch
         logger.warning("No valid sequences found in batch, skipping packing")
+        # Return empty packed tensors with cu_seqlens indicating no content
         return (
-            tokens[:1],  # Return first sequence as-is
-            labels[:1],
-            loss_mask[:1],
+            tokens[:, :0],  # Empty tensor
+            labels[:, :0],
+            loss_mask[:, :0],
             attention_mask,
-            position_ids[:1],
-            torch.tensor([0, seq_len], dtype=torch.int32, device=device),
-            torch.tensor(seq_len, dtype=torch.int32, device=device),
+            position_ids[:, :0],
+            torch.tensor([0], dtype=torch.int32, device=device),  # Empty cu_seqlens
+            torch.tensor(0, dtype=torch.int32, device=device),
         )
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/vlm_step.py` around lines 151 - 162, The current
fallback for all-padding batches returns tokens[:1] etc. and cu_seqlens as
torch.tensor([0, seq_len]) which falsely signals a non-empty sequence; update
the all-padding branch in vlm_step.py (the block using valid_sequences, tokens,
labels, loss_mask, attention_mask, position_ids and the torch.tensor(...)
cu_seqlens/total_len) to return a consistent zero-length packed batch: produce
empty tensors for tokens/labels/loss_mask/position_ids with correct
dtypes/devices, keep attention_mask shaped for the original batch if needed, and
set cu_seqlens to [0,0] and total_len to 0 (or otherwise signal length 0) so
downstream loss code sees truly empty content and does not assume seq_len.

Comment on lines 299 to 303
# attention_mask if present
attn = batch.get("attention_mask")
if attn is not None:
attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
batch["attention_mask"] = attn # type: ignore[assignment]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bug: Variable scope issue with attn handling.

The attn variable is defined inside the if tokens_or_input is not None: block (line 300), but the if attn is not None: check at line 301 appears to be at the wrong indentation level (outside the block). This will cause an UnboundLocalError when tokens_or_input is None.

Proposed fix
                 # attention_mask if present
                 attn = batch.get("attention_mask")
-            if attn is not None:
-                attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
-                batch["attention_mask"] = attn  # type: ignore[assignment]
+                if attn is not None:
+                    attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
+                    batch["attention_mask"] = attn  # type: ignore[assignment]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# attention_mask if present
attn = batch.get("attention_mask")
if attn is not None:
attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
batch["attention_mask"] = attn # type: ignore[assignment]
# attention_mask if present
attn = batch.get("attention_mask")
if attn is not None:
attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
batch["attention_mask"] = attn # type: ignore[assignment]
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/vlm_step.py` around lines 299 - 303, The attn
variable is defined only when tokens_or_input is not None but the subsequent if
attn is not None: check is dedented causing a potential UnboundLocalError; fix
by ensuring attn is always defined before the check (e.g., initialize attn =
batch.get("attention_mask") or move the if attn is not None: block inside the
tokens_or_input branch) and then call pad_or_truncate_attn_to_len(attn,
target_len, seq_cap) and assign back to batch["attention_mask"]; locate uses of
attn, tokens_or_input, batch, and pad_or_truncate_attn_to_len in vlm_step.py to
make the adjustment.

Comment on lines 425 to 430
model_output = model(**forward_args)
# Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CPI'm
if isinstance(model_output, tuple):
output_tensor, loss_mask = model_output
else:
output_tensor = model_output
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Minor: Typo in comment.

Line 426 has a typo: "CPI'm" should be "CP".

Fix typo
-            # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CPI'm
+            # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CP
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/vlm_step.py` around lines 425 - 430, The inline
comment above the model output handling in vlm_step.py has a typo: change
"CPI'm" to "CP" in the comment that describes tuple returns from VLM models with
CP; update the comment near the model_output handling (variables/functions:
model_output, output_tensor, loss_mask, model(**forward_args)) so it reads "from
VLM models with CP" instead of "CPI'm".

Comment on lines 319 to 333
if packed_seq_params is not None and packed_seq_params.qkv_format == "thd":
import transformer_engine_torch as tex

cu_seqlens = packed_seq_params.cu_seqlens_q
cu_seqlens_padded = (
packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens
)
seq_len = inputs_embeds.size(1)

index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank)

# Slice all tensors using THD indices
if inputs_embeds is not None:
inputs_embeds = inputs_embeds.index_select(1, index)
if labels is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard THD slicing when inputs_embeds can be absent.

The THD path unconditionally reads inputs_embeds.size(1) and indexes it. If a caller supplies inputs_embeds=None, this will raise at runtime. Consider asserting non‑None (or deriving seq_len from another tensor) before computing THD indices.

🛠️ Possible guard
-        seq_len = inputs_embeds.size(1)
+        if inputs_embeds is None:
+            raise ValueError("inputs_embeds is required for THD CP slicing")
+        seq_len = inputs_embeds.size(1)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if packed_seq_params is not None and packed_seq_params.qkv_format == "thd":
import transformer_engine_torch as tex
cu_seqlens = packed_seq_params.cu_seqlens_q
cu_seqlens_padded = (
packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens
)
seq_len = inputs_embeds.size(1)
index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank)
# Slice all tensors using THD indices
if inputs_embeds is not None:
inputs_embeds = inputs_embeds.index_select(1, index)
if labels is not None:
if packed_seq_params is not None and packed_seq_params.qkv_format == "thd":
import transformer_engine_torch as tex
cu_seqlens = packed_seq_params.cu_seqlens_q
cu_seqlens_padded = (
packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens
)
if inputs_embeds is None:
raise ValueError("inputs_embeds is required for THD CP slicing")
seq_len = inputs_embeds.size(1)
index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank)
# Slice all tensors using THD indices
if inputs_embeds is not None:
inputs_embeds = inputs_embeds.index_select(1, index)
if labels is not None:
🤖 Prompt for AI Agents
In `@src/megatron/bridge/utils/common_utils.py` around lines 319 - 333, The THD
branch assumes inputs_embeds is present by calling inputs_embeds.size(1) and
indexing it; guard against inputs_embeds being None before computing seq_len and
calling tex.thd_get_partitioned_indices. Update the block around
packed_seq_params.qkv_format == "thd" (references: packed_seq_params,
tex.thd_get_partitioned_indices, inputs_embeds) to either assert inputs_embeds
is not None or compute seq_len from an alternative tensor (e.g., attention_mask
or labels) and only call thd_get_partitioned_indices and index_select when
inputs_embeds is non-None.

Comment on lines 186 to 194
def _get_llama_4_attn_scale(
self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
) -> torch.Tensor:
"""Reimplementation of the function for testing."""
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
num_dims_to_add = len(query_shape) - 1
for _ in range(num_dims_to_add):
scaling = scaling.unsqueeze(-1)
return scaling
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Test duplicates production logic instead of calling it.

This helper re-implements _get_llama_4_attn_scale, so regressions in the real method could still pass. Prefer invoking the production implementation and keep only the expected-value checks in tests.

🔧 Suggested change
-    def _get_llama_4_attn_scale(
-        self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
-    ) -> torch.Tensor:
-        """Reimplementation of the function for testing."""
-        scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
-        num_dims_to_add = len(query_shape) - 1
-        for _ in range(num_dims_to_add):
-            scaling = scaling.unsqueeze(-1)
-        return scaling
+    def _get_llama_4_attn_scale(
+        self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
+    ) -> torch.Tensor:
+        """Call the production implementation for testing."""
+        provider = Ministral3ModelProvider(
+            num_layers=26,
+            hidden_size=3072,
+            num_attention_heads=32,
+        )
+        return provider._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)
🤖 Prompt for AI Agents
In `@tests/unit_tests/models/ministral3/test_ministral3_provider.py` around lines
186 - 194, Test currently duplicates production logic by reimplementing
_get_llama_4_attn_scale; replace this helper with a direct call to the
production function (e.g., import and call the real _get_llama_4_attn_scale used
by the model) and keep the test only asserting expected output values. Locate
the helper method in
tests/unit_tests/models/ministral3/test_ministral3_provider.py, remove the
reimplementation, call the production _get_llama_4_attn_scale with the same
parameters (positions_ids, beta, max_position_embeddings, query_shape), and
assert the returned scaling tensor matches expected values.

Comment on lines 318 to 331
def test_gpu_tensor_support(self):
"""Test that the function works with GPU tensors if available."""
if not torch.cuda.is_available():
return # Skip test if no GPU

positions_ids = torch.arange(8, device="cuda")
beta = 0.1
max_position_embeddings = 1024
query_shape = (8, 4, 32)

scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)

assert scaling.device.type == "cuda"
assert scaling.shape == (8, 1, 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Check if the file exists and read the test method
find . -path "*tests/unit_tests/models/ministral3/test_ministral3_provider.py" -type f

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 135


🏁 Script executed:

# List files in the ministral3 test directory if it exists
fd "test_ministral3" tests/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 384


🏁 Script executed:

# Search more broadly for the test file
fd "test_ministral3_provider" .

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 135


🏁 Script executed:

# Read the test file to verify the current state and imports
wc -l ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 139


🏁 Script executed:

# Check the imports at the top of the file
head -30 ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1175


🏁 Script executed:

# Read around the test method (lines 310-340)
sed -n '310,340p' ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1294


🏁 Script executed:

# Check if pytest is imported anywhere in the file
grep -n "import pytest" ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check for pytestmark at module level
grep -n "pytestmark" ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check for any existing pytest.mark usage in the file
grep -n "@pytest.mark" ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Read the full imports section to understand what's imported
sed -n '1,25p' ./tests/unit_tests/models/ministral3/test_ministral3_provider.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 925


Add pytest markers and explicit GPU requirement for the CUDA test.

The file lacks unit test categorization and the GPU test uses an early return instead of a pytest skip marker, which doesn't document the hardware requirement. Mark the test with @pytest.mark.gpu and use @pytest.mark.skipif() to explicitly document the CUDA dependency, and add module-level pytestmark for unit test categorization.

✅ Suggested change
+import pytest
 import torch
 
 from megatron.bridge.models.ministral3.ministral3_provider import (
@@ -17,6 +18,8 @@
     Ministral3ModelProvider14B,
 )
 
+pytestmark = pytest.mark.unit
+
 
 class TestMinistral3ModelProvider:
     """Test cases for Ministral3ModelProvider base class."""
@@ -315,8 +318,9 @@ class TestMinistral3ModelProvider:
         result_4d = query_4d * scaling_4d.to(query_4d.dtype)
         assert result_4d.shape == query_4d.shape
 
+    `@pytest.mark.gpu`
+    `@pytest.mark.skipif`(not torch.cuda.is_available(), reason="Requires CUDA")
     def test_gpu_tensor_support(self):
         """Test that the function works with GPU tensors if available."""
-        if not torch.cuda.is_available():
-            return  # Skip test if no GPU
🤖 Prompt for AI Agents
In `@tests/unit_tests/models/ministral3/test_ministral3_provider.py` around lines
318 - 331, Add pytest markers and an explicit CUDA skip to the GPU test:
annotate the test_gpu_tensor_support function with `@pytest.mark.gpu` and
`@pytest.mark.skipif`(not torch.cuda.is_available(), reason="requires CUDA")
instead of returning early, and add a module-level pytestmark =
[pytest.mark.unit] to categorize the file; update imports if necessary to
include pytest at top of the test file so decorators resolve.

Comment on lines 229 to 246
def test_packed_seq_params_without_unpadded_fallback(self):
"""Test fallback to cu_seqlens when cu_seqlens_unpadded is not provided."""
batch = {
"cu_seqlens": torch.tensor([[0, 5, 10, 15, -1]], dtype=torch.int32),
"max_seqlen": torch.tensor([[8]], dtype=torch.int32),
}

result = get_packed_seq_params(batch)

expected_cu_seqlens = torch.tensor([0, 5, 10, 15], dtype=torch.int32)

# Without unpadded, q/kv should use padded values
assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens)
assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens)

# Padded fields should match q/kv
assert torch.equal(result.cu_seqlens_q_padded, expected_cu_seqlens)
assert torch.equal(result.cu_seqlens_kv_padded, expected_cu_seqlens)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

*Assertions for _padded in no-unpadded cases don’t match current behavior.

get_packed_seq_params only sets cu_seqlens_*_padded when cu_seqlens_unpadded is present; otherwise these fields are typically None. These tests will fail (or codify a new requirement) unless you update the implementation to always populate the padded fields.

🔧 Suggested change (align tests with current behavior)
-        # Padded fields should match q/kv
-        assert torch.equal(result.cu_seqlens_q_padded, expected_cu_seqlens)
-        assert torch.equal(result.cu_seqlens_kv_padded, expected_cu_seqlens)
+        # No unpadded input → padded fields are unset
+        assert result.cu_seqlens_q_padded is None
+        assert result.cu_seqlens_kv_padded is None
@@
-        assert torch.equal(result.cu_seqlens_q_padded, expected)
+        assert result.cu_seqlens_q_padded is None

Also applies to: 248-260

🤖 Prompt for AI Agents
In `@tests/unit_tests/training/test_gpt_step.py` around lines 229 - 246, The test
assumes get_packed_seq_params populates cu_seqlens_q_padded and
cu_seqlens_kv_padded when cu_seqlens_unpadded is missing, but the implementation
leaves those padded fields as None; update the failing tests (e.g.,
test_packed_seq_params_without_unpadded_fallback and the subsequent similar test
in the same file) to assert that result.cu_seqlens_q_padded and
result.cu_seqlens_kv_padded are None (instead of comparing to
expected_cu_seqlens), or alternatively change get_packed_seq_params to always
populate the *_padded fields—choose the test update approach to align tests with
current behavior.

Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants