-
Notifications
You must be signed in to change notification settings - Fork 150
THD training in VLMs #1997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
THD training in VLMs #1997
Conversation
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
📝 WalkthroughWalkthroughThis PR introduces batch-level sequence packing and context-parallel slicing support for vision-language models. Changes include a new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py`:
- Around line 121-124: The method whose signature includes "packed_seq_params:
Optional['PackedSeqParams'] = None, *, loss_mask: Optional[Tensor] = None"
returns a tuple (outputs, loss_mask) but is annotated "-> Tensor"; update the
return type to reflect the actual return (for example "-> Tuple[Tensor,
Optional[Tensor]]"), import Tuple from typing as needed, and update the method
docstring to describe the tuple return (outputs, loss_mask); apply the same
change to the other overloaded/duplicate signature in the file that has the same
parameters.
In `@src/megatron/bridge/models/gemma/gemma3_provider.py`:
- Around line 370-381: The forward method is using `@lru_cache` but accepts an
unhashable torch.distributed.ProcessGroup and uses Optional[...] typing; change
the signature to use PEP 604 union (torch.distributed.ProcessGroup | None) and
avoid passing the ProcessGroup into the cached path: implement a small uncached
branch in forward that, when cp_group is not None, directly calls
super().forward(...) and self.rope_local.forward(...) and returns rope_local,
rope_global; otherwise use the cached logic (keep `@lru_cache` on a helper that
only accepts hashable args like max_seq_len, offset, packed_seq) so caching is
only used for hashable parameters. Ensure you reference the existing forward
method and rope_local/rope_global calls when making the change.
In `@src/megatron/bridge/models/ministral3/modeling_ministral3.py`:
- Around line 195-198: The function whose signature includes the parameters
packed_seq_params: Optional["PackedSeqParams"] = None and *, loss_mask:
Optional[Tensor] = None is annotated as returning -> Tensor but actually returns
a tuple (outputs, loss_mask); update the return type annotation to something
like -> Tuple[Tensor, Optional[Tensor]] (import Tuple if needed) and update the
docstring (the lines describing the return value) to reflect the tuple form;
apply the same fix to the other occurrence that mirrors lines 271-272 so both
methods consistently annotate and document returning (outputs, loss_mask).
In `@src/megatron/bridge/training/vlm_step.py`:
- Around line 425-430: The inline comment above the model output handling in
vlm_step.py has a typo: change "CPI'm" to "CP" in the comment that describes
tuple returns from VLM models with CP; update the comment near the model_output
handling (variables/functions: model_output, output_tensor, loss_mask,
model(**forward_args)) so it reads "from VLM models with CP" instead of "CPI'm".
- Around line 299-303: The attn variable is defined only when tokens_or_input is
not None but the subsequent if attn is not None: check is dedented causing a
potential UnboundLocalError; fix by ensuring attn is always defined before the
check (e.g., initialize attn = batch.get("attention_mask") or move the if attn
is not None: block inside the tokens_or_input branch) and then call
pad_or_truncate_attn_to_len(attn, target_len, seq_cap) and assign back to
batch["attention_mask"]; locate uses of attn, tokens_or_input, batch, and
pad_or_truncate_attn_to_len in vlm_step.py to make the adjustment.
- Around line 151-162: The current fallback for all-padding batches returns
tokens[:1] etc. and cu_seqlens as torch.tensor([0, seq_len]) which falsely
signals a non-empty sequence; update the all-padding branch in vlm_step.py (the
block using valid_sequences, tokens, labels, loss_mask, attention_mask,
position_ids and the torch.tensor(...) cu_seqlens/total_len) to return a
consistent zero-length packed batch: produce empty tensors for
tokens/labels/loss_mask/position_ids with correct dtypes/devices, keep
attention_mask shaped for the original batch if needed, and set cu_seqlens to
[0,0] and total_len to 0 (or otherwise signal length 0) so downstream loss code
sees truly empty content and does not assume seq_len.
In `@src/megatron/bridge/utils/common_utils.py`:
- Around line 319-333: The THD branch assumes inputs_embeds is present by
calling inputs_embeds.size(1) and indexing it; guard against inputs_embeds being
None before computing seq_len and calling tex.thd_get_partitioned_indices.
Update the block around packed_seq_params.qkv_format == "thd" (references:
packed_seq_params, tex.thd_get_partitioned_indices, inputs_embeds) to either
assert inputs_embeds is not None or compute seq_len from an alternative tensor
(e.g., attention_mask or labels) and only call thd_get_partitioned_indices and
index_select when inputs_embeds is non-None.
In `@tests/unit_tests/models/ministral3/test_ministral3_provider.py`:
- Around line 186-194: Test currently duplicates production logic by
reimplementing _get_llama_4_attn_scale; replace this helper with a direct call
to the production function (e.g., import and call the real
_get_llama_4_attn_scale used by the model) and keep the test only asserting
expected output values. Locate the helper method in
tests/unit_tests/models/ministral3/test_ministral3_provider.py, remove the
reimplementation, call the production _get_llama_4_attn_scale with the same
parameters (positions_ids, beta, max_position_embeddings, query_shape), and
assert the returned scaling tensor matches expected values.
- Around line 318-331: Add pytest markers and an explicit CUDA skip to the GPU
test: annotate the test_gpu_tensor_support function with `@pytest.mark.gpu` and
`@pytest.mark.skipif`(not torch.cuda.is_available(), reason="requires CUDA")
instead of returning early, and add a module-level pytestmark =
[pytest.mark.unit] to categorize the file; update imports if necessary to
include pytest at top of the test file so decorators resolve.
In `@tests/unit_tests/training/test_gpt_step.py`:
- Around line 229-246: The test assumes get_packed_seq_params populates
cu_seqlens_q_padded and cu_seqlens_kv_padded when cu_seqlens_unpadded is
missing, but the implementation leaves those padded fields as None; update the
failing tests (e.g., test_packed_seq_params_without_unpadded_fallback and the
subsequent similar test in the same file) to assert that
result.cu_seqlens_q_padded and result.cu_seqlens_kv_padded are None (instead of
comparing to expected_cu_seqlens), or alternatively change get_packed_seq_params
to always populate the *_padded fields—choose the test update approach to align
tests with current behavior.
🧹 Nitpick comments (7)
src/megatron/bridge/recipes/qwen_vl/data/energon/task_encoder.py (1)
22-22: Use Python 3.10+ type hint syntax.Per coding guidelines, prefer built-in generics and union syntax over
typingequivalents:
- Use
tupleinstead ofTuple- Use
X | Yinstead ofUnion[X, Y]- Use
dictinstead ofDictAlso,
__subflavor__at line 192 lacks type parameters.♻️ Proposed refactor for modern type hints
Line 22:
-from typing import Dict, List, Tuple, Union +from typing import Any, Dict, ListLine 172:
- __restore_key__: Tuple[Union[str, int, tuple], ...] + __restore_key__: tuple[str | int | tuple, ...]Lines 190-192:
__key__: List[str] - __restore_key__: List[Tuple[Union[str, int, tuple], ...]] - __subflavor__: Dict + __restore_key__: list[tuple[str | int | tuple, ...]] + __subflavor__: dict[str, Any] | NoneAs per coding guidelines: "Use 'X | Y' for union types instead of 'Union[X, Y]'" and "Use built-in generics (list, dict, tuple) instead of typing equivalents".
Also applies to: 172-172, 190-192
src/megatron/bridge/utils/common_utils.py (1)
276-284: Add explicit typing (and PEP 604 unions) for the new utility.
packed_seq_params,pg_collection, and the return type are unannotated, and several args are nullable. Adding explicit annotations improves tooling and aligns with repo typing rules.As per coding guidelines, use type hints for function arguments and return types, and use 'T | None' for nullable types instead of 'Optional[T]'.
src/megatron/bridge/models/glm_vl/modeling_glm_45v.py (1)
153-166: Use a PEP 604 union for the new nullable parameter.To align with repo typing rules, annotate this as
PackedSeqParams | None(string‑quoted is fine if you keep TYPE_CHECKING).As per coding guidelines, use 'T | None' for nullable types instead of 'Optional[T]'.♻️ Suggested change
- packed_seq_params: Optional["PackedSeqParams"] = None, + packed_seq_params: "PackedSeqParams | None" = None,src/megatron/bridge/models/gemma/gemma3_provider.py (1)
371-377: Prefer PEP 604 unions in the new annotation.Use
torch.distributed.ProcessGroup | Noneto match repo typing rules.As per coding guidelines, use 'T | None' for nullable types instead of 'Optional[T]'.♻️ Suggested change
- cp_group: Optional[torch.distributed.ProcessGroup] = None, + cp_group: torch.distributed.ProcessGroup | None = None,src/megatron/bridge/training/vlm_step.py (1)
138-149: Performance: Device-host synchronization in loop.The
.item()call at line 143 triggers a CUDA synchronization for each sequence in the batch. For larger batch sizes, this can significantly impact performance.Consider vectorized approach
- for i in range(batch_size): - # Find first padding token or use full length - non_pad_mask = tokens[i] != pad_token_id - if non_pad_mask.any(): - # Find the last non-padding token - last_valid_idx = non_pad_mask.nonzero(as_tuple=True)[0][-1].item() + 1 - else: - # Empty sequence, skip - continue - - seq_lengths.append(last_valid_idx) - valid_sequences.append(i) + # Vectorized: compute all sequence lengths at once + non_pad_mask = tokens != pad_token_id # [batch_size, seq_len] + # Find last non-padding position per sequence + reversed_mask = non_pad_mask.flip(dims=[1]) + first_nonpad_from_end = reversed_mask.int().argmax(dim=1) + has_valid = non_pad_mask.any(dim=1) + seq_lengths_tensor = seq_len - first_nonpad_from_end + seq_lengths_tensor = torch.where(has_valid, seq_lengths_tensor, torch.zeros_like(seq_lengths_tensor)) + + # Single sync to get all lengths + seq_lengths_list = seq_lengths_tensor.tolist() + valid_sequences = [i for i, length in enumerate(seq_lengths_list) if length > 0] + seq_lengths = [seq_lengths_list[i] for i in valid_sequences]tests/unit_tests/training/test_vlm_step.py (1)
313-313: Address static analysis warnings for unused variables.Static analysis flagged multiple unused variables. While tests often unpack full tuples for clarity, consider:
- Prefix unused variables with underscore (e.g.,
_packed_labels)- Use the
caplogfixture at line 525 to verify the warning was loggedExample fixes
# Line 313 - packed_tokens, packed_labels, packed_loss_mask, packed_attn, packed_pos, cu_seqlens, max_seqlen = result + packed_tokens, _packed_labels, _packed_loss_mask, packed_attn, _packed_pos, cu_seqlens, max_seqlen = result # Line 525 - Use caplog to verify warning def test_packing_empty_batch_warning(self, caplog): """Test that all-padding batch returns dummy values with warning.""" + import logging + caplog.set_level(logging.WARNING) tokens = torch.tensor([[0, 0, 0, 0]]) # All padding ... result = pack_batch_sequences(...) + assert "No valid sequences found" in caplog.textAlso applies to: 360-360, 457-457, 480-480, 525-525, 542-542, 592-592
tests/unit_tests/training/test_gpt_step.py (1)
15-25: Add pytest unit markers at module level.These are unit tests but lack pytest categorization. A module-level marker keeps it consistent with the test guidelines.
✅ Suggested change
from functools import partial from unittest.mock import Mock, patch import modelopt.torch.distill as mtd +import pytest import torch from megatron.core.packed_seq_params import PackedSeqParams @@ from megatron.bridge.training.losses import ( create_masked_next_token_loss_function as _create_loss_function, ) +pytestmark = pytest.mark.unitAs per coding guidelines: Use "pytest.mark" to categorize tests (unit, integration, system).
| packed_seq_params: Optional["PackedSeqParams"] = None, | ||
| *, | ||
| loss_mask: Optional[Tensor] = None, | ||
| ) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation is inconsistent with actual return value.
The method now returns a tuple (outputs, loss_mask) but the return type annotation is still -> Tensor. This inconsistency can cause type checking issues and confuse callers.
Proposed fix
- ) -> Tensor:
+ ) -> tuple[Tensor, Tensor | None]:Also update the docstring at lines 72-73:
- Returns:
- Tensor: Model output (e.g., logits or loss, depending on mode).
+ Returns:
+ tuple: (output_tensor, loss_mask) where output_tensor contains model output
+ and loss_mask is the CP-sliced mask for consistent loss computation.Also applies to: 186-187
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/gemma_vl/modeling_gemma3_vl.py` around lines 121 -
124, The method whose signature includes "packed_seq_params:
Optional['PackedSeqParams'] = None, *, loss_mask: Optional[Tensor] = None"
returns a tuple (outputs, loss_mask) but is annotated "-> Tensor"; update the
return type to reflect the actual return (for example "-> Tuple[Tensor,
Optional[Tensor]]"), import Tuple from typing as needed, and update the method
docstring to describe the tuple return (outputs, loss_mask); apply the same
change to the other overloaded/duplicate signature in the file that has the same
parameters.
| @lru_cache(maxsize=32) | ||
| def forward(self, max_seq_len: int, offset: int = 0, packed_seq: bool = False) -> Tensor: | ||
| def forward( | ||
| self, | ||
| max_seq_len: int, | ||
| offset: int = 0, | ||
| packed_seq: bool = False, | ||
| cp_group: Optional[torch.distributed.ProcessGroup] = None, | ||
| ) -> Tensor: | ||
| """Get global and local rope embedding""" | ||
| rope_global = super().forward(max_seq_len, offset, packed_seq) | ||
| rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq) | ||
| rope_global = super().forward(max_seq_len, offset, packed_seq, cp_group) | ||
| rope_local = self.rope_local.forward(max_seq_len, offset, packed_seq, cp_group) | ||
| return rope_local, rope_global |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, check the actual file to understand context
head -n 400 src/megatron/bridge/models/gemma/gemma3_provider.py | tail -n 100Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 3234
🌐 Web query:
torch.distributed.ProcessGroup hashable lru_cache Python
💡 Result:
functools.lru_cache can only cache calls whose positional/keyword arguments are hashable, because it uses the arguments as keys in an internal dict. If you pass a torch.distributed.distributed_c10d.ProcessGroup and it raises TypeError: unhashable type: 'ProcessGroup', that’s why. [1]
Practical patterns
1) Cache by a hashable key derived from the group
Use a wrapper that turns the group into a stable, hashable identifier (often process-local):
import torch.distributed as dist
from functools import lru_cache
def _pg_key(pg):
# process-local identity + backend name (both hashable)
return (id(pg), dist.get_backend(pg))
`@lru_cache`(maxsize=None)
def expensive_for_group(pg_id_backend):
pg_id, backend = pg_id_backend
... # compute
def expensive(pg):
return expensive_for_group(_pg_key(pg))dist.get_backend(group) is a supported way to query a process group’s backend. [2]
2) Don’t use lru_cache: use an explicit dict keyed by id(pg)
If you need more control (cleanup, limits, etc.), do:
_cache = {}
def expensive(pg):
k = id(pg)
if k not in _cache:
_cache[k] = ...
return _cache[k]3) Cache by group-defining parameters instead of the object
If your “expensive” result is really determined by (say) (backend, world_size, ranks_tuple, options...), build a tuple key from those parameters and cache on that (tuples are hashable).
Sources
- Python
functools.lru_cacherequires args to be hashable. [1] - PyTorch distributed APIs treat
ProcessGroupas an opaque handle and provideget_backend(group)for it. [2]
[1] https://docs.python.org/3/library/functools.html
[2] https://docs.pytorch.org/docs/stable/distributed
🏁 Script executed:
# Search for other uses of ProcessGroup in the codebase
rg "ProcessGroup" -A 3 -B 3 src/ | head -50Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 4197
🏁 Script executed:
# Check if there are other uses of lru_cache with ProcessGroup or similar parameters in the codebase
rg "@lru_cache" -A 10 src/megatron/bridge/ | grep -A 10 "ProcessGroup"Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1201
🏁 Script executed:
# Check if cp_group is ever passed as non-None in the codebase
rg "\.forward\(" src/megatron/bridge/models/gemma/ -A 2 -B 2 | grep -i "cp_group"Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 343
🏁 Script executed:
# Check if this method is called anywhere and how
rg "Gemma3RotaryEmbedding" -A 5 -B 5 src/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2129
🏁 Script executed:
# Check the parent class RotaryEmbedding to understand the signature
rg "class RotaryEmbedding" -A 20 src/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Find where RotaryEmbedding is imported from
rg "from.*RotaryEmbedding|import.*RotaryEmbedding" src/megatron/bridge/models/gemma/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 210
🏁 Script executed:
# Check the imports at the top of gemma3_provider.py
head -50 src/megatron/bridge/models/gemma/gemma3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2201
🏁 Script executed:
# Search for RotaryEmbedding definition across the repository
fd "\.py$" src/ | xargs grep -l "class RotaryEmbedding" 2>/dev/nullRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 55
🏁 Script executed:
# Search within megatron.core for RotaryEmbedding definition
find . -name "*.py" -type f -exec grep -l "class RotaryEmbedding" {} \; 2>/dev/nullRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check if the parent RotaryEmbedding is in external megatron.core (not in this repo)
# Let's check what calls Gemma3RotaryEmbedding.forward to see if cp_group is ever actually used
rg "rotary_pos_emb.*forward\(" src/megatron/bridge/ -A 3 -B 3Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check if there are any actual calls to Gemma3RotaryEmbedding.forward with cp_group != None
rg "cp_group\s*=" src/megatron/bridge/models/gemma/ -B 5 -A 5Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Let's check the actual signature of the parent RotaryEmbedding.forward to see if it also accepts cp_group
# Since it's from megatron.core, let's see if there are any hints in usage
rg "super\(\)\.forward\(" src/megatron/bridge/models/gemma/ -B 3 -A 3Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 4267
🏁 Script executed:
# Check the actual parent class to understand the full signature
# Since RotaryEmbedding is from megatron.core, let's see if there are tests or examples
rg "RotaryEmbedding" src/megatron/bridge/ -A 5 | head -60Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 5616
🏁 Script executed:
# Check if the typing convention is violated - Optional vs | None
rg "Optional\[torch.distributed.ProcessGroup\]" src/megatron/bridge/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 664
🏁 Script executed:
# Verify the exact line numbers and full context of the method
sed -n '365,385p' src/megatron/bridge/models/gemma/gemma3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 721
Fix @lru_cache incompatibility with ProcessGroup parameter and update type hint.
lru_cache requires hashable arguments, but torch.distributed.ProcessGroup is unhashable. Passing a non-None cp_group will raise TypeError at runtime. Additionally, the type hint uses Optional[...] instead of the PEP 604 union syntax required by coding guidelines.
Either: (1) remove the cp_group parameter if it's not needed, (2) bypass caching when cp_group is provided, or (3) derive a hashable key from the group before caching. Also change Optional[torch.distributed.ProcessGroup] to torch.distributed.ProcessGroup | None.
🧰 Tools
🪛 Ruff (0.14.14)
370-370: Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
(B019)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/gemma/gemma3_provider.py` around lines 370 - 381,
The forward method is using `@lru_cache` but accepts an unhashable
torch.distributed.ProcessGroup and uses Optional[...] typing; change the
signature to use PEP 604 union (torch.distributed.ProcessGroup | None) and avoid
passing the ProcessGroup into the cached path: implement a small uncached branch
in forward that, when cp_group is not None, directly calls super().forward(...)
and self.rope_local.forward(...) and returns rope_local, rope_global; otherwise
use the cached logic (keep `@lru_cache` on a helper that only accepts hashable
args like max_seq_len, offset, packed_seq) so caching is only used for hashable
parameters. Ensure you reference the existing forward method and
rope_local/rope_global calls when making the change.
| packed_seq_params: Optional["PackedSeqParams"] = None, | ||
| *, | ||
| loss_mask: Optional[Tensor] = None, | ||
| ) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation inconsistent with actual return value.
Same issue as in modeling_gemma3_vl.py - the method returns a tuple (outputs, loss_mask) but the annotation is -> Tensor.
Proposed fix
- ) -> Tensor:
+ ) -> tuple[Tensor, Tensor | None]:Also update the docstring at lines 212-213 to reflect the tuple return.
Also applies to: 271-272
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/ministral3/modeling_ministral3.py` around lines
195 - 198, The function whose signature includes the parameters
packed_seq_params: Optional["PackedSeqParams"] = None and *, loss_mask:
Optional[Tensor] = None is annotated as returning -> Tensor but actually returns
a tuple (outputs, loss_mask); update the return type annotation to something
like -> Tuple[Tensor, Optional[Tensor]] (import Tuple if needed) and update the
docstring (the lines describing the return value) to reflect the tuple form;
apply the same fix to the other occurrence that mirrors lines 271-272 so both
methods consistently annotate and document returning (outputs, loss_mask).
| if len(valid_sequences) == 0: | ||
| # No valid sequences, return dummy packed batch | ||
| logger.warning("No valid sequences found in batch, skipping packing") | ||
| return ( | ||
| tokens[:1], # Return first sequence as-is | ||
| labels[:1], | ||
| loss_mask[:1], | ||
| attention_mask, | ||
| position_ids[:1], | ||
| torch.tensor([0, seq_len], dtype=torch.int32, device=device), | ||
| torch.tensor(seq_len, dtype=torch.int32, device=device), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edge case: All-padding batch returns potentially invalid data.
When all sequences are padding, the fallback returns tokens[:1] (which is all padding) with cu_seqlens=[0, seq_len]. This indicates a valid sequence of length seq_len, which is inconsistent with the actual empty content. Downstream loss computation might not handle this correctly.
Consider returning a zero-length packed batch or explicitly handling this case in loss computation:
if len(valid_sequences) == 0:
# No valid sequences, return dummy packed batch
logger.warning("No valid sequences found in batch, skipping packing")
+ # Return empty packed tensors with cu_seqlens indicating no content
return (
- tokens[:1], # Return first sequence as-is
- labels[:1],
- loss_mask[:1],
+ tokens[:, :0], # Empty tensor
+ labels[:, :0],
+ loss_mask[:, :0],
attention_mask,
- position_ids[:1],
- torch.tensor([0, seq_len], dtype=torch.int32, device=device),
- torch.tensor(seq_len, dtype=torch.int32, device=device),
+ position_ids[:, :0],
+ torch.tensor([0], dtype=torch.int32, device=device), # Empty cu_seqlens
+ torch.tensor(0, dtype=torch.int32, device=device),
)🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/vlm_step.py` around lines 151 - 162, The current
fallback for all-padding batches returns tokens[:1] etc. and cu_seqlens as
torch.tensor([0, seq_len]) which falsely signals a non-empty sequence; update
the all-padding branch in vlm_step.py (the block using valid_sequences, tokens,
labels, loss_mask, attention_mask, position_ids and the torch.tensor(...)
cu_seqlens/total_len) to return a consistent zero-length packed batch: produce
empty tensors for tokens/labels/loss_mask/position_ids with correct
dtypes/devices, keep attention_mask shaped for the original batch if needed, and
set cu_seqlens to [0,0] and total_len to 0 (or otherwise signal length 0) so
downstream loss code sees truly empty content and does not assume seq_len.
| # attention_mask if present | ||
| attn = batch.get("attention_mask") | ||
| if attn is not None: | ||
| attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap) | ||
| batch["attention_mask"] = attn # type: ignore[assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Variable scope issue with attn handling.
The attn variable is defined inside the if tokens_or_input is not None: block (line 300), but the if attn is not None: check at line 301 appears to be at the wrong indentation level (outside the block). This will cause an UnboundLocalError when tokens_or_input is None.
Proposed fix
# attention_mask if present
attn = batch.get("attention_mask")
- if attn is not None:
- attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
- batch["attention_mask"] = attn # type: ignore[assignment]
+ if attn is not None:
+ attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap)
+ batch["attention_mask"] = attn # type: ignore[assignment]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # attention_mask if present | |
| attn = batch.get("attention_mask") | |
| if attn is not None: | |
| attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap) | |
| batch["attention_mask"] = attn # type: ignore[assignment] | |
| # attention_mask if present | |
| attn = batch.get("attention_mask") | |
| if attn is not None: | |
| attn = pad_or_truncate_attn_to_len(attn, target_len, seq_cap) | |
| batch["attention_mask"] = attn # type: ignore[assignment] |
🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/vlm_step.py` around lines 299 - 303, The attn
variable is defined only when tokens_or_input is not None but the subsequent if
attn is not None: check is dedented causing a potential UnboundLocalError; fix
by ensuring attn is always defined before the check (e.g., initialize attn =
batch.get("attention_mask") or move the if attn is not None: block inside the
tokens_or_input branch) and then call pad_or_truncate_attn_to_len(attn,
target_len, seq_cap) and assign back to batch["attention_mask"]; locate uses of
attn, tokens_or_input, batch, and pad_or_truncate_attn_to_len in vlm_step.py to
make the adjustment.
| model_output = model(**forward_args) | ||
| # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CPI'm | ||
| if isinstance(model_output, tuple): | ||
| output_tensor, loss_mask = model_output | ||
| else: | ||
| output_tensor = model_output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: Typo in comment.
Line 426 has a typo: "CPI'm" should be "CP".
Fix typo
- # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CPI'm
+ # Handle tuple return: (output_tensor, sliced_loss_mask) from VLM models with CP🤖 Prompt for AI Agents
In `@src/megatron/bridge/training/vlm_step.py` around lines 425 - 430, The inline
comment above the model output handling in vlm_step.py has a typo: change
"CPI'm" to "CP" in the comment that describes tuple returns from VLM models with
CP; update the comment near the model_output handling (variables/functions:
model_output, output_tensor, loss_mask, model(**forward_args)) so it reads "from
VLM models with CP" instead of "CPI'm".
| if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": | ||
| import transformer_engine_torch as tex | ||
|
|
||
| cu_seqlens = packed_seq_params.cu_seqlens_q | ||
| cu_seqlens_padded = ( | ||
| packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens | ||
| ) | ||
| seq_len = inputs_embeds.size(1) | ||
|
|
||
| index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank) | ||
|
|
||
| # Slice all tensors using THD indices | ||
| if inputs_embeds is not None: | ||
| inputs_embeds = inputs_embeds.index_select(1, index) | ||
| if labels is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard THD slicing when inputs_embeds can be absent.
The THD path unconditionally reads inputs_embeds.size(1) and indexes it. If a caller supplies inputs_embeds=None, this will raise at runtime. Consider asserting non‑None (or deriving seq_len from another tensor) before computing THD indices.
🛠️ Possible guard
- seq_len = inputs_embeds.size(1)
+ if inputs_embeds is None:
+ raise ValueError("inputs_embeds is required for THD CP slicing")
+ seq_len = inputs_embeds.size(1)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": | |
| import transformer_engine_torch as tex | |
| cu_seqlens = packed_seq_params.cu_seqlens_q | |
| cu_seqlens_padded = ( | |
| packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens | |
| ) | |
| seq_len = inputs_embeds.size(1) | |
| index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank) | |
| # Slice all tensors using THD indices | |
| if inputs_embeds is not None: | |
| inputs_embeds = inputs_embeds.index_select(1, index) | |
| if labels is not None: | |
| if packed_seq_params is not None and packed_seq_params.qkv_format == "thd": | |
| import transformer_engine_torch as tex | |
| cu_seqlens = packed_seq_params.cu_seqlens_q | |
| cu_seqlens_padded = ( | |
| packed_seq_params.cu_seqlens_q_padded if packed_seq_params.cu_seqlens_q_padded is not None else cu_seqlens | |
| ) | |
| if inputs_embeds is None: | |
| raise ValueError("inputs_embeds is required for THD CP slicing") | |
| seq_len = inputs_embeds.size(1) | |
| index = tex.thd_get_partitioned_indices(cu_seqlens_padded, seq_len, cp_size, cp_rank) | |
| # Slice all tensors using THD indices | |
| if inputs_embeds is not None: | |
| inputs_embeds = inputs_embeds.index_select(1, index) | |
| if labels is not None: |
🤖 Prompt for AI Agents
In `@src/megatron/bridge/utils/common_utils.py` around lines 319 - 333, The THD
branch assumes inputs_embeds is present by calling inputs_embeds.size(1) and
indexing it; guard against inputs_embeds being None before computing seq_len and
calling tex.thd_get_partitioned_indices. Update the block around
packed_seq_params.qkv_format == "thd" (references: packed_seq_params,
tex.thd_get_partitioned_indices, inputs_embeds) to either assert inputs_embeds
is not None or compute seq_len from an alternative tensor (e.g., attention_mask
or labels) and only call thd_get_partitioned_indices and index_select when
inputs_embeds is non-None.
| def _get_llama_4_attn_scale( | ||
| self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple | ||
| ) -> torch.Tensor: | ||
| """Reimplementation of the function for testing.""" | ||
| scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) | ||
| num_dims_to_add = len(query_shape) - 1 | ||
| for _ in range(num_dims_to_add): | ||
| scaling = scaling.unsqueeze(-1) | ||
| return scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test duplicates production logic instead of calling it.
This helper re-implements _get_llama_4_attn_scale, so regressions in the real method could still pass. Prefer invoking the production implementation and keep only the expected-value checks in tests.
🔧 Suggested change
- def _get_llama_4_attn_scale(
- self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
- ) -> torch.Tensor:
- """Reimplementation of the function for testing."""
- scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
- num_dims_to_add = len(query_shape) - 1
- for _ in range(num_dims_to_add):
- scaling = scaling.unsqueeze(-1)
- return scaling
+ def _get_llama_4_attn_scale(
+ self, positions_ids: torch.Tensor, beta: float, max_position_embeddings: int, query_shape: tuple
+ ) -> torch.Tensor:
+ """Call the production implementation for testing."""
+ provider = Ministral3ModelProvider(
+ num_layers=26,
+ hidden_size=3072,
+ num_attention_heads=32,
+ )
+ return provider._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape)🤖 Prompt for AI Agents
In `@tests/unit_tests/models/ministral3/test_ministral3_provider.py` around lines
186 - 194, Test currently duplicates production logic by reimplementing
_get_llama_4_attn_scale; replace this helper with a direct call to the
production function (e.g., import and call the real _get_llama_4_attn_scale used
by the model) and keep the test only asserting expected output values. Locate
the helper method in
tests/unit_tests/models/ministral3/test_ministral3_provider.py, remove the
reimplementation, call the production _get_llama_4_attn_scale with the same
parameters (positions_ids, beta, max_position_embeddings, query_shape), and
assert the returned scaling tensor matches expected values.
| def test_gpu_tensor_support(self): | ||
| """Test that the function works with GPU tensors if available.""" | ||
| if not torch.cuda.is_available(): | ||
| return # Skip test if no GPU | ||
|
|
||
| positions_ids = torch.arange(8, device="cuda") | ||
| beta = 0.1 | ||
| max_position_embeddings = 1024 | ||
| query_shape = (8, 4, 32) | ||
|
|
||
| scaling = self._get_llama_4_attn_scale(positions_ids, beta, max_position_embeddings, query_shape) | ||
|
|
||
| assert scaling.device.type == "cuda" | ||
| assert scaling.shape == (8, 1, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Check if the file exists and read the test method
find . -path "*tests/unit_tests/models/ministral3/test_ministral3_provider.py" -type fRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 135
🏁 Script executed:
# List files in the ministral3 test directory if it exists
fd "test_ministral3" tests/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 384
🏁 Script executed:
# Search more broadly for the test file
fd "test_ministral3_provider" .Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 135
🏁 Script executed:
# Read the test file to verify the current state and imports
wc -l ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 139
🏁 Script executed:
# Check the imports at the top of the file
head -30 ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1175
🏁 Script executed:
# Read around the test method (lines 310-340)
sed -n '310,340p' ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1294
🏁 Script executed:
# Check if pytest is imported anywhere in the file
grep -n "import pytest" ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check for pytestmark at module level
grep -n "pytestmark" ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check for any existing pytest.mark usage in the file
grep -n "@pytest.mark" ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Read the full imports section to understand what's imported
sed -n '1,25p' ./tests/unit_tests/models/ministral3/test_ministral3_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 925
Add pytest markers and explicit GPU requirement for the CUDA test.
The file lacks unit test categorization and the GPU test uses an early return instead of a pytest skip marker, which doesn't document the hardware requirement. Mark the test with @pytest.mark.gpu and use @pytest.mark.skipif() to explicitly document the CUDA dependency, and add module-level pytestmark for unit test categorization.
✅ Suggested change
+import pytest
import torch
from megatron.bridge.models.ministral3.ministral3_provider import (
@@ -17,6 +18,8 @@
Ministral3ModelProvider14B,
)
+pytestmark = pytest.mark.unit
+
class TestMinistral3ModelProvider:
"""Test cases for Ministral3ModelProvider base class."""
@@ -315,8 +318,9 @@ class TestMinistral3ModelProvider:
result_4d = query_4d * scaling_4d.to(query_4d.dtype)
assert result_4d.shape == query_4d.shape
+ `@pytest.mark.gpu`
+ `@pytest.mark.skipif`(not torch.cuda.is_available(), reason="Requires CUDA")
def test_gpu_tensor_support(self):
"""Test that the function works with GPU tensors if available."""
- if not torch.cuda.is_available():
- return # Skip test if no GPU🤖 Prompt for AI Agents
In `@tests/unit_tests/models/ministral3/test_ministral3_provider.py` around lines
318 - 331, Add pytest markers and an explicit CUDA skip to the GPU test:
annotate the test_gpu_tensor_support function with `@pytest.mark.gpu` and
`@pytest.mark.skipif`(not torch.cuda.is_available(), reason="requires CUDA")
instead of returning early, and add a module-level pytestmark =
[pytest.mark.unit] to categorize the file; update imports if necessary to
include pytest at top of the test file so decorators resolve.
| def test_packed_seq_params_without_unpadded_fallback(self): | ||
| """Test fallback to cu_seqlens when cu_seqlens_unpadded is not provided.""" | ||
| batch = { | ||
| "cu_seqlens": torch.tensor([[0, 5, 10, 15, -1]], dtype=torch.int32), | ||
| "max_seqlen": torch.tensor([[8]], dtype=torch.int32), | ||
| } | ||
|
|
||
| result = get_packed_seq_params(batch) | ||
|
|
||
| expected_cu_seqlens = torch.tensor([0, 5, 10, 15], dtype=torch.int32) | ||
|
|
||
| # Without unpadded, q/kv should use padded values | ||
| assert torch.equal(result.cu_seqlens_q, expected_cu_seqlens) | ||
| assert torch.equal(result.cu_seqlens_kv, expected_cu_seqlens) | ||
|
|
||
| # Padded fields should match q/kv | ||
| assert torch.equal(result.cu_seqlens_q_padded, expected_cu_seqlens) | ||
| assert torch.equal(result.cu_seqlens_kv_padded, expected_cu_seqlens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*Assertions for _padded in no-unpadded cases don’t match current behavior.
get_packed_seq_params only sets cu_seqlens_*_padded when cu_seqlens_unpadded is present; otherwise these fields are typically None. These tests will fail (or codify a new requirement) unless you update the implementation to always populate the padded fields.
🔧 Suggested change (align tests with current behavior)
- # Padded fields should match q/kv
- assert torch.equal(result.cu_seqlens_q_padded, expected_cu_seqlens)
- assert torch.equal(result.cu_seqlens_kv_padded, expected_cu_seqlens)
+ # No unpadded input → padded fields are unset
+ assert result.cu_seqlens_q_padded is None
+ assert result.cu_seqlens_kv_padded is None
@@
- assert torch.equal(result.cu_seqlens_q_padded, expected)
+ assert result.cu_seqlens_q_padded is NoneAlso applies to: 248-260
🤖 Prompt for AI Agents
In `@tests/unit_tests/training/test_gpt_step.py` around lines 229 - 246, The test
assumes get_packed_seq_params populates cu_seqlens_q_padded and
cu_seqlens_kv_padded when cu_seqlens_unpadded is missing, but the implementation
leaves those padded fields as None; update the failing tests (e.g.,
test_packed_seq_params_without_unpadded_fallback and the subsequent similar test
in the same file) to assert that result.cu_seqlens_q_padded and
result.cu_seqlens_kv_padded are None (instead of comparing to
expected_cu_seqlens), or alternatively change get_packed_seq_params to always
populate the *_padded fields—choose the test update approach to align tests with
current behavior.
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
What does this PR do ?
Support THD Training in VLMs
TODOs
GLM 4.5 v(after this PR)Qwen 2.5 VL(after this PR)Qwen 3 VL(wip in [WIP] Support qwen3-vl for THD format and CP #1943)Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Gemma 3 VL:

Loss and grad norm are both matching closely
MInistral 3:

Not sure why grad norm is different between cp1 and cp2. It might be a display issue. Will resolve in next PR. Loss is matching very closely.
More plots to come
Summary by CodeRabbit
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.