[TRTLLM-11146][feat] Extend python cache transceiver to support nemotron#12150
[TRTLLM-11146][feat] Extend python cache transceiver to support nemotron#12150bo-nv wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #38698 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis change introduces Mamba state transfer support to the disaggregated serving architecture. It adds a new Changes
Sequence DiagramsequenceDiagram
participant App as Application
participant Executor as PyExecutor
participant CacheMgr as MambaCacheManager
participant Transceiver as PyNativeCacheTransceiver
participant Policy as MambaPolicy
participant Transfer as KVSendTask/KVRecvTask
App->>Executor: init with disaggregated config
Executor->>CacheMgr: create MambaHybridCacheManager
CacheMgr->>CacheMgr: partition conv/ssm states across TP ranks
App->>Transceiver: initialize for context rank
Transceiver->>CacheMgr: get layer groups with mamba_state_index
Transceiver->>Transceiver: build page tables with Mamba layers
App->>Executor: run context phase
Executor->>CacheMgr: allocate & write Mamba states (conv/ssm)
App->>Transfer: prepare send to generation rank
Transfer->>Policy: collect_frags for Mamba states
Policy->>Policy: _mamba_tp (compute effective TP mapping)
Policy->>Policy: _select_mapper (choose conv/ssm mapper)
Policy->>Policy: build_mamba_frags (compute source/dest pointers)
Policy-->>Transfer: return fragment pointers & sizes
Transfer->>Transfer: aggregate KV + Mamba fragments
Transfer->>Transceiver: transfer state via transceiver
Transceiver->>Transceiver: recv & reconstruct states
Transceiver->>CacheMgr: write received states to Mamba cache
App->>Executor: run generation phase
Executor->>CacheMgr: read Mamba states from cache
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip You can get early access to new features in CodeRabbit.Enable the |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the copyright year to 2026.
The copyright header shows "2022-2024" but this file has meaningful modifications in 2026. As per coding guidelines, the year should reflect the latest modification.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` at line 1, Update the file header copyright year range from "2022-2024" to include 2026 (e.g., "2022-2026") at the top of the file; modify the SPDX/FileCopyrightText comment line in tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py so the year range reflects the latest modification year.
🧹 Nitpick comments (2)
tests/unittest/disaggregated/test_mamba_transfer.py (1)
329-455: Run the shutdown path from afinallyblock.Any exception or assertion failure before Line 450 skips both the cache-manager shutdowns and the transfer-worker shutdowns, which can leak GPU memory and background threads into later tests in the same pytest worker.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/disaggregated/test_mamba_transfer.py` around lines 329 - 455, The test shutdown steps are not executed on exceptions; wrap the teardown in a finally block so resources are always released: move the cleanup that calls mgr.shutdown() for all ctx_mgrs and gen_mgrs and the transfer_worker.shutdown() for ctx_tcs/gen_tcs into a finally clause at the end of run_mamba_transfer_test, ensuring ctx_mgrs, gen_mgrs, ctx_tcs and gen_tcs are visible to the finally block (initialize them before try) and guard transfer_worker shutdown with hasattr/None checks as currently done.tensorrt_llm/_torch/disaggregation/resource/page.py (1)
143-144: Consider extracting exception messages to constants (optional).Static analysis (TRY003) flags the inline exception messages. While functional, extracting these to a class-level constant or using a dedicated exception subclass would be cleaner.
Example refactor
class InvalidLayerGroupError(ValueError): """Raised when a LayerGroup has neither kv nor mamba configuration.""" pass # Then use: raise InvalidLayerGroupError("LayerGroup must have either kv_head_num_per_rank or mamba_layer_offsets")Also applies to: 173-174
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/disaggregation/resource/page.py` around lines 143 - 144, Replace the inline ValueError("Invalid layer group") raises in the LayerGroup handling code with either a dedicated exception class (e.g., InvalidLayerGroupError subclassing ValueError) or a class/module-level constant message (e.g., INVALID_LAYER_GROUP_MSG) and use that constant when raising; update both occurrences where LayerGroup validation currently raises ValueError (the else branch in the layer-group selection logic and the other similar raise later) to raise InvalidLayerGroupError("LayerGroup must have either kv_head_num_per_rank or mamba_layer_offsets") or raise ValueError(INVALID_LAYER_GROUP_MSG) so all messages are centralized and reusable.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.py`:
- Around line 156-164: The code currently creates a PyNativeCacheTransceiver
even when the MambaCacheManager is the C++ implementation which makes accessing
MambaHybridCacheManager.mamba_cache_index (used in _create_kv_slice) trigger its
internal `not self._use_cpp` assertion; fix by adding a guard: in
get_cache_transceiver() detect if the chosen kv_cache_manager is using C++
(check the manager instance or its `_use_cpp` flag / TRTLLM_USE_CPP_MAMBA env)
and do not instantiate PyNativeCacheTransceiver when C++ Mamba is active, or
alternatively add a validation in PyNativeCacheTransceiver.__init__ that
inspects the passed kv_cache_manager and raises/returns early if `_use_cpp` is
True so _create_kv_slice will never access mamba_cache_index on a C++-backed
manager.
In `@tensorrt_llm/_torch/disaggregation/resource/page.py`:
- Around line 132-142: The branch handling mamba mode calls
self.conv_states.to_dict() and self.ssm_states.to_dict() without guarding for
None even though those attributes are Optional; update the to_dict() branch
where mamba_layer_offsets is not None to either validate/invariant in
__post_init__ that conv_states and ssm_states are non-None or (preferred) add
defensive None checks: replace direct .to_dict() calls with conditional
expressions that return None or empty dict if conv_states or ssm_states is None
(e.g., use conv_states.to_dict() if conv_states is not None else None),
referencing the mamba_layer_offsets branch and the conv_states/ssm_states
attributes to locate the change.
In `@tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py`:
- Around line 396-439: The bug is that _prepare_mamba_cache_blocks (triggered by
add_dummy_requests) can rewrite self.state_indices but
get_state_indices(request_ids, is_padding) only returns a computed mapping
without persisting it, while update_mamba_states still reads self.state_indices
causing mis-synced slot assignments. Fix by making get_state_indices persist the
computed mapping back into self.state_indices when request_ids and is_padding
are provided: compute result as currently done, then assign self.state_indices =
torch.tensor(result, dtype=self.state_indices.dtype,
device=self.state_indices.device) (or keep list form only if callers expect a
list), and then return the value; ensure you use self.mamba_cache_index for
lookups and preserve existing tensor dtype/device to avoid device/type
mismatches.
- Around line 221-227: Before publishing conv_section_dims, assert that each
section is exactly divisible by tp_size instead of only checking the totals:
verify d_inner, n_groups * d_state, conv_dim, and nheads are divisible by
tp_size; if any are not, raise a clear error (or assert) indicating which value
and its expected divisibility. Update the block that computes d_inner_local,
ng_ds_local, conv_dim, and nheads (which then sets self.conv_section_dims) to
perform these divisibility checks prior to integer division so conv_section_dims
accurately reflects the local slot layout.
---
Outside diff comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Line 1: Update the file header copyright year range from "2022-2024" to
include 2026 (e.g., "2022-2026") at the top of the file; modify the
SPDX/FileCopyrightText comment line in
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py so the year range reflects the
latest modification year.
---
Nitpick comments:
In `@tensorrt_llm/_torch/disaggregation/resource/page.py`:
- Around line 143-144: Replace the inline ValueError("Invalid layer group")
raises in the LayerGroup handling code with either a dedicated exception class
(e.g., InvalidLayerGroupError subclassing ValueError) or a class/module-level
constant message (e.g., INVALID_LAYER_GROUP_MSG) and use that constant when
raising; update both occurrences where LayerGroup validation currently raises
ValueError (the else branch in the layer-group selection logic and the other
similar raise later) to raise InvalidLayerGroupError("LayerGroup must have
either kv_head_num_per_rank or mamba_layer_offsets") or raise
ValueError(INVALID_LAYER_GROUP_MSG) so all messages are centralized and
reusable.
In `@tests/unittest/disaggregated/test_mamba_transfer.py`:
- Around line 329-455: The test shutdown steps are not executed on exceptions;
wrap the teardown in a finally block so resources are always released: move the
cleanup that calls mgr.shutdown() for all ctx_mgrs and gen_mgrs and the
transfer_worker.shutdown() for ctx_tcs/gen_tcs into a finally clause at the end
of run_mamba_transfer_test, ensuring ctx_mgrs, gen_mgrs, ctx_tcs and gen_tcs are
visible to the finally block (initialize them before try) and guard
transfer_worker shutdown with hasattr/None checks as currently done.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0c8dd4c9-4b86-4b04-b8ba-c22058c737d6
📒 Files selected for processing (14)
tensorrt_llm/_torch/disaggregation/base/transfer.pytensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.pytensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.pytensorrt_llm/_torch/disaggregation/native/transfer.pytensorrt_llm/_torch/disaggregation/resource/kv_extractor.pytensorrt_llm/_torch/disaggregation/resource/page.pytensorrt_llm/_torch/modules/mamba/mamba2_metadata.pytensorrt_llm/_torch/modules/mamba/mamba2_mixer.pytensorrt_llm/_torch/pyexecutor/mamba_cache_manager.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.pytests/integration/defs/accuracy/test_disaggregated_serving.pytests/integration/test_lists/test-db/l0_a10.ymltests/integration/test_lists/test-db/l0_dgx_b200.ymltests/unittest/disaggregated/test_mamba_transfer.py
tensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.py
Outdated
Show resolved
Hide resolved
| d_inner_local = d_inner // tp_size | ||
| ng_ds_local = n_groups * d_state // tp_size | ||
| conv_dim = conv_dim // tp_size | ||
| nheads = nheads // tp_size | ||
|
|
||
| # Per-section dims for conv_state: [x | B | C] | ||
| self.conv_section_dims = [d_inner_local, ng_ds_local, ng_ds_local] |
There was a problem hiding this comment.
Assert per-section TP divisibility before publishing conv_section_dims.
conv_dim % tp_size == 0 is weaker than what this new metadata needs. If only the total is divisible, ng_ds_local = n_groups * d_state // tp_size silently truncates and the section list no longer matches the local slot layout, which will mis-fragment conv-state transfers on TP-mismatched runs.
Suggested guard
assert nheads % tp_size == 0, "nheads must be divisible by tp_size"
assert conv_dim % tp_size == 0, "conv_dim must be divisible by tp_size"
+ assert (n_groups * d_state) % tp_size == 0, (
+ "n_groups * d_state must be divisible by tp_size"
+ )
# partition conv_dim and nheads
d_inner_local = d_inner // tp_size
ng_ds_local = n_groups * d_state // tp_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py` around lines 221 -
227, Before publishing conv_section_dims, assert that each section is exactly
divisible by tp_size instead of only checking the totals: verify d_inner,
n_groups * d_state, conv_dim, and nheads are divisible by tp_size; if any are
not, raise a clear error (or assert) indicating which value and its expected
divisibility. Update the block that computes d_inner_local, ng_ds_local,
conv_dim, and nheads (which then sets self.conv_section_dims) to perform these
divisibility checks prior to integer division so conv_section_dims accurately
reflects the local slot layout.
|
PR_Github #38698 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #38820 [ run ] triggered by Bot. Commit: |
|
PR_Github #38820 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
937e336 to
e20c425
Compare
|
/bot kill |
|
/bot run --disable-fail-fast |
|
PR_Github #38986 [ kill ] triggered by Bot. Commit: |
|
PR_Github #38986 [ kill ] completed with state |
|
/bot run |
|
PR_Github #40184 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #40210 [ run ] triggered by Bot. Commit: |
|
PR_Github #40210 [ run ] completed with state
|
|
/bot run |
|
PR_Github #40281 [ run ] triggered by Bot. Commit: |
|
PR_Github #40281 [ run ] completed with state
|
|
/bot run |
|
PR_Github #40336 [ run ] triggered by Bot. Commit: |
4a4c54d to
ed03704
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #40345 [ run ] triggered by Bot. Commit: |
|
PR_Github #40336 [ run ] completed with state |
|
PR_Github #40345 [ run ] completed with state |
|
/bot kill |
…ron-super v3 Add SSM/Mamba state transfer support for disaggregated serving of hybrid Nemotron models. Key changes: - Add MambaLayerGroup to resource/page.py for SSM state layout description - Add MambaPolicy and region mappers in native/mixers/ssm/peer.py for TP-aware Mamba state fragment building - Extend RecvReqInfo, KVSlice with mamba_state_index for slot identification - Update Sender/Receiver in native/transfer.py to handle mamba fragments - Update KvCacheTransceiverV2 (transceiver.py) with mamba layer count exchange and MambaLayerGroup-aware KV slice creation - Update get_unique_pool_memory_descs in resource/utils.py for mamba pools - Adapt mamba_cache_manager, kv_extractor, py_executor_creator for disaggregated hybrid model support - Add comprehensive unit tests in test_mamba_transfer.py Signed-off-by: Bo Deng <deemod@nvidia.com>
ed03704 to
3774a6a
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #40533 [ run ] triggered by Bot. Commit: |
|
PR_Github #40533 [ run ] completed with state |
|
/bot run |
|
PR_Github #40537 [ run ] triggered by Bot. Commit: |
|
PR_Github #40537 [ run ] completed with state |
|
/bot run |
|
PR_Github #40538 [ run ] triggered by Bot. Commit: |
|
PR_Github #40538 [ run ] completed with state |
|
/bot run |
|
PR_Github #40542 [ run ] triggered by Bot. Commit: |
|
PR_Github #40542 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.