@@ -186,11 +186,6 @@ def prepare(self) -> None:
186186 assert self .request_ids is not None
187187 block_ids_per_seq = self .kv_cache_manager .get_batch_cache_indices (
188188 self .request_ids )
189- paged_kv_indices = torch .tensor (
190- [x for block_ids in block_ids_per_seq for x in block_ids ],
191- dtype = torch .int32 )
192- self ._paged_kv_indices [:paged_kv_indices .size (0 )].copy_ (
193- paged_kv_indices , non_blocking = True )
194189
195190 # number of tokens in the kv cache for each sequence in the batch
196191 cached_token_lens = torch .tensor (
@@ -212,13 +207,26 @@ def prepare(self) -> None:
212207 1 ])
213208
214209 # number of cache blocks used by each sequence in the cache
215- self .num_blocks = [len (block_ids ) for block_ids in block_ids_per_seq ]
210+ # NOTE: do not use len(block_ids) - that will give you a number
211+ # that can be too big if using chunked prefill/kv cache reuse
212+ # since we allocate all blocks ahead of time.
213+ num_blocks = ((kv_lens + self .page_size - 1 ) // self .page_size )
214+ self .num_blocks = num_blocks .tolist ()
216215 self .num_context_blocks = sum (self .num_blocks [:self .num_contexts ])
217216 self .num_generation_blocks = sum (self .num_blocks [self .num_contexts :])
218217
218+ paged_kv_indices_list = []
219+ for i , block_ids in enumerate (block_ids_per_seq ):
220+ paged_kv_indices_list .extend (block_ids [:self .num_blocks [i ]])
221+
222+ paged_kv_indices = torch .tensor (paged_kv_indices_list ,
223+ dtype = torch .int32 )
224+
225+ self ._paged_kv_indices [:paged_kv_indices .size (0 )].copy_ (
226+ paged_kv_indices , non_blocking = True )
227+
219228 # number of tokens in the last cache block used by each sequence
220- paged_kv_last_page_len = kv_lens - (torch .Tensor (
221- self .num_blocks ).int ().cuda (non_blocking = True ) - 1 ) * self .page_size
229+ paged_kv_last_page_len = kv_lens - (num_blocks - 1 ) * self .page_size
222230 self ._paged_kv_last_page_len [:paged_kv_last_page_len .size (0 )].copy_ (
223231 paged_kv_last_page_len , non_blocking = True )
224232
0 commit comments