@@ -59,22 +59,30 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
5959 auto * blocks = sequence->kv_state ().mutable_kv_blocks ();
6060 auto * host_blocks = sequence->host_kv_state ().mutable_kv_blocks ();
6161
62- if (blocks->size () == 0 || host_blocks->size () > blocks->size ()) {
62+ if (host_blocks->size () >= blocks->size ()) {
63+ host_block_managers_[dp_rank]->deallocate (
64+ sequence->host_kv_state ().kv_blocks ());
65+ block_managers_[dp_rank]->deallocate (sequence->kv_state ().kv_blocks ());
66+ sequence->reset ();
6367 return ;
6468 }
6569
66- size_t cached_block_num =
70+ size_t cached_host_block_num =
6771 sequence->host_kv_state ().kv_cache_tokens_num () / options_.block_size ();
72+ size_t cached_device_block_num =
73+ sequence->kv_state ().kv_cache_tokens_num () / options_.block_size ();
6874
69- size_t needed_block_num =
70- sequence->num_tokens () / options_.block_size () - host_blocks->size ();
75+ size_t needed_block_num = cached_device_block_num > host_blocks->size ()
76+ ? cached_device_block_num - host_blocks->size ()
77+ : 0 ;
7178
79+ // allocate additional host blocks for copy
7280 if (needed_block_num != 0 ) {
7381 sequence->host_kv_state ().add_kv_blocks (
7482 host_block_managers_[dp_rank]->allocate (needed_block_num));
7583 }
7684
77- for (size_t i = cached_block_num ; i < host_blocks->size (); i++) {
85+ for (size_t i = cached_host_block_num ; i < host_blocks->size (); i++) {
7886 if (blocks->at (i).ref_count () != 2 ) {
7987 continue ;
8088 }
@@ -89,7 +97,6 @@ void HierarchyBlockManagerPool::deallocate(Sequence* sequence) {
8997 sequence->host_kv_state ().kv_blocks ());
9098
9199 block_managers_[dp_rank]->deallocate (sequence->kv_state ().kv_blocks ());
92- // release the blocks after prefix cache insertion
93100 sequence->reset ();
94101}
95102
0 commit comments