Skip to content

Update inference code to reflect Megatron-LM API changes#2263

Open
santhnm2 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
santhnm2:inference_api_update
Open

Update inference code to reflect Megatron-LM API changes#2263
santhnm2 wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
santhnm2:inference_api_update

Conversation

@santhnm2
Copy link
Contributor

@santhnm2 santhnm2 commented Feb 6, 2026

What does this PR do ?

Updates the inference code to reflect Megatron-LM API changes (in particular, the removal of InferenceWrapperConfig).

Changelog

  • Add specific line by line info of high level changes in this PR.

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Summary by CodeRabbit

Release Notes

  • Refactor

    • Streamlined inference wrapper initialization by removing explicit configuration object dependencies
    • Updated inference function parameter naming and handling for improved API consistency
    • Simplified configuration management throughout the inference engine initialization process
    • Adjusted internal class inheritance hierarchy for improved component architecture
  • Tests

    • Updated test fixtures, assertions, and mock configurations to align with refactored inference components

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 6, 2026

📝 Walkthrough

Walkthrough

This PR refactors VLM inference modules by replacing CommonInferenceParams with SamplingParams throughout the inference pipeline and simplifying inference wrapper construction by removing explicit InferenceWrapperConfig parameter passing. A new initialization pathway is added to VLMEngine accepting a TextGenerationController.

Changes

Cohort / File(s) Summary
Core Inference API Updates
src/megatron/bridge/inference/vlm/base.py, src/megatron/bridge/inference/vlm/vlm_engine.py
Replaced CommonInferenceParams with SamplingParams in generate() method signatures and imports. Updated documentation and default instantiation behavior accordingly.
Wrapper Simplification
src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py
Removed InferenceWrapperConfig dependency from wrapper construction; __init__ now accepts only model parameter instead of both model and config.
Controller Base Class Update
src/megatron/bridge/inference/vlm/vlm_inference_controller.py
Changed VLMTextGenerationController to inherit from TextGenerationController instead of SimpleTextGenerationController; updated corresponding import path.
VLMEngine Initialization
src/megatron/bridge/inference/vlm/vlm_engine.py
Added new __init__ method accepting TextGenerationController, optional max_batch_size, and optional random_seed. Wires controller's wrapped model and initializes scheduler with deterministic seed support.
Test Fixtures and Configuration
tests/unit_tests/inference/vlm/conftest.py, tests/unit_tests/inference/vlm/test_qwenvl_inference_wrapper.py
Removed mock_inference_wrapper_config fixture; updated wrapper test fixture to depend only on mock_model.
Parameter Migration Tests
tests/unit_tests/inference/vlm/test_base.py, tests/unit_tests/inference/vlm/test_vlm_engine.py
Updated tests to use SamplingParams instead of CommonInferenceParams. Removed config assertions from setup tests; adjusted mock configuration references to use context properties.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • Megatron-Bridge#2071 — Directly related PR modifying the same VLM inference modules (base.py, qwenvl_inference_wrapper.py, vlm_engine.py) with matching pattern of replacing CommonInferenceParams with SamplingParams and removing InferenceWrapperConfig.

Suggested reviewers

  • meatybobby
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR contains major breaking API changes across critical inference modules without explicit test results documentation in PR description or commit message. Update PR description to document passing unit tests, regression testing results, and performance impact statement once CI executes successfully.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: updating inference code to reflect Megatron-LM API changes, which aligns with the substantial refactoring across multiple files to accommodate the removal of InferenceWrapperConfig and replacement of CommonInferenceParams with SamplingParams.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/megatron/bridge/inference/vlm/vlm_engine.py (1)

50-51: ⚠️ Potential issue | 🟡 Minor

Seed guard is always true given random_seed or 1234.

self.random_seed is always truthy (either the user's value or 1234), so this if never skips seeding. If the or 1234 is intentional, this guard is dead code. If it's not, see the prior comment about the default.

🤖 Fix all issues with AI agents
In `@src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py`:
- Around line 34-35: The __init__ method in QwenVLInferenceWrapper currently
accepts an untyped model parameter; add the declared type hint to match the
documented type by changing the signature of __init__ to accept model:
"Qwen2VLModel" and leave the super().__init__(model) call unchanged so tooling
and linters recognize the model type (refer to the __init__ method and the model
parameter in QwenVLInferenceWrapper).

In `@src/megatron/bridge/inference/vlm/vlm_engine.py`:
- Around line 29-38: The __init__ currently uses "self.random_seed = random_seed
or 1234", which coerces None and 0 to 1234 and causes deterministic seeding;
change it to preserve None and allow 0 by assigning "self.random_seed =
random_seed" (no fallback) and update any checks (e.g., in generate or other
methods that use self.random_seed) from "if self.random_seed:" to "if
self.random_seed is not None:" so seeding only occurs when an explicit seed is
provided and 0 is treated as a valid seed.
🧹 Nitpick comments (3)
src/megatron/bridge/inference/vlm/base.py (1)

165-165: Consider using SamplingParams | None instead of Optional[SamplingParams].

The coding guidelines prefer T | None over Optional[T] for nullable types (Python 3.10+). This applies to line 165 and the function signature.

As per coding guidelines, "Use 'T | None' for nullable types instead of 'Optional[T]'".

♻️ Proposed fix
 def generate(
     wrapped_model: AbstractModelInferenceWrapper,
     tokenizer,
     image_processor,
     prompts: List[str],
     images: List[Union[Image, List[Image]]],
     processor=None,
     max_batch_size: int = 4,
     random_seed: Optional[int] = None,
-    sampling_params: Optional[SamplingParams] = None,
+    sampling_params: SamplingParams | None = None,
 ) -> dict:

Also applies to: 202-203

tests/unit_tests/inference/vlm/test_vlm_engine.py (1)

33-36: Consider adding a test case that passes sampling_params to generate.

The current test calls engine.generate(["prompt"], ["image"]) without sampling_params. Since the sampling_params pathway is the primary change in this PR, a test verifying it is forwarded to scheduler.add_request would strengthen coverage.

tests/unit_tests/inference/vlm/test_base.py (1)

346-346: Consider moving the SamplingParams import to the top of the file.

The inline import on line 346 works but is inconsistent with the rest of the file where imports are at the module level. Since SamplingParams is used in production code already imported in this test's scope, a top-level import is cleaner.

♻️ Proposed fix

Add to the top-level imports:

from megatron.core.inference.sampling_params import SamplingParams

Then remove the inline import at line 346.

Comment on lines +34 to +35
def __init__(self, model):
super().__init__(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, let's examine the actual file and verify the content at lines 34-35
cat -n src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py | head -50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2079


🏁 Script executed:

# Check the parent class AbstractModelInferenceWrapper to verify __init__ signature
find . -name "*.py" -type f | xargs rg -l "class AbstractModelInferenceWrapper" | head -3

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Look for InferenceWrapperConfig usage in the file to understand the changes
rg -n "InferenceWrapperConfig" src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Find the AbstractModelInferenceWrapper class definition
rg -n "class AbstractModelInferenceWrapper" --type py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Get the signature of AbstractModelInferenceWrapper.__init__
rg -A 5 "def __init__" megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py 2>/dev/null | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Let's search for this parent class in the repository
find . -path "*/megatron/core/inference*" -name "*.py" -type f | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check if there's a type hint missing for model parameter in __init__
rg -B 2 -A 2 "def __init__\(self, model\)" src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 143


🏁 Script executed:

# Look for other inference wrapper classes to check type hint patterns
find . -path "*/inference*" -name "*wrapper.py" -type f | head -10

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 200


🏁 Script executed:

# Check other similar wrapper classes for __init__ type hints
rg -B 2 -A 2 "class.*InferenceWrapper.*:" src/megatron/bridge/ -A 8 | grep -A 8 "def __init__"

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check the exact parent class definition to understand the interface
fd -e py | xargs rg -l "class AbstractModelInferenceWrapper" 2>/dev/null

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 55


🏁 Script executed:

# Let's look at the imports to understand the typing - check if there's type hints elsewhere
rg -n "from typing import" src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py -A 2

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 128


🏁 Script executed:

# Check if the model parameter should have a type hint by looking at similar __init__ methods
rg "def __init__\(self, model.*\)" src/megatron/bridge/ -A 1 | head -30

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 893


🏁 Script executed:

# Check if there are other __init__ methods in the codebase with type hints to establish the pattern
rg "def __init__\(self, model.*:" src/megatron/bridge/inference/ -A 1 | head -20

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 257


🏁 Script executed:

# Look at the parent class to understand if it has type hints
rg -B 5 -A 10 "class AbstractModelInferenceWrapper" . --type py

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 53


🏁 Script executed:

# Check test file to understand expected usage
cat -n tests/unit_tests/inference/vlm/test_qwenvl_inference_wrapper.py | head -50

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 2467


🏁 Script executed:

# Verify if model parameter type hint is required by checking similar files
find src/megatron/bridge/inference -name "*.py" -exec grep -l "def __init__" {} \; | xargs grep -A 1 "def __init__" | head -40

Repository: NVIDIA-NeMo/Megatron-Bridge

Length of output: 1222


Add type hint for model parameter.

The constructor simplification is sound, but the model parameter lacks a type hint. Per coding guidelines, add the type hint matching the documented type:

def __init__(self, model: "Qwen2VLModel"):
    super().__init__(model)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/inference/vlm/qwenvl_inference_wrapper.py` around lines
34 - 35, The __init__ method in QwenVLInferenceWrapper currently accepts an
untyped model parameter; add the declared type hint to match the documented type
by changing the signature of __init__ to accept model: "Qwen2VLModel" and leave
the super().__init__(model) call unchanged so tooling and linters recognize the
model type (refer to the __init__ method and the model parameter in
QwenVLInferenceWrapper).

Comment on lines +29 to +38
def __init__(
self, text_generation_controller: TextGenerationController,
max_batch_size: Optional[int] = None,
random_seed: Optional[int] = None,
):
self.controller = text_generation_controller
self.inference_wrapped_model = self.controller.inference_wrapped_model
self.config = self.inference_wrapped_model.config
self.random_seed = random_seed or 1234
self.scheduler = Scheduler(max_batch_size=max_batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

random_seed or 1234 silently coerces None and 0 to 1234, making inference always deterministically seeded.

Line 37: self.random_seed = random_seed or 1234 means:

  • Passing random_seed=None (the default from base.py) → seeds with 1234 instead of not seeding.
  • Passing random_seed=0 → seeds with 1234 instead of 0.

Combined with line 50 (if self.random_seed: — always truthy), every call to generate will set a deterministic seed. This is inconsistent with base.py line 164 where random_seed defaults to None, implying "no fixed seed."

Proposed fix
-        self.random_seed = random_seed or 1234
+        self.random_seed = random_seed
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(
self, text_generation_controller: TextGenerationController,
max_batch_size: Optional[int] = None,
random_seed: Optional[int] = None,
):
self.controller = text_generation_controller
self.inference_wrapped_model = self.controller.inference_wrapped_model
self.config = self.inference_wrapped_model.config
self.random_seed = random_seed or 1234
self.scheduler = Scheduler(max_batch_size=max_batch_size)
def __init__(
self, text_generation_controller: TextGenerationController,
max_batch_size: Optional[int] = None,
random_seed: Optional[int] = None,
):
self.controller = text_generation_controller
self.inference_wrapped_model = self.controller.inference_wrapped_model
self.config = self.inference_wrapped_model.config
self.random_seed = random_seed
self.scheduler = Scheduler(max_batch_size=max_batch_size)
🤖 Prompt for AI Agents
In `@src/megatron/bridge/inference/vlm/vlm_engine.py` around lines 29 - 38, The
__init__ currently uses "self.random_seed = random_seed or 1234", which coerces
None and 0 to 1234 and causes deterministic seeding; change it to preserve None
and allow 0 by assigning "self.random_seed = random_seed" (no fallback) and
update any checks (e.g., in generate or other methods that use self.random_seed)
from "if self.random_seed:" to "if self.random_seed is not None:" so seeding
only occurs when an explicit seed is provided and 0 is treated as a valid seed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant