feat: register post-processed tensors with NIXL using full module tree traversal#169
feat: register post-processed tensors with NIXL using full module tree traversal#169
Conversation
…e traversal Move NIXL tensor registration to AFTER process_weights_after_loading() and use _iter_module_tensors() to discover all CUDA tensors (parameters, buffers, and bare tensor attributes like FP8 scales). Non-contiguous tensors (e.g. transposed views like W_UK_T) are skipped as they are views over contiguous tensors already in the module tree. Also adds GPU memory stage logging and sets logger level to INFO. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Hyunjae Woo <hwoo@nvidia.com>
WalkthroughThe changes refactor tensor collection and registration in MxModelLoader from collecting only CUDA parameters to a broader traversal capturing parameters, buffers, and tensor attributes. The public interface transitions from Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. 📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can validate your CodeRabbit configuration file in your editor.If your editor has YAML language server, you can enable auto-completion and validation by adding |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
modelexpress_client/python/modelexpress/vllm_loader.py (2)
157-160: Consider logging skipped attributes at debug level.The bare
except Exception: continuesilently swallows errors when inspecting module attributes. While catching broadly here is reasonable (arbitrarygetattrcalls can fail in many ways), logging at debug level would aid troubleshooting without adding noise in production.try: attr_val = getattr(module, attr_name, None) except Exception: + logger.debug(f"Could not inspect attribute '{attr_name}' on module") continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelexpress_client/python/modelexpress/vllm_loader.py` around lines 157 - 160, The try/except around getattr(module, attr_name, None) silently swallows failures; change the except to capture the exception (e.g., except Exception as e) and emit a debug-level log including the module name and attr_name and exception details (use the existing logger or create one if absent) before continuing, keeping the same control flow so attr_val is skipped on error.
664-666: Minor: tensors collected twice in target path.In the target path,
_collect_module_tensors(model)is called in_receive_from_peer(line 516) and again here (line 664). Since the model doesn't change between calls, this is functionally correct but slightly redundant. Consider passing tensors as a parameter if performance becomes a concern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelexpress_client/python/modelexpress/vllm_loader.py` around lines 664 - 666, The code calls _collect_module_tensors(model) twice (in _receive_from_peer and later when registering) which is redundant; modify either _receive_from_peer or the registration call to reuse the already-collected tensors by passing that tensors object into the later routine (or assign it to self._tensors once and remove the second _collect_module_tensors call), update the signature of the registration function that currently calls _collect_module_tensors(model) to accept a tensors parameter (or read self._tensors), and ensure callers (including _receive_from_peer) pass the collected tensors to the registration path so _collect_module_tensors is invoked only once for model.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@modelexpress_client/python/modelexpress/vllm_loader.py`:
- Around line 157-160: The try/except around getattr(module, attr_name, None)
silently swallows failures; change the except to capture the exception (e.g.,
except Exception as e) and emit a debug-level log including the module name and
attr_name and exception details (use the existing logger or create one if
absent) before continuing, keeping the same control flow so attr_val is skipped
on error.
- Around line 664-666: The code calls _collect_module_tensors(model) twice (in
_receive_from_peer and later when registering) which is redundant; modify either
_receive_from_peer or the registration call to reuse the already-collected
tensors by passing that tensors object into the later routine (or assign it to
self._tensors once and remove the second _collect_module_tensors call), update
the signature of the registration function that currently calls
_collect_module_tensors(model) to accept a tensors parameter (or read
self._tensors), and ensure callers (including _receive_from_peer) pass the
collected tensors to the registration path so _collect_module_tensors is invoked
only once for model.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a32bfbf5-0540-4f15-bc3e-68615da7c365
📒 Files selected for processing (2)
docs/ARCHITECTURE.mdmodelexpress_client/python/modelexpress/vllm_loader.py
Signed-off-by: Hyunjae Woo <hwoo@nvidia.com>
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Problem
Currently we register the original weight tensors with NIXL before the
process_weights_after_loading()call in vLLM loader. This is a problem as NIXL registration holds references to the origin weight tensors (pinned), and any modifications to the model tensors by the kernel backends (e.g. FlashInfer) duringprocess_weights_after_loading()cannot free the original weights after the new tensors are created/allocated.This ends up wasting a lot of extra HBM memories, causing CUDA OOM with large models like Kimi k2.5 or DSV3. For example, when loading DSV3 TEP8 on 8x H200 GPUs, the current code causes extra 50GB per worker, leaving very little or no room for kv caches. Also, when loading Kimi k2.5 TEP4 on 4x B200 GPUs (which should fit as each GPU has ~178GB HBM), we run into CUDA OOM during the
process_weights_after_loading()as there are no more free GPU memory due to pinned original weight tensors.Solution
Move the NIXL registration to after
process_weights_after_loading()and register the final model tensors after all the processings are done. This eliminates the pinned memory issue and CUDA OOM issues described above.Code Changes
process_weights_after_loading()so that all final tensors (parameters, buffers, and bare tensor attributes like FP8 scales) are capturednamed_parameters()with_iter_module_tensors()which walks the full PyTorch module tree to discover parameters, buffers, and tensor attributes (e.g. weight_scale, _k_scale)Summary by CodeRabbit
New Features
Documentation