Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/disaggregation/base/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class KVSlice:
default_factory=list
) # Physical block IDs per layer group
is_last_slice: bool = False
mamba_state_index: Optional[int] = None


class SessionStatus(Enum):
Expand Down
193 changes: 192 additions & 1 deletion tensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import List
from typing import Dict, List, Optional, Tuple

from tensorrt_llm._torch.disaggregation.base.region import (
MemRegionGroup,
RegionMapperBase,
SpecRegion,
SpecRegionPair,
)
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
from tensorrt_llm._torch.disaggregation.resource.page import (
KVCachePageTable,
MambaLayerGroup,
PhysicalPool,
)


class MambaHeadMatchMapper(RegionMapperBase):
Expand Down Expand Up @@ -292,3 +298,188 @@ def _compute_tp_offsets(
else:
# peer has fewer ranks -> larger chunk; self selects sub-chunk
return 0, (self_tp_rank % ratio) * transfer_bytes


class MambaPolicy:
"""
Dispatch mappers and build frags for Mamba state transfer.
"""

@staticmethod
def _mamba_tp(ri: RankInfo) -> Tuple[int, int]:
"""Return (mamba_effective_tp_size, mamba_effective_tp_rank).

When attention_dp is enabled, mamba is not TP-sharded.
"""
if ri.attention and ri.attention.enable_attention_dp:
return 1, 0
return ri.tp_size, ri.tp_rank

@staticmethod
def _build_layer_ptrs(
pool: PhysicalPool,
layer_offsets: Dict[int, int],
overlapping_layers: List[int],
slot: int,
) -> List[int]:
"""Build per-layer pointers for a given pool (conv or ssm) and slot."""
ptrs = []
for glid in overlapping_layers:
lid = layer_offsets[glid]
ptrs.append(
pool.base_address + lid * pool.num_slots * pool.slot_bytes + slot * pool.slot_bytes
)
return ptrs

@staticmethod
def _select_mapper(
*,
is_conv: bool,
tp_match: bool,
transfer_layers: int,
self_mlg: MambaLayerGroup,
peer_mlg: MambaLayerGroup,
self_pool: PhysicalPool,
peer_pool: PhysicalPool,
self_mamba_tp: int,
peer_mamba_tp: int,
self_mamba_tp_rank: int,
peer_mamba_tp_rank: int,
) -> RegionMapperBase:
"""Select the appropriate mapper for a conv/ssm pool pair."""
if tp_match:
return MambaHeadMatchMapper(
transfer_layers=transfer_layers,
src_layer_off=0,
dst_layer_off=0,
block_bytes_per_layer=self_pool.slot_bytes,
)
if is_conv:
return ConvStateMismatchMapper(
transfer_layers=transfer_layers,
src_layer_off=0,
dst_layer_off=0,
self_section_bytes=self_mlg.conv_section_bytes,
peer_section_bytes=peer_mlg.conv_section_bytes,
self_tp_per_dp=self_mamba_tp,
peer_tp_per_dp=peer_mamba_tp,
self_tp_rank=self_mamba_tp_rank,
peer_tp_rank=peer_mamba_tp_rank,
)
# SSM state: head-level granularity
self_nheads = self_pool.slot_bytes // self_mlg.ssm_bytes_per_head
peer_nheads = peer_pool.slot_bytes // peer_mlg.ssm_bytes_per_head
return MambaHeadMismatchMapper(
transfer_layers=transfer_layers,
src_layer_off=0,
dst_layer_off=0,
bytes_per_head=self_mlg.ssm_bytes_per_head,
self_nheads=self_nheads,
peer_nheads=peer_nheads,
self_tp_per_dp=self_mamba_tp,
peer_tp_per_dp=peer_mamba_tp,
self_tp_rank=self_mamba_tp_rank,
peer_tp_rank=peer_mamba_tp_rank,
)

@staticmethod
def build_mamba_frags(
self_mlg: MambaLayerGroup,
peer_mlg: MambaLayerGroup,
src_slot: int,
dst_slot: int,
self_ri: RankInfo,
peer_ri: RankInfo,
) -> Tuple[List[int], List[int], List[int]]:
"""Build (src_frags, dst_frags, kv_sizes) for mamba state transfer.

Returns empty lists if there are no overlapping layers.
"""
overlapping_layers = sorted(
set(self_mlg.mamba_layer_offsets.keys()) & set(peer_mlg.mamba_layer_offsets.keys())
)
transfer_layers = len(overlapping_layers)
if transfer_layers == 0:
return [], [], []

self_mamba_tp, self_mamba_tp_rank = MambaPolicy._mamba_tp(self_ri)
peer_mamba_tp, peer_mamba_tp_rank = MambaPolicy._mamba_tp(peer_ri)
tp_match = self_mamba_tp == peer_mamba_tp

src_frags: List[int] = []
dst_frags: List[int] = []
kv_sizes: List[int] = []

for self_pool, peer_pool, is_conv in [
(self_mlg.conv_states, peer_mlg.conv_states, True),
(self_mlg.ssm_states, peer_mlg.ssm_states, False),
]:
src_ptrs = MambaPolicy._build_layer_ptrs(
self_pool, self_mlg.mamba_layer_offsets, overlapping_layers, src_slot
)
dst_ptrs = MambaPolicy._build_layer_ptrs(
peer_pool, peer_mlg.mamba_layer_offsets, overlapping_layers, dst_slot
)

src_region = SpecRegion(
memory=MemRegionGroup(ptrs=src_ptrs, bytes_per_region=self_pool.slot_bytes)
)
dst_region = SpecRegion(
memory=MemRegionGroup(ptrs=dst_ptrs, bytes_per_region=peer_pool.slot_bytes)
)

mapper = MambaPolicy._select_mapper(
is_conv=is_conv,
tp_match=tp_match,
transfer_layers=transfer_layers,
self_mlg=self_mlg,
peer_mlg=peer_mlg,
self_pool=self_pool,
peer_pool=peer_pool,
self_mamba_tp=self_mamba_tp,
peer_mamba_tp=peer_mamba_tp,
self_mamba_tp_rank=self_mamba_tp_rank,
peer_mamba_tp_rank=peer_mamba_tp_rank,
)

region_pair = mapper.map(src_region, dst_region)
region_pairs = region_pair if isinstance(region_pair, list) else [region_pair]
for rp in region_pairs:
src_frags.extend(rp.src.memory.ptrs)
dst_frags.extend(rp.dst.memory.ptrs)
frag_size = rp.src.memory.bytes_per_region
kv_sizes.extend([frag_size] * len(rp.src.memory.ptrs))

return src_frags, dst_frags, kv_sizes

@staticmethod
def collect_frags(
self_page_table: KVCachePageTable,
peer_page_table: KVCachePageTable,
src_slot: Optional[int],
dst_slot: Optional[int],
self_ri: RankInfo,
peer_ri: RankInfo,
) -> Tuple[List[int], List[int], List[int]]:
"""Find mamba layer groups from page tables and build transfer frags.

Returns (src_frags, dst_frags, kv_sizes) — all empty if not applicable.
"""
self_mlg = next(
(lg for lg in self_page_table.layer_groups if isinstance(lg, MambaLayerGroup)),
None,
)
peer_mlg = next(
(lg for lg in peer_page_table.layer_groups if isinstance(lg, MambaLayerGroup)),
None,
)
if self_mlg is None or peer_mlg is None or src_slot is None or dst_slot is None:
return [], [], []
return MambaPolicy.build_mamba_frags(
self_mlg=self_mlg,
peer_mlg=peer_mlg,
src_slot=src_slot,
dst_slot=dst_slot,
self_ri=self_ri,
peer_ri=peer_ri,
)
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/disaggregation/native/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tensorrt_llm._torch.disaggregation.native.mixers.attention.peer import AttentionPolicy
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
from tensorrt_llm._torch.disaggregation.resource.kv_extractor import KVRegionExtractorV1
from tensorrt_llm._torch.disaggregation.resource.page import AttentionLayerGroup
from tensorrt_llm._torch.disaggregation.resource.utils import (
PoolRole,
get_global_layer_ids,
Expand Down Expand Up @@ -132,6 +133,8 @@ def get_pool_mapping(self, peer_ri: RankInfo) -> Dict[LGPoolKey, LGPoolKey]:
kv_factor = self._ri.attention.kv_factor

for self_lg_idx, self_lg in enumerate(self_pt.layer_groups):
if not isinstance(self_lg, AttentionLayerGroup):
continue
for self_pi, self_pv in enumerate(self_lg.pool_views):
is_indexer = len(self_pv.buffer_entries) == 0
# For INDEXER (empty buffer_entries), use group-level IDs for step-1 lookup
Expand Down
17 changes: 17 additions & 0 deletions tensorrt_llm/_torch/disaggregation/native/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from tensorrt_llm._torch.disaggregation.native.auxiliary import AuxBuffer
from tensorrt_llm._torch.disaggregation.native.messenger import ZMQMessenger, decode_message
from tensorrt_llm._torch.disaggregation.native.mixers.ssm.peer import MambaPolicy
from tensorrt_llm._torch.disaggregation.native.peer import PeerRegistrar
from tensorrt_llm._torch.disaggregation.native.perf_logger import PerfTimer, perf_log_manager
from tensorrt_llm._torch.disaggregation.native.rank_info import RankInfo
Expand Down Expand Up @@ -70,6 +71,7 @@ class RecvReqInfo:
unique_rid: int
start_token_idx: Optional[int] = None
aux_slot: Optional[int] = None
mamba_state_index: Optional[int] = None

def to_bytes(self) -> bytes:
return msgpack.packb(asdict(self))
Expand Down Expand Up @@ -520,6 +522,20 @@ def _build_kv_write_meta(self, task: KVSendTask, req_info: RecvReqInfo) -> Write
frag_size = rp.src.memory.bytes_per_region # type: ignore[attr-defined]
kv_sizes.extend([frag_size] * len(rp.src.memory.ptrs)) # type: ignore[attr-defined]

# handle mamba fragments
m_src, m_dst, m_sizes = MambaPolicy.collect_frags(
self_page_table=extractor.page_table,
peer_page_table=peer_extractor.page_table,
src_slot=task._slice.mamba_state_index,
dst_slot=req_info.mamba_state_index,
self_ri=self._registrar.self_rank_info,
peer_ri=peer_ri,
)
if m_src:
src_frags.extend(m_src)
dst_frags.extend(m_dst)
kv_sizes.extend(m_sizes)

if timer:
timer.record_prepare_args_end(peer_ri.instance_rank)
timer.record_transfer_sizes(peer_ri.instance_rank, sum(kv_sizes), len(dst_frags))
Expand Down Expand Up @@ -975,6 +991,7 @@ def _build_recv_req_info(self, task: KVRecvTask) -> RecvReqInfo:
block_ids_per_layer_groups=task._kv_slice.block_ids_per_layer_groups,
unique_rid=task._unique_rid,
aux_slot=task._aux_slot,
mamba_state_index=task._kv_slice.mamba_state_index,
)

def dispatch_task(self, task: KVRecvTask):
Expand Down
69 changes: 67 additions & 2 deletions tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@
)
from tensorrt_llm._torch.disaggregation.resource.page import (
BUFFER_ENTRY_DTYPE,
AttentionLayerGroup,
KVCachePageTable,
LayerGroup,
LocalLayer,
MambaLayerGroup,
PhysicalPool,
PhysicalPoolGroup,
PoolView,
)
from tensorrt_llm._torch.disaggregation.resource.utils import get_physical_pool
from tensorrt_llm._torch.pyexecutor.mamba_cache_manager import (
MambaHybridCacheManager,
PythonMambaCacheManager,
)
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import get_size_in_bytes
from tensorrt_llm.bindings import DataType
Expand Down Expand Up @@ -81,6 +87,56 @@ def extract(
# ---------------------------------------------------------------------------


def _build_layer_group_for_mamba(
manager: MambaHybridCacheManager, pool_group_idx: int
) -> MambaLayerGroup:
assert isinstance(manager._impl, PythonMambaCacheManager), (
"CppMambaCacheManager is not supported with Python transceiver, please set TRTLLM_USE_CPP_MAMBA=0"
)

mamba_layer_offsets = {
int(global_layer_id): int(local_layer_id)
for global_layer_id, local_layer_id in manager._impl.mamba_layer_offsets.items()
}

conv_state = manager._impl.mamba_cache.conv
ssm_state = manager._impl.mamba_cache.temporal

conv_pool = PhysicalPool(
base_address=conv_state.data_ptr(),
slot_bytes=conv_state.stride(1) * conv_state.element_size(),
num_slots=conv_state.shape[1],
)

ssm_pool = PhysicalPool(
base_address=ssm_state.data_ptr(),
slot_bytes=ssm_state.stride(1) * ssm_state.element_size(),
num_slots=ssm_state.shape[1],
)

# Per-section bytes for conv_state and per-head bytes for ssm_state.
# conv_state layout: [x: d_inner/tp | B: ng*ds/tp | C: ng*ds/tp] x (d_conv-1)
# ssm_state layout: (nheads/tp, head_dim, d_state)
d_conv_m1 = conv_state.shape[3]
conv_elem_size = conv_state.element_size()
conv_section_dims = manager._impl.conv_section_dims
conv_section_bytes = [dim * d_conv_m1 * conv_elem_size for dim in conv_section_dims]

head_dim = ssm_state.shape[3]
d_state = ssm_state.shape[4]
ssm_elem_size = ssm_state.element_size()
ssm_bytes_per_head = head_dim * d_state * ssm_elem_size

return MambaLayerGroup(
pool_group_idx=pool_group_idx,
mamba_layer_offsets=mamba_layer_offsets,
conv_states=conv_pool,
ssm_states=ssm_pool,
conv_section_bytes=conv_section_bytes,
ssm_bytes_per_head=ssm_bytes_per_head,
)


def build_page_table(kv_cache_manager: KVCacheManager) -> KVCachePageTable:
"""Build a KVCachePageTable from a KVCacheManager (V1)."""
if kv_cache_manager.dtype == DataType.NVFP4:
Expand Down Expand Up @@ -164,14 +220,18 @@ def build_page_table(kv_cache_manager: KVCacheManager) -> KVCachePageTable:
for lid in local_layer_ids
]
layer_groups.append(
LayerGroup(
AttentionLayerGroup(
pool_group_idx=group_id,
kv_head_num_per_rank=num_kv_heads,
sliding_window_size=window_size,
local_layers=local_layers,
pool_views=pool_views,
)
)
if isinstance(kv_cache_manager, MambaHybridCacheManager):
mamba_layer_group_idx = len(pool_groups)
mamba_layer_group = _build_layer_group_for_mamba(kv_cache_manager, mamba_layer_group_idx)
layer_groups.append(mamba_layer_group)

return KVCachePageTable(
tokens_per_block=tokens_per_block,
Expand Down Expand Up @@ -352,7 +412,7 @@ def _role_str_to_enum(role: str) -> DataRole:
sliding_window_size = life_cycle.window_size

layer_groups.append(
LayerGroup(
AttentionLayerGroup(
pool_group_idx=storage_pg_to_list_idx[storage_pg_idx],
kv_head_num_per_rank=num_kv_heads,
sliding_window_size=sliding_window_size,
Expand All @@ -361,6 +421,11 @@ def _role_str_to_enum(role: str) -> DataRole:
)
)

if isinstance(manager, MambaHybridCacheManager):
mamba_layer_group_idx = len(pool_groups)
mamba_layer_group = _build_layer_group_for_mamba(manager, mamba_layer_group_idx)
layer_groups.append(mamba_layer_group)

return KVCachePageTable(
tokens_per_block=config.tokens_per_block,
layer_groups=layer_groups,
Expand Down
Loading
Loading