Skip to content

feat: register post-processed tensors with NIXL using full module tree traversal#169

Open
nv-hwoo wants to merge 5 commits intomainfrom
hwoo/move-nixl-registration
Open

feat: register post-processed tensors with NIXL using full module tree traversal#169
nv-hwoo wants to merge 5 commits intomainfrom
hwoo/move-nixl-registration

Conversation

@nv-hwoo
Copy link
Contributor

@nv-hwoo nv-hwoo commented Mar 17, 2026

Problem

Currently we register the original weight tensors with NIXL before the process_weights_after_loading() call in vLLM loader. This is a problem as NIXL registration holds references to the origin weight tensors (pinned), and any modifications to the model tensors by the kernel backends (e.g. FlashInfer) during process_weights_after_loading() cannot free the original weights after the new tensors are created/allocated.

This ends up wasting a lot of extra HBM memories, causing CUDA OOM with large models like Kimi k2.5 or DSV3. For example, when loading DSV3 TEP8 on 8x H200 GPUs, the current code causes extra 50GB per worker, leaving very little or no room for kv caches. Also, when loading Kimi k2.5 TEP4 on 4x B200 GPUs (which should fit as each GPU has ~178GB HBM), we run into CUDA OOM during the process_weights_after_loading() as there are no more free GPU memory due to pinned original weight tensors.

Solution

Move the NIXL registration to after process_weights_after_loading() and register the final model tensors after all the processings are done. This eliminates the pinned memory issue and CUDA OOM issues described above.

Code Changes

  • Move NIXL tensor registration to after process_weights_after_loading() so that all final tensors (parameters, buffers, and bare tensor attributes like FP8 scales) are captured
  • Replace named_parameters() with _iter_module_tensors() which walks the full PyTorch module tree to discover parameters, buffers, and tensor attributes (e.g. weight_scale, _k_scale)
  • Skip non-contiguous tensors (e.g. transposed views like W_UK_T) since they are views over contiguous tensors already registered
  • Target path now runs process_weights_after_loading() on dummy data to establish the final tensor layout, then receives fully-processed weights via RDMA — no post-processing needed after transfer

Summary by CodeRabbit

  • New Features

    • Expanded tensor discovery to identify parameters, buffers, and custom tensor attributes
    • Refined weight processing sequence for improved tensor registration and handling
  • Documentation

    • Architecture guide updated to reflect enhanced tensor discovery and processing workflow

…e traversal

Move NIXL tensor registration to AFTER process_weights_after_loading()
and use _iter_module_tensors() to discover all CUDA tensors (parameters,
buffers, and bare tensor attributes like FP8 scales). Non-contiguous
tensors (e.g. transposed views like W_UK_T) are skipped as they are
views over contiguous tensors already in the module tree.

Also adds GPU memory stage logging and sets logger level to INFO.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Hyunjae Woo <hwoo@nvidia.com>
@coderabbitai
Copy link

coderabbitai bot commented Mar 17, 2026

Walkthrough

The changes refactor tensor collection and registration in MxModelLoader from collecting only CUDA parameters to a broader traversal capturing parameters, buffers, and tensor attributes. The public interface transitions from raw_tensors to tensors, and the loading flow now processes weights before registration and metadata publishing.

Changes

Cohort / File(s) Summary
Documentation
docs/ARCHITECTURE.md
Updated narrative to reflect new tensor discovery mechanism via _iter_module_tensors(), revised FP8 handling and NIXL integration flow, and clarified processing sequencing on both source and target paths. Updated Mermaid diagram to show weights processed before RDMA transfer.
Loader Implementation
modelexpress_client/python/modelexpress/vllm_loader.py
Replaced _collect_cuda_tensors with new _iter_module_tensors() and _collect_module_tensors() helpers to traverse module tree for parameters, buffers, and tensor attributes. Changed public API from raw_tensors property to tensors, renamed internal _raw_tensors to _tensors and global _raw_tensor_registry to _tensor_registry. Updated load flow to process weights before registration, adjusted target/source sequencing, and changed default logging level from DEBUG to INFO.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Poem

🐰 A rabbit hops through tensor trees,
Gathering parameters, buffers, with ease,
Raw tensors now called by their proper name,
Processing before transfer—a structural game!
From CUDA-only to all-encompassing view,
Our model loader's reborn, refreshed and new! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: registering post-processed tensors with NIXL using full module tree traversal, which matches the core objectives of the PR.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can validate your CodeRabbit configuration file in your editor.

If your editor has YAML language server, you can enable auto-completion and validation by adding # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json at the top of your CodeRabbit configuration file.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelexpress_client/python/modelexpress/vllm_loader.py (2)

157-160: Consider logging skipped attributes at debug level.

The bare except Exception: continue silently swallows errors when inspecting module attributes. While catching broadly here is reasonable (arbitrary getattr calls can fail in many ways), logging at debug level would aid troubleshooting without adding noise in production.

         try:
             attr_val = getattr(module, attr_name, None)
         except Exception:
+            logger.debug(f"Could not inspect attribute '{attr_name}' on module")
             continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelexpress_client/python/modelexpress/vllm_loader.py` around lines 157 -
160, The try/except around getattr(module, attr_name, None) silently swallows
failures; change the except to capture the exception (e.g., except Exception as
e) and emit a debug-level log including the module name and attr_name and
exception details (use the existing logger or create one if absent) before
continuing, keeping the same control flow so attr_val is skipped on error.

664-666: Minor: tensors collected twice in target path.

In the target path, _collect_module_tensors(model) is called in _receive_from_peer (line 516) and again here (line 664). Since the model doesn't change between calls, this is functionally correct but slightly redundant. Consider passing tensors as a parameter if performance becomes a concern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelexpress_client/python/modelexpress/vllm_loader.py` around lines 664 -
666, The code calls _collect_module_tensors(model) twice (in _receive_from_peer
and later when registering) which is redundant; modify either _receive_from_peer
or the registration call to reuse the already-collected tensors by passing that
tensors object into the later routine (or assign it to self._tensors once and
remove the second _collect_module_tensors call), update the signature of the
registration function that currently calls _collect_module_tensors(model) to
accept a tensors parameter (or read self._tensors), and ensure callers
(including _receive_from_peer) pass the collected tensors to the registration
path so _collect_module_tensors is invoked only once for model.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelexpress_client/python/modelexpress/vllm_loader.py`:
- Around line 157-160: The try/except around getattr(module, attr_name, None)
silently swallows failures; change the except to capture the exception (e.g.,
except Exception as e) and emit a debug-level log including the module name and
attr_name and exception details (use the existing logger or create one if
absent) before continuing, keeping the same control flow so attr_val is skipped
on error.
- Around line 664-666: The code calls _collect_module_tensors(model) twice (in
_receive_from_peer and later when registering) which is redundant; modify either
_receive_from_peer or the registration call to reuse the already-collected
tensors by passing that tensors object into the later routine (or assign it to
self._tensors once and remove the second _collect_module_tensors call), update
the signature of the registration function that currently calls
_collect_module_tensors(model) to accept a tensors parameter (or read
self._tensors), and ensure callers (including _receive_from_peer) pass the
collected tensors to the registration path so _collect_module_tensors is invoked
only once for model.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a32bfbf5-0540-4f15-bc3e-68615da7c365

📥 Commits

Reviewing files that changed from the base of the PR and between 63643c7 and 721635d.

📒 Files selected for processing (2)
  • docs/ARCHITECTURE.md
  • modelexpress_client/python/modelexpress/vllm_loader.py

Signed-off-by: Hyunjae Woo <hwoo@nvidia.com>
@codecov
Copy link

codecov bot commented Mar 17, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant