2525from .._block_radix_tree import Block , RootBlock , UselessBlockError
2626from .._common import (
2727 BAD_PAGE_INDEX ,
28+ DEFAULT_BEAM_INDEX ,
2829 GPU_LEVEL ,
2930 NDEBUG ,
3031 BeamIndex ,
4950 _SharedPageLock ,
5051 batched_lock_to_gpu ,
5152)
52- from .._storage ._config import BufferId
5353from .._storage_manager import StorageManager
5454from .._utils import (
5555 CachedCudaEvent ,
@@ -312,7 +312,7 @@ def beam_width(self, beam_width: BeamIndex) -> None:
312312 # Due to constraints of the current kernels, K/V data blocks and the correspondding quant scale blocks
313313 # share the same indices, so the output for DataRole.KEY_DATA and DataRole.KEY_BLOCK_SCALE are the same.
314314 def get_page_indices (
315- self , layer_group_id : LayerGroupId , beam_id : BeamIndex = BeamIndex ( 0 )
315+ self , layer_group_id : LayerGroupId , beam_id : BeamIndex = DEFAULT_BEAM_INDEX
316316 ) -> IndexSeq :
317317 indices = self ._page_indices [beam_id ][layer_group_id ]
318318 assert NDEBUG or all (
@@ -321,13 +321,31 @@ def get_page_indices(
321321 )
322322 return indices
323323
324- def get_all_page_indices (
325- self , beam_id : BeamIndex , buf_ids : Iterable [BufferId ]
326- ) -> Iterator [IndexSeq ]:
327- layer_to_lc_ids = self .manager ._storage ._layer_to_life_cycle_ids
328- for layer_id , _ in buf_ids :
329- lc = layer_to_lc_ids [layer_id ]
330- yield self ._page_indices [beam_id ][lc ]
324+ def get_aggregated_page_indices (
325+ self ,
326+ layer_group_id : LayerGroupId ,
327+ beam_id : BeamIndex = DEFAULT_BEAM_INDEX ,
328+ valid_only : bool = False ,
329+ ) -> Iterator [int ]:
330+ """
331+ Get the internal slot indices for the given layer group and beam.
332+ Each slot is a group of coalesced buffers in one memory pool group.
333+ This API exposes internal slot indices, mainly for efficient data transfer.
334+ For computation, use get_page_indices() instead.
335+
336+ Args:
337+ layer_group_id: Layer group to inspect.
338+ beam_id: Beam index to read. Defaults to DEFAULT_BEAM_INDEX.
339+
340+ Returns:
341+ Aggregated page index for each block, or BAD_PAGE_INDEX for invalid blocks.
342+ """
343+ for b in self ._blocks :
344+ if (holder := b .pages [beam_id ][layer_group_id ]) is None :
345+ if not valid_only :
346+ yield BAD_PAGE_INDEX
347+ else :
348+ yield holder .page .slot_id
331349
332350 # reserve space for next inference. Request new blocks from KVCacheManager if necessary.
333351 # if capacity is increased and beam_width > 1, blocks containing new tokens should be allocated for each beam.
@@ -608,7 +626,7 @@ def _commit_block(self, ordinal: BlockOrdinal, is_last: bool) -> None:
608626 )
609627 seq_block = self ._blocks [ordinal ]
610628 assert typed_len (seq_block .pages ) == 1 , "Must have 1 beam only"
611- beam_idx = BeamIndex ( 0 )
629+ beam_idx = DEFAULT_BEAM_INDEX
612630 beam_block = seq_block .pages [beam_idx ]
613631 tokens_per_block = self .tokens_per_block
614632 start = ordinal * tokens_per_block
@@ -756,7 +774,7 @@ def _get_tree_block(self, ordinal: BlockOrdinal) -> Block:
756774 assert self ._blocks [ordinal ].is_committed
757775 ret = unwrap_optional (self ._blocks [ordinal ].tree_block )
758776 if not NDEBUG :
759- for b in self ._block (ordinal , BeamIndex ( 0 ) ):
777+ for b in self ._block (ordinal , DEFAULT_BEAM_INDEX ):
760778 assert b is None or (isinstance (b .page , CommittedPage ) and b .page .block () is ret )
761779 return ret
762780
@@ -925,7 +943,7 @@ def check_no_page_stale(b: tuple[Block, int]):
925943 ],
926944 )
927945
928- beam_idx = BeamIndex ( 0 )
946+ beam_idx = DEFAULT_BEAM_INDEX
929947 for lc_idx , lc in life_cycles .items ():
930948 stale_start , stale_end = _KVCache ._get_stale_range (
931949 tokens_per_block , get_num_matched_tokens (matched ), lc
@@ -1011,7 +1029,7 @@ def _update_page_index(
10111029 return old
10121030
10131031 def _get_page_indices_ref (
1014- self , lc : LifeCycleId , beam_id : BeamIndex = BeamIndex ( 0 )
1032+ self , lc : LifeCycleId , beam_id : BeamIndex = DEFAULT_BEAM_INDEX
10151033 ) -> Iterator [int | None ]:
10161034 assert beam_id < self .beam_width
10171035 assert self .is_active
0 commit comments