Skip to content

[model, refactor] refactor: Centralize provider_bridge config mapping in base class for VLM models#2250

Open
yaoyu-33 wants to merge 40 commits intomainfrom
feature/provider-bridge-refactor-3
Open

[model, refactor] refactor: Centralize provider_bridge config mapping in base class for VLM models#2250
yaoyu-33 wants to merge 40 commits intomainfrom
feature/provider-bridge-refactor-3

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Feb 5, 2026

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

  • Bug Fixes

    • Removed generation_config propagation from model provider bridges to prevent configuration conflicts.
  • New Features

    • Added squared_relu activation function support for enhanced model conversion.
  • Refactor

    • Standardized model bridge registration patterns across provider implementations.
    • Refactored provider inheritance hierarchies for improved consistency.
    • Enhanced model-to-provider configuration mapping for Nemotron, Qwen, and Gemma model variants.
  • Tests

    • Updated test coverage to reflect provider and bridge architecture changes.

yaoyu-33 and others added 30 commits January 23, 2026 09:29
This refactoring centralizes model-specific configurations within the
provider_bridge method of each model bridge.

Changes:
- Add MoE-related field mappings to base class CONFIG_MAPPING:
  - num_experts -> num_moe_experts
  - num_experts_per_tok -> moe_router_topk
  - moe_intermediate_size -> moe_ffn_hidden_size

- Refactor LlamaBridge:
  - Use MEGATRON_DEFAULTS and HF_DEFAULTS class attributes
  - Override provider_bridge only for RoPE scaling (Llama 3.1/3.2)

- Refactor Qwen2Bridge:
  - Use MEGATRON_DEFAULTS (add_qkv_bias=True) and HF_DEFAULTS
  - No provider_bridge override needed

- Refactor Qwen3Bridge:
  - Use MEGATRON_DEFAULTS (qk_layernorm=True) and HF_DEFAULTS
  - No provider_bridge override needed

- Refactor Qwen3MoEBridge:
  - Use MEGATRON_DEFAULTS with MoE settings and HF_DEFAULTS
  - No provider_bridge override needed

- Update tests to expect GPTModelProvider instead of model-specific providers
- Add verification scripts for both Llama and Qwen bridges

Verified on remote server:
- Qwen/Qwen2-0.5B: PASS
- Qwen/Qwen2-7B: PASS
- Qwen/Qwen3-0.6B: PASS
- Qwen/Qwen3-1.7B: PASS
- Qwen/Qwen3-30B-A3B: PASS
…dels

- Add MLAModelProvider as unified base for Multi-Latent Attention models
- Refactor DeepSeek V2/V3 bridges to use MLAModelProvider
- Refactor Kimi K2 bridge to use MLAModelProvider
- Move model-specific defaults from providers to MEGATRON_DEFAULTS in bridges
- Add model_type parameter to @register_bridge decorator for auto HF config
- Simplify provider files to deprecated backward-compatible aliases

Verified: DeepSeek-V2-Lite, DeepSeek-V2, DeepSeek-V3, Moonlight-16B, Kimi-K2
- Register GemmaModelProvider, Gemma2ModelProvider, Gemma3ModelProvider via decorator
- Add MEGATRON_DEFAULTS to Gemma/Gemma2 bridges for explicit config defaults
- Add gelu_pytorch_tanh -> fast_gelu to ACTIVATION_MAPPING in model_bridge.py
- Add verification script for Gemma provider refactoring

Verified: gemma-2b, gemma-7b, gemma-2-2b, gemma-2-9b, gemma-2-27b,
         gemma-3-4b-it, gemma-3-12b-it, gemma-3-27b-it
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
OLMoE HF config doesn't have head_dim attribute, so kv_channels was
left as None. This fix calculates it as hidden_size // num_attention_heads
(2048 // 16 = 128 for OLMoE-1B-7B).

This follows the pattern used by MistralBridge and NemotronHBridge.
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
yaoyu-33 and others added 10 commits February 4, 2026 19:34
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
# Conflicts:
#	src/megatron/bridge/models/conversion/model_bridge.py
- Remove generation_config from provider_kwargs in model_bridge.py
- Remove generation_config from test fixtures and mocks
- Remove test_provider_bridge_generation_config tests from:
  - test_qwen3_bridge.py
  - test_qwen3_moe_bridge.py
  - test_qwen3_next_bridge.py
  - test_qwen25_vl_bridge.py
- Remove test_provide_with_generation_config from test_gpt_provider.py
- Remove unused GenerationConfig imports from test files
- Remove test_provider_bridge_generation_config from gemma, gemma2, gemma3 bridges
- Remove test_provider_bridge_generation_config from mistral bridge
- Remove test_provider_bridge_generation_config from olmoe bridge
- Remove generation_config assertion from llama bridge test
- Add squared_relu import and 'relu2' to ACTIVATION_MAPPING in model_bridge.py
- Add hidden_act='relu2' to NemotronH test config dicts (required for base class)
- Remove generation_config from Nemotron and NemotronH test fixtures
- NemotronH: Use Mock(spec=[]) and MambaModelProvider
- NemotronH: Fix num_moe_experts assertion to accept 0 or None
- Nemotron: Use GPTModelProvider instead of NemotronModelProvider
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 5, 2026

📝 Walkthrough

Walkthrough

This PR systematically removes generation_config propagation from HF pretrained models through provider bridges, refactors provider initialization to use base class helpers, adds provider and model_type parameters to bridge registrations, introduces squared_relu activation function usage, realigns some VL providers to inherit from GPTModelProvider, and removes corresponding test coverage across multiple model implementations.

Changes

Cohort / File(s) Summary
Activation Mapping & Common Changes
src/megatron/bridge/models/conversion/model_bridge.py, src/megatron/bridge/models/deepseek/common.py
Added squared_relu to activation mapping; removed generation_config propagation from HF configs.
Gemma3 VL Bridge
src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py
Extended bridge registration with provider and model_type parameters; refactored provider initialization using hf_config_to_provider_kwargs helper; added Gemma3-specific feature configurations (window_size, rotary_base, softmax_scale, rope_scaling_factor, VL tokens).
Nemotron Bridge Family
src/megatron/bridge/models/nemotron/nemotron_bridge.py, src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py, src/megatron/bridge/models/nemotronh/__init__.py
Updated bridge registrations with provider and model_type; switched Nemotron to GPTModelProvider; replaced Nemotron-H provider from NemotronHModelProvider to MambaModelProvider; added Mamba-specific config mappings; introduced squared_relu activation; added NemotronHBridge to public exports.
Nemotron VL Bridge
src/megatron/bridge/models/nemotron_vl/nemotron_vl_bridge.py, src/megatron/bridge/models/nemotron_vl/nemotron_vl_provider.py
Extended bridge registration with provider/model_type; refactored provider initialization via hf_config_to_provider_kwargs; added VL-specific overrides (activation_func, softmax settings); removed generation_config field.
GPT Provider Base
src/megatron/bridge/models/gpt_provider.py
Removed generation_config field from GPTModelProvider.
Qwen3 Next Bridge
src/megatron/bridge/models/qwen/qwen3_next_bridge.py
Removed generation_config argument from provider initialization.
Qwen VL Bridges & Providers
src/megatron/bridge/models/qwen_vl/qwen25_vl_bridge.py, src/megatron/bridge/models/qwen_vl/qwen25_vl_provider.py, src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py, src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py, src/megatron/bridge/models/qwen_vl/__init__.py
Updated bridge registrations with provider/model_type; refactored Qwen25/Qwen3 VL provider initialization; realigned Qwen25VLModelProvider and Qwen3VLModelProvider inheritance to GPTModelProvider (from Qwen2ModelProvider and Qwen3ModelProvider respectively); removed VL-specific fields (head_dim, qk_layernorm); applied VL-specific attribute configurations post-instantiation; updated module import source for Qwen25VLModelProvider.
Llama & Mistral Bridges
src/megatron/bridge/models/llama_nemotron/llama_nemotron_bridge.py, src/megatron/bridge/models/mistral/mistral_bridge.py
Removed generation_config from provider initialization arguments.
Test Coverage Removals - Generation Config
tests/unit_tests/models/gemma/test_gemma_bridge.py, tests/unit_tests/models/gemma/test_gemma2_bridge.py, tests/unit_tests/models/gemma/test_gemma3_bridge.py, tests/unit_tests/models/llama/test_llama_bridge.py, tests/unit_tests/models/mistral/test_mistral_model_bridge.py, tests/unit_tests/models/olmoe/test_olmoe_bridge.py, tests/unit_tests/models/qwen/test_qwen3_bridge.py, tests/unit_tests/models/qwen/test_qwen3_moe_bridge.py, tests/unit_tests/models/qwen/test_qwen3_next_bridge.py, tests/unit_tests/models/test_gpt_provider.py
Removed test_provider_bridge_generation_config and related test methods validating generation_config propagation through provider bridges.
Test Coverage Removals - Nemotron H
tests/functional_tests/models/nemotronh/test_nemotron_h_provider.py
Entire test file deleted; previously contained provider comparison tests.
Test Updates - Provider Type Changes
tests/unit_tests/models/nemotron/test_nemotron_bridge.py, tests/unit_tests/models/nemotronh/test_nemotron_h_bridge.py
Updated provider type expectations from NemotronModelProvider/NemotronHModelProvider to GPTModelProvider/MambaModelProvider; removed GenerationConfig mocking and setup.
Test Updates - Config & Fixtures
tests/unit_tests/models/gemma_vl/test_gemma3_vl_bridge.py, tests/unit_tests/models/nemotron_vl/test_nemotron_vl_bridge.py, tests/unit_tests/models/qwen_vl/test_qwen25_vl_bridge.py, tests/unit_tests/models/qwen_vl/test_qwen25_vl_provider.py
Updated mock configs with MLA-specific fields and hidden_act: "relu2"; replaced dtype-handling patch-based tests with hardcoded bf16 assertions; updated provider inheritance expectations; refined fixture initialization.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • Megatron-Bridge#2052: Directly related as this PR reverses the generation_config injection introduced in PR #2052 and updates the same MegatronModelBridge activation mapping.
  • Megatron-Bridge#1914: Related through overlapping changes to nemotron/nemotronh bridge and provider code paths, with modifications to structures introduced in that PR.

Suggested labels

Run CICD

Suggested reviewers

  • ananthsub
  • ko3n1g
  • yashaswikarnati
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 76.47% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR contains significant provider architecture refactoring but lacks documented test results, metrics, or regression validation evidence. Add test results section with documented validation across modified model variants and confirmation that no regressions were introduced.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[model, refactor] refactor: Centralize provider_bridge config mapping in base class for VLM models' clearly identifies the main change: centralizing provider_bridge configuration mapping in the base class, specifically for VLM models.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feature/provider-bridge-refactor-3

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py`:
- Around line 79-83: The default image_token_id used when loading hf_config is
incorrect: update the assignment for provider.image_token_id in
gemma3_vl_bridge.py to use Gemma3VLModelProvider's default of 262144 instead of
Qwen's 151655; specifically, change the getattr call on hf_config for
"image_token_id" so its fallback value is 262144 to match Gemma3VLModelProvider
(see provider.image_token_id and Gemma3VLModelProvider).

In `@src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py`:
- Around line 87-88: The guard that checks MoE-specific defaults uses
`hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0`
which will raise if `n_routed_experts` is present but None; update the condition
in `nemotron_h_bridge.py` to explicitly check for non-None (e.g.,
`getattr(hf_config, "n_routed_experts", None) is not None and
hf_config.n_routed_experts > 0` or test `hf_config.n_routed_experts is not None
and hf_config.n_routed_experts > 0`) so the MoE branch only runs when
`n_routed_experts` is a valid number. Ensure you modify the condition where
`hasattr(hf_config, "n_routed_experts")` is currently used.
- Around line 57-70: The CONFIG_MAPPING extension concatenates lists directly
causing Ruff RUF005; instead unpack MegatronModelBridge.CONFIG_MAPPING into a
new list so use iterable unpacking when building CONFIG_MAPPING (replace
MegatronModelBridge.CONFIG_MAPPING + [...] with
[*MegatronModelBridge.CONFIG_MAPPING, (...)]), keeping the same tuple entries
(e.g., "mamba_head_dim", "mamba_num_heads", "n_groups", "ssm_state_size",
"hybrid_override_pattern", "residual_in_fp32", "use_bias", "layer_norm_epsilon",
"moe_shared_expert_intermediate_size") to remove the linter warning while
preserving behavior.

In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py`:
- Around line 77-104: The assignment to provider.mrope_section can raise
TypeError when text_config.rope_scaling is None; update the provider_bridge
construction to safely read mrope_section by using a safe fallback like
(getattr(text_config, "rope_scaling") or {}).get("mrope_section", [24,20,20]) or
by checking if text_config.rope_scaling is truthy before calling .get; modify
the line that sets provider.mrope_section accordingly so it never calls .get on
None (reference symbols: provider.mrope_section, text_config.rope_scaling,
hf_config_to_provider_kwargs/provider_bridge).
🧹 Nitpick comments (5)
src/megatron/bridge/models/nemotron/nemotron_bridge.py (2)

29-31: Duplicate squared_relu definition - consider importing from existing location.

This function is already defined in:

  • src/megatron/bridge/models/nemotron/nemotron_provider.py (lines 26-28)
  • megatron.core.activations (as imported in model_bridge.py)

Importing from megatron.core.activations would reduce duplication and maintain consistency with the base class's ACTIVATION_MAPPING.

♻️ Proposed fix
-def squared_relu(x):
-    """Squared ReLU activation function."""
-    return torch.pow(torch.nn.functional.relu(x), 2)
+from megatron.core.activations import squared_relu

52-55: Consider using iterable unpacking for CONFIG_MAPPING extension.

Static analysis (RUF005) suggests using unpacking syntax for better readability.

♻️ Proposed fix
-    CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
+    CONFIG_MAPPING = [
+        *MegatronModelBridge.CONFIG_MAPPING,
         # Nemotron uses norm_eps instead of rms_norm_eps
         ("norm_eps", "layernorm_epsilon"),
     ]
src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py (1)

68-72: Hardcoded bf16 dtype may limit hardware compatibility.

The hardcoded bf16=True overrides dtype detection from HF config. This will fail on hardware without bfloat16 support (e.g., older GPUs). If this is intentional to match a specific baseline, consider documenting the hardware requirement or adding a fallback.

src/megatron/bridge/models/qwen_vl/qwen3_vl_provider.py (2)

36-43: Outdated docstring references incorrect parent class.

The docstring states "Inherits language model configuration from Qwen3ModelProvider" but the class now inherits from GPTModelProvider. Update the docstring to reflect the actual inheritance.

📝 Proposed fix
 class Qwen3VLModelProvider(GPTModelProvider):
     """
     Base model provider for Qwen 3 VL Models.
-    Inherits language model configuration from Qwen3ModelProvider.
+    Inherits language model configuration from GPTModelProvider.
 
     Note: num_query_groups in parent class corresponds to num_key_value_heads in HF config.
     Default value of 8 is used for GQA (Grouped Query Attention).
     """

136-157: Outdated docstring in Qwen3VLMoEModelProvider.

Similar to Qwen3VLModelProvider, the docstring references are now outdated. The class inherits from GPTModelProvider, not Qwen3MoEModelProvider.

📝 Proposed fix
 class Qwen3VLMoEModelProvider(GPTModelProvider):
     """
     Base model provider for Qwen 3 VL MoE (Mixture of Experts) Models.
 
     This provider inherits directly from GPTModelProvider following the
     provider_bridge refactoring pattern. It includes:
-    - Qwen3 MoE-specific LLM defaults (RMSNorm, gated linear unit, QK layernorm, MoE config)
+    - Qwen3-VL MoE-specific LLM defaults (RMSNorm, gated linear unit, QK layernorm, MoE config)
     - VL-specific configurations (vision_config, token IDs, mrope)
 
-    The Qwen3VLMoEBridge leverages Qwen3MoEBridge for HF config mapping,
-    then applies VL-specific overrides.
+    The Qwen3VLMoEBridge uses the centralized CONFIG_MAPPING for HF config mapping,
+    then applies MoE and VL-specific overrides.

Comment on lines +79 to +83
provider.bos_token_id = getattr(hf_config, "bos_token_id", 0)
provider.eos_token_id = getattr(hf_config, "eos_token_id", 1)
provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 255999)
provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 256000)
provider.image_token_id = getattr(hf_config, "image_token_id", 151655)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find src/megatron/bridge/models/gemma_vl -type f -name "*.py" | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 290


🏁 Script executed:

cat -n src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 8594


🏁 Script executed:

rg -n "151655|image_token_id" src/megatron/bridge/models/

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2744


🏁 Script executed:

# Check if there are any Qwen model files or references
find src/megatron/bridge/models -type f -name "*qwen*" -o -name "*Qwen*"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 599


🏁 Script executed:

# Check the gemma3_vl_provider to see how image_token_id is used
cat -n src/megatron/bridge/models/gemma_vl/gemma3_vl_provider.py | head -150

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 3368


Fix image_token_id default to match Gemma3 VL provider definition.

Line 83 uses the default image_token_id=151655, which is Qwen's tokenizer vocabulary value. The Gemma3VLModelProvider defines the correct default as 262144 (line 47 of gemma3_vl_provider.py). This mismatch will cause incorrect token assignments when the HuggingFace config lacks an image_token_id attribute. Change line 83 to:

provider.image_token_id = getattr(hf_config, "image_token_id", 262144)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/gemma_vl/gemma3_vl_bridge.py` around lines 79 -
83, The default image_token_id used when loading hf_config is incorrect: update
the assignment for provider.image_token_id in gemma3_vl_bridge.py to use
Gemma3VLModelProvider's default of 262144 instead of Qwen's 151655;
specifically, change the getattr call on hf_config for "image_token_id" so its
fallback value is 262144 to match Gemma3VLModelProvider (see
provider.image_token_id and Gemma3VLModelProvider).

Comment on lines +57 to +70
# Extend CONFIG_MAPPING with Nemotron-H/Mamba-specific fields
CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
# Mamba-specific fields
("mamba_head_dim", "mamba_head_dim"),
("mamba_num_heads", "mamba_num_heads"),
("n_groups", "mamba_num_groups"),
("ssm_state_size", "mamba_state_dim"),
("hybrid_override_pattern", "hybrid_override_pattern"),
("residual_in_fp32", "fp32_residual_connection"),
("use_bias", "add_bias_linear"),
("layer_norm_epsilon", "layernorm_epsilon"),
# MoE-specific fields (already in base but with different HF names)
("moe_shared_expert_intermediate_size", "moe_shared_expert_intermediate_size"),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix Ruff RUF005 by using iterable unpacking.

This avoids the warning and keeps ruff clean.

♻️ Suggested update
-    CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
+    CONFIG_MAPPING = [
+        *MegatronModelBridge.CONFIG_MAPPING,
         # Mamba-specific fields
         ("mamba_head_dim", "mamba_head_dim"),
         ("mamba_num_heads", "mamba_num_heads"),
         ("n_groups", "mamba_num_groups"),
         ("ssm_state_size", "mamba_state_dim"),
         ("hybrid_override_pattern", "hybrid_override_pattern"),
         ("residual_in_fp32", "fp32_residual_connection"),
         ("use_bias", "add_bias_linear"),
         ("layer_norm_epsilon", "layernorm_epsilon"),
         # MoE-specific fields (already in base but with different HF names)
         ("moe_shared_expert_intermediate_size", "moe_shared_expert_intermediate_size"),
-    ]
+    ]
As per coding guidelines, "Use ruff for linting and formatting with commands: uv run ruff check --fix . and uv run ruff format .".
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Extend CONFIG_MAPPING with Nemotron-H/Mamba-specific fields
CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [
# Mamba-specific fields
("mamba_head_dim", "mamba_head_dim"),
("mamba_num_heads", "mamba_num_heads"),
("n_groups", "mamba_num_groups"),
("ssm_state_size", "mamba_state_dim"),
("hybrid_override_pattern", "hybrid_override_pattern"),
("residual_in_fp32", "fp32_residual_connection"),
("use_bias", "add_bias_linear"),
("layer_norm_epsilon", "layernorm_epsilon"),
# MoE-specific fields (already in base but with different HF names)
("moe_shared_expert_intermediate_size", "moe_shared_expert_intermediate_size"),
]
# Extend CONFIG_MAPPING with Nemotron-H/Mamba-specific fields
CONFIG_MAPPING = [
*MegatronModelBridge.CONFIG_MAPPING,
# Mamba-specific fields
("mamba_head_dim", "mamba_head_dim"),
("mamba_num_heads", "mamba_num_heads"),
("n_groups", "mamba_num_groups"),
("ssm_state_size", "mamba_state_dim"),
("hybrid_override_pattern", "hybrid_override_pattern"),
("residual_in_fp32", "fp32_residual_connection"),
("use_bias", "add_bias_linear"),
("layer_norm_epsilon", "layernorm_epsilon"),
# MoE-specific fields (already in base but with different HF names)
("moe_shared_expert_intermediate_size", "moe_shared_expert_intermediate_size"),
]
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 58-70: Consider iterable unpacking instead of concatenation

(RUF005)

🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py` around lines 57 -
70, The CONFIG_MAPPING extension concatenates lists directly causing Ruff
RUF005; instead unpack MegatronModelBridge.CONFIG_MAPPING into a new list so use
iterable unpacking when building CONFIG_MAPPING (replace
MegatronModelBridge.CONFIG_MAPPING + [...] with
[*MegatronModelBridge.CONFIG_MAPPING, (...)]), keeping the same tuple entries
(e.g., "mamba_head_dim", "mamba_num_heads", "n_groups", "ssm_state_size",
"hybrid_override_pattern", "residual_in_fp32", "use_bias", "layer_norm_epsilon",
"moe_shared_expert_intermediate_size") to remove the linter warning while
preserving behavior.

Comment on lines +87 to +88
# MoE-specific defaults (only if MoE is enabled)
if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard against n_routed_experts=None to avoid a TypeError.

If the attribute exists but is None, > 0 will raise.

🛡️ Suggested guard
-        if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0:
+        n_routed_experts = getattr(hf_config, "n_routed_experts", 0) or 0
+        if n_routed_experts > 0:
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py` around lines 87 -
88, The guard that checks MoE-specific defaults uses `hasattr(hf_config,
"n_routed_experts") and hf_config.n_routed_experts > 0` which will raise if
`n_routed_experts` is present but None; update the condition in
`nemotron_h_bridge.py` to explicitly check for non-None (e.g.,
`getattr(hf_config, "n_routed_experts", None) is not None and
hf_config.n_routed_experts > 0` or test `hf_config.n_routed_experts is not None
and hf_config.n_routed_experts > 0`) so the MoE branch only runs when
`n_routed_experts` is a valid number. Ensure you modify the condition where
`hasattr(hf_config, "n_routed_experts")` is currently used.

Comment on lines +77 to 104
provider_kwargs = self.hf_config_to_provider_kwargs(text_config)

# Set vision config dtype to match the language model dtype
# This ensures vision model parameters are initialized in the same dtype
vision_config = hf_config.vision_config
vision_config.torch_dtype = model_dtype

# Create the provider with text model configuration
provider = Qwen3VLModelProvider(
# Language model configuration from text_config
num_layers=text_config.num_hidden_layers,
hidden_size=text_config.hidden_size,
ffn_hidden_size=text_config.intermediate_size,
num_attention_heads=text_config.num_attention_heads,
num_query_groups=text_config.num_key_value_heads, # GQA configuration
head_dim=text_config.head_dim,
init_method_std=text_config.initializer_range,
layernorm_epsilon=text_config.rms_norm_eps,
gated_linear_unit=True, # Qwen3 uses gated linear units
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(text_config.vocab_size),
rotary_base=text_config.rope_theta,
share_embeddings_and_output_weights=getattr(text_config, "tie_word_embeddings", False),
vocab_size=text_config.vocab_size,
seq_length=text_config.max_position_embeddings,
fp16=(model_dtype == torch.float16),
bf16=(model_dtype == torch.bfloat16),
params_dtype=model_dtype,
generation_config=hf_pretrained.generation_config,
# Qwen3 specific parameters
add_qkv_bias=text_config.attention_bias, # Qwen3 can have bias in QKV
qk_layernorm=True, # Qwen3 uses QK layernorm
# Vision configuration
vision_config=vision_config,
# Store the original HF text config for RoPE initialization
hf_text_config=text_config,
# Vision-Language token IDs
bos_token_id=getattr(text_config, "bos_token_id", 151643),
eos_token_id=getattr(text_config, "eos_token_id", 151645),
vision_start_token_id=getattr(hf_config, "vision_start_token_id", 151652),
vision_end_token_id=getattr(hf_config, "vision_end_token_id", 151653),
image_token_id=getattr(hf_config, "image_token_id", 151655),
video_token_id=getattr(hf_config, "video_token_id", 151656),
# MRoPE configuration for multimodal position embeddings
mrope_section=text_config.rope_scaling.get("mrope_section", [24, 20, 20]),
)
vision_config.torch_dtype = provider_kwargs.get("params_dtype", torch.float32)

provider = Qwen3VLModelProvider(**provider_kwargs)

# Qwen3-specific settings
provider.normalization = "RMSNorm"
provider.gated_linear_unit = True
provider.add_qkv_bias = text_config.attention_bias
provider.add_bias_linear = False
provider.qk_layernorm = True
provider.hidden_dropout = 0.0

# VL-specific overrides
provider.position_embedding_type = "mrope"
provider.vision_config = vision_config
provider.hf_text_config = text_config
provider.head_dim = text_config.head_dim
provider.bos_token_id = getattr(text_config, "bos_token_id", 151643)
provider.eos_token_id = getattr(text_config, "eos_token_id", 151645)
provider.vision_start_token_id = getattr(hf_config, "vision_start_token_id", 151652)
provider.vision_end_token_id = getattr(hf_config, "vision_end_token_id", 151653)
provider.image_token_id = getattr(hf_config, "image_token_id", 151655)
provider.video_token_id = getattr(hf_config, "video_token_id", 151656)
provider.mrope_section = text_config.rope_scaling.get("mrope_section", [24, 20, 20])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Provider construction follows new pattern correctly.

The refactored provider_bridge method properly uses hf_config_to_provider_kwargs and applies Qwen3-specific overrides post-instantiation.

However, Line 103 may raise a TypeError if rope_scaling is None:

provider.mrope_section = text_config.rope_scaling.get("mrope_section", [24, 20, 20])

If text_config.rope_scaling is None, calling .get() will fail.

🐛 Proposed fix
-        provider.mrope_section = text_config.rope_scaling.get("mrope_section", [24, 20, 20])
+        provider.mrope_section = (text_config.rope_scaling or {}).get("mrope_section", [24, 20, 20])
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/qwen_vl/qwen3_vl_bridge.py` around lines 77 - 104,
The assignment to provider.mrope_section can raise TypeError when
text_config.rope_scaling is None; update the provider_bridge construction to
safely read mrope_section by using a safe fallback like (getattr(text_config,
"rope_scaling") or {}).get("mrope_section", [24,20,20]) or by checking if
text_config.rope_scaling is truthy before calling .get; modify the line that
sets provider.mrope_section accordingly so it never calls .get on None
(reference symbols: provider.mrope_section, text_config.rope_scaling,
hf_config_to_provider_kwargs/provider_bridge).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant