Skip to content

Commit 53cb762

Browse files
authored
[None][feat] New KVCacheManagerV2 APIs for Transceiver (#11003)
Signed-off-by: Yao Yao <lowsfer@users.noreply.github.com>
1 parent 5ff244c commit 53cb762

File tree

18 files changed

+677
-199
lines changed

18 files changed

+677
-199
lines changed

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,3 @@ mistral-common==1.8.6
8080
torchao>=0.14.1
8181
cuda-core
8282
llist
83-
dynamic_path_manager

scripts/build_wheel.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,13 +439,14 @@ def build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=False):
439439
so_file.unlink()
440440

441441
# Build rawref
442-
print("-- Building kv_cache_manager_v2 rawref extension...")
442+
print("-- Building kv_cache_manager_v2 rawref extension...", end=" ")
443443
rawref_dir = kv_cache_mgr_dir / "rawref"
444444
build_run(f'"{venv_python}" setup.py build_ext --inplace', cwd=rawref_dir)
445+
print("Done")
445446

446447
if use_mypyc:
447448
# Build mypyc
448-
print("-- Building kv_cache_manager_v2 mypyc extensions...")
449+
print("-- Building kv_cache_manager_v2 mypyc extensions...", end=" ")
449450
# setup_mypyc.py is in kv_cache_manager_v2 but executed from runtime dir
450451
setup_mypyc = kv_cache_mgr_dir / "setup_mypyc.py"
451452
build_run(f'"{venv_python}" "{setup_mypyc}" build_ext --inplace',
@@ -456,6 +457,8 @@ def build_kv_cache_manager_v2(project_dir, venv_python, use_mypyc=False):
456457
raise RuntimeError(
457458
"Failed to build kv_cache_manager_v2: no shared library generated."
458459
)
460+
print("Done")
461+
print("-- Done building kv_cache_manager_v2.")
459462

460463

461464
def main(*,

tensorrt_llm/runtime/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,31 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import os
16+
import sys
17+
from contextlib import contextmanager
18+
from typing import Iterator
19+
20+
21+
# Duplicated from kv_cache_manager_v2._utils. We need this both inside and outside of
22+
# kv_cache_manager_v2 due to restriction of mypyc build process.
23+
@contextmanager
24+
def temporary_sys_path(path: str) -> Iterator[None]:
25+
already_in_path = path in sys.path
26+
if not already_in_path:
27+
sys.path.insert(0, path)
28+
try:
29+
yield
30+
finally:
31+
if not already_in_path:
32+
sys.path.remove(path)
1633

17-
from dynamic_path_manager import DynamicPathManager
1834

1935
# Add current directory to sys.path so kv_cache_manager_v2 can be imported as top-level package.
2036
# This is required because when kv_cache_manager_v2 is compiled with mypyc, it is compiled as
2137
# a top-level package (to avoid complex build paths), but at runtime it is used as a submodule.
2238
# The compiled extension might try to import its submodules using absolute imports based on its
2339
# compiled name.
24-
with DynamicPathManager(os.path.dirname(os.path.abspath(__file__)),
25-
clear_cache=False):
40+
with temporary_sys_path(os.path.dirname(os.path.abspath(__file__))):
2641
import kv_cache_manager_v2
2742

2843
from .enc_dec_model_runner import EncDecModelRunner

tensorrt_llm/runtime/kv_cache_manager_v2/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,16 @@
3636
HostCacheTierConfig,
3737
KVCacheManagerConfig,
3838
)
39-
from ._core import BeamIndex, KVCacheManager, _KVCache
39+
from ._core import (
40+
DEFAULT_BEAM_INDEX,
41+
AggregatedPageDesc,
42+
BeamIndex,
43+
BufferSlice,
44+
KVCacheManager,
45+
_KVCache,
46+
)
4047
from ._life_cycle_registry import LayerGroupId, LifeCycleId
48+
from ._storage import BufferId
4149

4250
__all__ = [
4351
"LifeCycleId",
@@ -47,6 +55,7 @@
4755
"KVCacheManager",
4856
"_KVCache",
4957
"BeamIndex",
58+
"DEFAULT_BEAM_INDEX",
5059
"LayerId",
5160
"Priority",
5261
"CacheLevel",
@@ -64,4 +73,7 @@
6473
"CacheTierConfig",
6574
"gen_multi_modal_tokens",
6675
"rawref",
76+
"BufferSlice",
77+
"AggregatedPageDesc",
78+
"BufferId",
6779
]

tensorrt_llm/runtime/kv_cache_manager_v2/__init__.pyi

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ from typing import (
2323
Final,
2424
Iterable,
2525
Iterator,
26+
NamedTuple,
2627
NewType,
2728
Protocol,
2829
Sequence,
@@ -33,6 +34,7 @@ from typing import (
3334

3435
# From _common.py
3536
NDEBUG: Final[int]
37+
DEFAULT_BEAM_INDEX: Final[BeamIndex]
3638

3739
class CacheTier(enum.IntEnum):
3840
GPU_MEM = 0
@@ -49,6 +51,7 @@ CudaStream = NewType("CudaStream", int)
4951
BeamIndex = NewType("BeamIndex", int)
5052
MemAddress = NewType("MemAddress", int)
5153
Priority = NewType("Priority", int)
54+
PoolGroupIndex = NewType("PoolGroupIndex", int)
5255

5356
# From _config.py
5457
DataRole = NewType("DataRole", str)
@@ -154,9 +157,12 @@ class _KVCache:
154157
@beam_width.setter
155158
def beam_width(self, beam_width: BeamIndex) -> None: ...
156159
def get_page_indices(self, layer_group_id: int, beam_id: BeamIndex = ...) -> IndexSeq: ...
157-
def get_all_page_indices(
158-
self, beam_id: BeamIndex, buf_ids: Iterable[tuple[LayerId, DataRole]]
159-
) -> Iterator[IndexSeq]: ...
160+
def get_aggregated_page_indices(
161+
self,
162+
layer_group_id: LayerGroupId,
163+
beam_id: BeamIndex = DEFAULT_BEAM_INDEX,
164+
valid_only: bool = False,
165+
) -> Iterator[int]: ...
160166
def resize(self, capacity: int | None, history_length: int | None = None) -> bool: ...
161167
@property
162168
def capacity(self) -> int: ...
@@ -183,6 +189,39 @@ class _KVCache:
183189
@property
184190
def tokens_per_block(self) -> int: ...
185191

192+
@dataclass(slots=True, frozen=True)
193+
class MemoryPoolDesc:
194+
base: MemAddress
195+
page_size: int
196+
197+
@dataclass(slots=True, frozen=True)
198+
class MemoryPoolGroupDesc:
199+
num_pages: int
200+
pools: Sequence[MemoryPoolDesc]
201+
202+
class BufferId(NamedTuple):
203+
layer_id: LayerId
204+
role: DataRole
205+
206+
@dataclass(slots=True, frozen=True)
207+
class BufferSlice:
208+
buffer_id: BufferId
209+
num_slices: int = 1
210+
slice_index: int = 1
211+
212+
@dataclass(slots=True, frozen=True)
213+
class AggregatedPageDesc:
214+
"""The data you need would be in the following byte ranges.
215+
216+
(base + stride * i + Range(0, size) for i in aggregated_page_indices)
217+
"""
218+
219+
base: MemAddress
220+
size: int
221+
stride: int
222+
layer_group_id: LayerGroupId
223+
buffers: Sequence[BufferSlice]
224+
186225
# From _core/_kv_cache_manager.py
187226
class KVCacheManager:
188227
def __init__(self, config: KVCacheManagerConfig) -> None: ...
@@ -200,14 +239,23 @@ class KVCacheManager:
200239
def resize(self, cache_level: CacheLevel, quota: int, best_efforts: bool = False) -> bool: ...
201240
def get_quota(self, cache_level: CacheLevel) -> int: ...
202241
@property
203-
def cache_tier_list(self) -> tuple[CacheTier, ...]: ...
242+
def cache_tier_list(self) -> Sequence[CacheTier]: ...
204243
@property
205244
def tokens_per_block(self) -> int: ...
206245
@property
207246
def allow_seq_rebasing(self) -> bool: ...
208247
@property
209248
def enable_partial_match(self) -> bool: ...
210-
def get_layer_group_id(self, layer_id: LayerId) -> int: ...
211249
@property
212-
def layer_grouping(self) -> tuple[tuple[LayerId, ...], ...]: ...
250+
def num_layers(self) -> int: ...
251+
@property
252+
def layer_ids(self) -> Iterator[LayerId]: ...
253+
def get_layer_group_id(self, layer_id: LayerId) -> LayerGroupId: ...
254+
@property
255+
def layer_grouping(self) -> Sequence[Sequence[LayerId]]: ...
256+
@property
257+
def all_buffer_ids(self) -> Iterator[BufferId]: ...
258+
def get_aggregated_pages(
259+
self, buffers: Iterable[BufferSlice]
260+
) -> Iterator[AggregatedPageDesc]: ...
213261
def clamp_max_seq_len_for_mem(self, batch_size: int, model_max_seq_len: int) -> int: ...

tensorrt_llm/runtime/kv_cache_manager_v2/_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class CacheTier(enum.IntEnum):
5656
CudaStream = NewType("CudaStream", int)
5757

5858
BeamIndex = NewType("BeamIndex", int)
59+
DEFAULT_BEAM_INDEX: Final[BeamIndex] = BeamIndex(0)
5960

6061
UserId = NewType("UserId", int)
6162

tensorrt_llm/runtime/kv_cache_manager_v2/_copy_engine.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@
2626
from typing import ClassVar, NamedTuple, Sequence, cast
2727

2828
import cuda.bindings.driver as drv
29-
from dynamic_path_manager import DynamicPathManager
3029

3130
from ._common import Address, CacheTier, CudaStream, MemAddress
32-
from ._utils import CachedCudaEvent, HomoTuple, HostMem, _unwrap, div_up, stream_wait_events
31+
from ._utils import (
32+
CachedCudaEvent,
33+
HomoTuple,
34+
HostMem,
35+
_unwrap,
36+
div_up,
37+
stream_wait_events,
38+
temporary_sys_path,
39+
)
3340

3441
if "tensorrt_llm" in sys.modules:
3542
from tensorrt_llm.bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( # noqa # type: ignore
@@ -50,7 +57,7 @@
5057
# fast path for dev, avoids importing the whole tensorrt_llm module
5158
spec = find_spec("kv_cache_manager_v2")
5259
assert spec is not None and spec.origin is not None
53-
with DynamicPathManager(str(Path(spec.origin).parent.parent.parent), clear_cache=False):
60+
with temporary_sys_path(str(Path(spec.origin).parent.parent.parent)):
5461
from bindings.internal.batch_manager.kv_cache_manager_v2_utils import ( # noqa
5562
DiskAddress,
5663
DiskToDiskTask,

tensorrt_llm/runtime/kv_cache_manager_v2/_core/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from .._common import BeamIndex
16+
from .._common import DEFAULT_BEAM_INDEX, BeamIndex
1717
from ._kv_cache import _KVCache
18-
from ._kv_cache_manager import KVCacheManager
18+
from ._kv_cache_manager import AggregatedPageDesc, BufferSlice, KVCacheManager
1919

20-
__all__ = ["KVCacheManager", "_KVCache", "BeamIndex"]
20+
__all__ = [
21+
"KVCacheManager",
22+
"_KVCache",
23+
"BeamIndex",
24+
"DEFAULT_BEAM_INDEX",
25+
"BufferSlice",
26+
"AggregatedPageDesc",
27+
]

tensorrt_llm/runtime/kv_cache_manager_v2/_core/_kv_cache.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .._block_radix_tree import Block, RootBlock, UselessBlockError
2626
from .._common import (
2727
BAD_PAGE_INDEX,
28+
DEFAULT_BEAM_INDEX,
2829
GPU_LEVEL,
2930
NDEBUG,
3031
BeamIndex,
@@ -49,7 +50,6 @@
4950
_SharedPageLock,
5051
batched_lock_to_gpu,
5152
)
52-
from .._storage._config import BufferId
5353
from .._storage_manager import StorageManager
5454
from .._utils import (
5555
CachedCudaEvent,
@@ -312,7 +312,7 @@ def beam_width(self, beam_width: BeamIndex) -> None:
312312
# Due to constraints of the current kernels, K/V data blocks and the correspondding quant scale blocks
313313
# share the same indices, so the output for DataRole.KEY_DATA and DataRole.KEY_BLOCK_SCALE are the same.
314314
def get_page_indices(
315-
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = BeamIndex(0)
315+
self, layer_group_id: LayerGroupId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
316316
) -> IndexSeq:
317317
indices = self._page_indices[beam_id][layer_group_id]
318318
assert NDEBUG or all(
@@ -321,13 +321,31 @@ def get_page_indices(
321321
)
322322
return indices
323323

324-
def get_all_page_indices(
325-
self, beam_id: BeamIndex, buf_ids: Iterable[BufferId]
326-
) -> Iterator[IndexSeq]:
327-
layer_to_lc_ids = self.manager._storage._layer_to_life_cycle_ids
328-
for layer_id, _ in buf_ids:
329-
lc = layer_to_lc_ids[layer_id]
330-
yield self._page_indices[beam_id][lc]
324+
def get_aggregated_page_indices(
325+
self,
326+
layer_group_id: LayerGroupId,
327+
beam_id: BeamIndex = DEFAULT_BEAM_INDEX,
328+
valid_only: bool = False,
329+
) -> Iterator[int]:
330+
"""
331+
Get the internal slot indices for the given layer group and beam.
332+
Each slot is a group of coalesced buffers in one memory pool group.
333+
This API exposes internal slot indices, mainly for efficient data transfer.
334+
For computation, use get_page_indices() instead.
335+
336+
Args:
337+
layer_group_id: Layer group to inspect.
338+
beam_id: Beam index to read. Defaults to DEFAULT_BEAM_INDEX.
339+
340+
Returns:
341+
Aggregated page index for each block, or BAD_PAGE_INDEX for invalid blocks.
342+
"""
343+
for b in self._blocks:
344+
if (holder := b.pages[beam_id][layer_group_id]) is None:
345+
if not valid_only:
346+
yield BAD_PAGE_INDEX
347+
else:
348+
yield holder.page.slot_id
331349

332350
# reserve space for next inference. Request new blocks from KVCacheManager if necessary.
333351
# if capacity is increased and beam_width > 1, blocks containing new tokens should be allocated for each beam.
@@ -608,7 +626,7 @@ def _commit_block(self, ordinal: BlockOrdinal, is_last: bool) -> None:
608626
)
609627
seq_block = self._blocks[ordinal]
610628
assert typed_len(seq_block.pages) == 1, "Must have 1 beam only"
611-
beam_idx = BeamIndex(0)
629+
beam_idx = DEFAULT_BEAM_INDEX
612630
beam_block = seq_block.pages[beam_idx]
613631
tokens_per_block = self.tokens_per_block
614632
start = ordinal * tokens_per_block
@@ -756,7 +774,7 @@ def _get_tree_block(self, ordinal: BlockOrdinal) -> Block:
756774
assert self._blocks[ordinal].is_committed
757775
ret = unwrap_optional(self._blocks[ordinal].tree_block)
758776
if not NDEBUG:
759-
for b in self._block(ordinal, BeamIndex(0)):
777+
for b in self._block(ordinal, DEFAULT_BEAM_INDEX):
760778
assert b is None or (isinstance(b.page, CommittedPage) and b.page.block() is ret)
761779
return ret
762780

@@ -925,7 +943,7 @@ def check_no_page_stale(b: tuple[Block, int]):
925943
],
926944
)
927945

928-
beam_idx = BeamIndex(0)
946+
beam_idx = DEFAULT_BEAM_INDEX
929947
for lc_idx, lc in life_cycles.items():
930948
stale_start, stale_end = _KVCache._get_stale_range(
931949
tokens_per_block, get_num_matched_tokens(matched), lc
@@ -1011,7 +1029,7 @@ def _update_page_index(
10111029
return old
10121030

10131031
def _get_page_indices_ref(
1014-
self, lc: LifeCycleId, beam_id: BeamIndex = BeamIndex(0)
1032+
self, lc: LifeCycleId, beam_id: BeamIndex = DEFAULT_BEAM_INDEX
10151033
) -> Iterator[int | None]:
10161034
assert beam_id < self.beam_width
10171035
assert self.is_active

0 commit comments

Comments
 (0)