@@ -49,7 +49,7 @@ def __init__(
4949 self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
5050 batch_size , - 1
5151 )
52- self .free_blocks = torch .arange ( self .num_blocks , device = device )
52+ self .free_blocks = torch .ones ([ self .num_blocks ], dtype = torch . int32 , device = device )
5353 self .max_cache_len = max_cache_len
5454 self .num_kv_heads = config .num_key_value_heads
5555 self .num_hidden_layers = config .num_hidden_layers
@@ -88,12 +88,10 @@ def update_for_prefill(
8888 all_slot_offsets = []
8989 num_blocks = (input_lens + self .block_size - 1 ) // self .block_size
9090 for i in range (batch_size ):
91- for b_idx in range (num_blocks [i ]):
92- if self .block_tables [i ][b_idx ] == - 1 :
93- # need a free block
94- self .block_tables [i ][b_idx ] = self .free_blocks [0 ]
95- self .free_blocks = self .free_blocks [1 :]
96-
91+ nb = num_blocks [i ]
92+ block_table = self .free_blocks .nonzero ().view (- 1 )[0 :nb ]
93+ self .block_tables [i ][0 :nb ] = block_table
94+ self .free_blocks [block_table ] = 0
9795 slots_range = torch .arange (input_lens [i ], device = key_states .device )
9896 block_indices = slots_range // self .block_size
9997 slot_offsets = slots_range % self .block_size
@@ -103,7 +101,6 @@ def update_for_prefill(
103101 all_block_indices = torch .cat (all_block_indices )
104102 all_slot_offsets = torch .cat (all_slot_offsets )
105103 self .slots = all_block_indices * self .block_size + all_slot_offsets
106-
107104 # Update the cache
108105 PagedAttention .reshape_and_cache (
109106 key_states ,
@@ -127,16 +124,16 @@ def update_for_decode(
127124 ):
128125 if layer_idx == 0 :
129126 start_block_idx = self ._seen_tokens // self .block_size
130- num_blocks = (self ._seen_tokens + self .block_size ) // self .block_size
131127 slot_offset_in_block = (self ._seen_tokens ) % self .block_size
132128 self .slots = torch .zeros ([batch_size ], device = key_states .device , dtype = torch .int32 )
133129 for i in range (batch_size ):
134- for b_idx in range (start_block_idx [i ], num_blocks [i ]):
130+ if slot_offset_in_block [i ] == 0 :
131+ # need a new block:
132+ b_idx = start_block_idx [i ]
135133 if self .block_tables [i ][b_idx ] == - 1 :
136134 # need a free block
137- self .block_tables [i ][b_idx ] = self .free_blocks [0 ]
138- self .free_blocks = self .free_blocks [1 :]
139-
135+ self .block_tables [i ][b_idx ] = self .free_blocks .nonzero ().view (- 1 )[0 :1 ]
136+ self .free_blocks [self .block_tables [i ][b_idx ]] = 0
140137 self .slots [i ] = self .block_tables [i ][start_block_idx [i ]] * self .block_size + slot_offset_in_block [i ]
141138 # Update the cache
142139 PagedAttention .reshape_and_cache (
@@ -196,7 +193,7 @@ def reset(self):
196193 """Resets the cache values while preserving the objects"""
197194 self ._seen_tokens = torch .zeros ([self .batch_size ], dtype = torch .int32 , device = self .block_tables .device )
198195 self .block_tables .fill_ (- 1 )
199- self .free_blocks = torch .arange ( self .num_blocks , device = self .block_tables .device )
196+ self .free_blocks = torch .ones ([ self .num_blocks ], dtype = torch . int32 , device = self .block_tables .device )
200197 self .max_seq_len = 0
201198
202199 def reorder_cache (self , beam_idx : torch .LongTensor ):
@@ -206,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
206203 updated_block_tables = self .block_tables .index_select (0 , beam_idx .to (device ))
207204 mask = self .block_tables .masked_fill (self .block_tables != - 1 , 1 ).masked_fill (self .block_tables == - 1 , 0 )
208205 num_blocks = mask .cumsum (- 1 )[:, - 1 ]
209- updated_table = []
206+ updated_table = torch . zeros_like ( beam_idx )
210207 for i in range (beam_idx .shape [0 ]):
211- self . block_tables [ i , 0 : num_blocks [ i ] - 1 ] = updated_block_tables [ i , 0 : num_blocks [i ] - 1 ]
212- updated_table . append ( self .block_tables [i : i + 1 , num_blocks [ i ] - 1 : num_blocks [ i ]])
213- updated_table = torch . cat ( tuple ( updated_table ), dim = 0 )
208+ nb = num_blocks [i ]
209+ self .block_tables [i , 0 : nb - 1 ] = updated_block_tables [ i , 0 : nb - 1 ]
210+ updated_table [ i ] = self . block_tables [ i ][ nb - 1 ]
214211 for layer_idx in range (self .num_hidden_layers ):
215212 self .key_cache [layer_idx ][updated_table ] = self .key_cache [layer_idx ][updated_table [beam_idx ]]
216213 self .value_cache [layer_idx ][updated_table ] = self .value_cache [layer_idx ][updated_table [beam_idx ]]
217214 free_table = torch .unique ((origin_table [origin_table != self .block_tables ]).view (- 1 ))
218- self .free_blocks = torch .cat ((self .free_blocks , free_table ))
215+ for i in free_table :
216+ if not (self .block_tables == i ).any ():
217+ self .free_blocks [i ] = 1
219218
220219 def crop (self , maximum_length : int ):
221220 """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
@@ -235,4 +234,6 @@ def crop(self, maximum_length: int):
235234 self ._seen_tokens [bs ] = new_tokens
236235 self .max_seq_len , _ = self ._seen_tokens .max (dim = 0 )
237236 free_table = torch .unique ((origin_table [origin_table != self .block_tables ]).view (- 1 ))
238- self .free_blocks = torch .cat ((self .free_blocks , free_table ))
237+ for i in free_table :
238+ if not (self .block_tables == i ).any ():
239+ self .free_blocks [i ] = 1
0 commit comments