@@ -54,7 +54,10 @@ def __init__(
5454        # Used in `generate` to keep tally of how many tokens the cache has seen 
5555
5656        self ._seen_tokens  =  torch .zeros ([max_batch_size ], dtype = torch .int32 , device = device )
57-         default_block_size  =  16 
57+         self .slots  =  torch .zeros ([max_cache_len  *  max_batch_size ], dtype = torch .int32 , device = device )
58+         torch ._dynamo .mark_static_address (self ._seen_tokens )
59+         torch ._dynamo .mark_static_address (self .slots )
60+         default_block_size  =  16  if  max_cache_len  <=  64  else  64 
5861        self .block_size  =  int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , str (default_block_size )))
5962        self .num_blocks  =  (max_cache_len  //  self .block_size  +  (max_cache_len  %  self .block_size  !=  0 )) *  max_batch_size 
6063        self .block_tables  =  - 1  *  torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
@@ -69,7 +72,6 @@ def __init__(
6972        else :
7073            head_size  =  config .hidden_size  //  config .num_attention_heads 
7174        self .head_size  =  head_size 
72-         self .max_seq_len  =  0 
7375
7476        self .key_cache : List [torch .Tensor ] =  []
7577        self .value_cache : List [torch .Tensor ] =  []
@@ -87,6 +89,8 @@ def __init__(
8789        for  i  in  range (config .num_hidden_layers ):
8890            new_layer_key_cache  =  torch .zeros (key_cache_shape , dtype = dtype , device = device )
8991            new_layer_value_cache  =  torch .zeros (value_cache_shape , dtype = dtype , device = device )
92+             torch ._dynamo .mark_static_address (new_layer_key_cache )
93+             torch ._dynamo .mark_static_address (new_layer_value_cache )
9094            self .key_cache .append (new_layer_key_cache )
9195            self .value_cache .append (new_layer_value_cache )
9296
@@ -116,79 +120,50 @@ def reshape_and_cache(
116120                slots ,
117121            )
118122
119-     def  update_for_prefill (
120-         self ,
121-         key_states : torch .Tensor ,
122-         value_states : torch .Tensor ,
123-         layer_idx : int ,
124-         batch_size : int ,
125-         input_lens : torch .Tensor ,
126-     ):
127-         if  layer_idx  ==  0 :
128-             all_block_indices  =  []
129-             all_slot_offsets  =  []
130-             num_blocks  =  (input_lens  +  self .block_size  -  1 ) //  self .block_size 
131-             for  i  in  range (batch_size ):
132-                 nb  =  num_blocks [i ]
133-                 block_table  =  self .free_blocks .nonzero ().view (- 1 )[0 :nb ]
134-                 self .block_tables [i ][0 :nb ] =  block_table 
135-                 self .free_blocks [block_table ] =  0 
136-                 slots_range  =  torch .arange (input_lens [i ], device = self .device )
137-                 block_indices  =  slots_range  //  self .block_size 
138-                 slot_offsets  =  slots_range  %  self .block_size 
139-                 all_block_indices .append (self .block_tables [i ][block_indices ])
140-                 all_slot_offsets .append (slot_offsets )
141- 
142-             all_block_indices  =  torch .cat (all_block_indices )
143-             all_slot_offsets  =  torch .cat (all_slot_offsets )
144-             self .slots  =  all_block_indices  *  self .block_size  +  all_slot_offsets 
145-         # Update the cache 
146-         self .reshape_and_cache (
147-             key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots 
148-         )
149- 
150-         # Update the number of seen tokens 
151-         if  layer_idx  ==  self .num_hidden_layers  -  1 :
152-             self ._seen_tokens  =  self ._seen_tokens  +  input_lens 
153-             self .max_seq_len , _  =  self ._seen_tokens .max (dim = 0 )
154- 
155-     def  update_for_decode (
156-         self ,
157-         key_states : torch .Tensor ,
158-         value_states : torch .Tensor ,
159-         layer_idx : int ,
160-         batch_size : int ,
161-     ):
162-         if  layer_idx  ==  0 :
163-             start_block_idx  =  self ._seen_tokens  //  self .block_size 
164-             slot_offset_in_block  =  (self ._seen_tokens ) %  self .block_size 
165-             self .slots  =  torch .zeros ([batch_size ], device = self .device , dtype = torch .int32 )
166-             for  i  in  range (batch_size ):
167-                 if  slot_offset_in_block [i ] ==  0 :
168-                     # need a new block: 
169-                     b_idx  =  start_block_idx [i ]
170-                     if  self .block_tables [i ][b_idx ] ==  - 1 :
171-                         # need a free block 
172-                         self .block_tables [i ][b_idx ] =  self .free_blocks .nonzero ().view (- 1 )[0 :1 ]
173-                         self .free_blocks [self .block_tables [i ][b_idx ]] =  0 
174-                 self .slots [i ] =  self .block_tables [i ][start_block_idx [i ]] *  self .block_size  +  slot_offset_in_block [i ]
175-         # Update the cache 
176-         self .reshape_and_cache (
177-             key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots 
178-         )
179- 
180-         # Update the number of seen tokens 
181-         if  layer_idx  ==  self .num_hidden_layers  -  1 :
182-             self ._seen_tokens  =  self ._seen_tokens  +  1 
183-             self .max_seq_len  =  self .max_seq_len  +  1 
123+     # outside the model forward 
124+     def  alloc_slot_for_prefill (self , input_lens : torch .Tensor , batch_size : int ):
125+         all_block_indices  =  []
126+         all_slot_offsets  =  []
127+         num_blocks  =  (input_lens  +  self .block_size  -  1 ) //  self .block_size 
128+         for  i  in  range (batch_size ):
129+             nb  =  num_blocks [i ]
130+             scores  =  self .free_blocks  *  torch .arange (self .free_blocks .shape [0 ], 0 , - 1 )
131+             block_table  =  torch .topk (scores , nb ).indices 
132+             self .block_tables [i ][0 :nb ] =  block_table 
133+             self .free_blocks [block_table ] =  0 
134+             slots_range  =  torch .arange (input_lens [i ], device = self .device )
135+             block_indices  =  slots_range  //  self .block_size 
136+             slot_offsets  =  slots_range  %  self .block_size 
137+             all_block_indices .append (self .block_tables [i ][block_indices ])
138+             all_slot_offsets .append (slot_offsets )
139+ 
140+         all_block_indices  =  torch .cat (all_block_indices )
141+         all_slot_offsets  =  torch .cat (all_slot_offsets ).int ()
142+         # Use inplace op to keep the same memory address, avoid recompile 
143+         self .slots [: all_block_indices .shape [0 ]].copy_ (all_block_indices  *  self .block_size  +  all_slot_offsets )
144+ 
145+     # outside the model forward 
146+     def  alloc_slot_for_decode (self , batch_size : int ):
147+         start_block_idx  =  self ._seen_tokens  //  self .block_size 
148+         slot_offset_in_block  =  (self ._seen_tokens ) %  self .block_size 
149+         # Use inplace op to keep the same memory address, avoid recompile 
150+         self .slots .zero_ ()
151+         for  i  in  range (batch_size ):
152+             if  slot_offset_in_block [i ] ==  0 :
153+                 # need a new block: 
154+                 b_idx  =  start_block_idx [i ]
155+                 if  self .block_tables [i ][b_idx ] ==  - 1 :
156+                     # Need a free block. Get indices of free blocks, select the first free block 
157+                     scores  =  self .free_blocks  *  torch .arange (self .free_blocks .shape [0 ], 0 , - 1 )
158+                     self .block_tables [i ][b_idx ] =  scores .argmax ()
159+                     self .free_blocks [self .block_tables [i ][b_idx ]] =  0 
160+             self .slots [i ] =  self .block_tables [i ][start_block_idx [i ]] *  self .block_size  +  slot_offset_in_block [i ]
184161
185162    def  update (
186163        self ,
187164        key_states : torch .Tensor ,
188165        value_states : torch .Tensor ,
189166        layer_idx : int ,
190-         attention_mask : torch .Tensor ,
191-         input_lens : torch .Tensor ,
192167    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
193168        """ 
194169        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. 
@@ -204,45 +179,46 @@ def update(
204179            A tuple containing the updated key and value states. 
205180        """ 
206181
207-         batch_size  =  input_lens .shape [- 1 ]
208-         if  self .get_seq_length () ==  0 :
209-             # prefill 
210-             self .update_for_prefill (key_states , value_states , layer_idx , batch_size , input_lens )
211-         else :
212-             # decode 
213-             self .update_for_decode (key_states , value_states , layer_idx , batch_size )
182+         self .reshape_and_cache (
183+             key_states , value_states , self .key_cache [layer_idx ], self .value_cache [layer_idx ], self .slots 
184+         )
214185
215186        return  self .key_cache [layer_idx ], self .value_cache [layer_idx ]
216187
217-     def  get_seq_length (self ,  layer_idx :  Optional [ int ]  =   0 ) ->  int :
188+     def  get_seq_length (self ) ->  int :
218189        """Returns the sequence length of the cached states that were seen by the model.""" 
219-         return  self .max_seq_len 
190+         return  self ._seen_tokens . max () 
220191
221192    def  get_max_length (self ) ->  Optional [int ]:
222193        """Returns the maximum sequence length of the cached states.""" 
223194        return  self .max_cache_len 
224195
225196    def  reset (self ):
226197        """Resets the cache values while preserving the objects""" 
227-         self ._seen_tokens   =   torch . zeros ([ self . max_batch_size ],  dtype = torch . int32 ,  device = self . device )
198+         self ._seen_tokens . zero_ ( )
228199        self .block_tables .fill_ (- 1 )
229-         self .free_blocks  =  torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .device )
230-         self .max_seq_len  =  0 
200+         self .free_blocks .fill_ (1 )
231201
232202    def  reorder_cache (self , beam_idx : torch .LongTensor ):
233203        """Reorders the cache for beam search, given the selected beam indices.""" 
234204        origin_table  =  self .block_tables .clone ()
235205        updated_block_tables  =  self .block_tables .index_select (0 , beam_idx .to (self .device ))
236-         mask  =  self . block_tables . masked_fill (self .block_tables  !=   - 1 ,  1 ). masked_fill ( self . block_tables   ==  - 1 , 0 )
237-         num_blocks  =  mask .cumsum (- 1 )[:,  - 1 ] 
206+         mask  =  torch . where (self .block_tables  ==  - 1 , 0 ,  1 )
207+         num_blocks  =  mask .sum (- 1 )
238208        updated_table  =  torch .zeros_like (beam_idx )
239209        for  i  in  range (beam_idx .shape [0 ]):
240210            nb  =  num_blocks [i ]
241211            self .block_tables [i , 0  : nb  -  1 ] =  updated_block_tables [i , 0  : nb  -  1 ]
242212            updated_table [i ] =  self .block_tables [i ][nb  -  1 ]
243213        for  layer_idx  in  range (self .num_hidden_layers ):
244-             self .key_cache [layer_idx ][updated_table ] =  self .key_cache [layer_idx ][updated_table [beam_idx ]]
245-             self .value_cache [layer_idx ][updated_table ] =  self .value_cache [layer_idx ][updated_table [beam_idx ]]
214+             # The updated_table cannot contain the whole block table, otherwise will cause core-dump. 
215+             self .key_cache [layer_idx ][updated_table ] =  self .key_cache [layer_idx ].index_select (
216+                 0 , updated_table [beam_idx ]
217+             )
218+             self .value_cache [layer_idx ][updated_table ] =  self .value_cache [layer_idx ].index_select (
219+                 0 , updated_table [beam_idx ]
220+             )
221+ 
246222        free_table  =  torch .unique ((origin_table [origin_table  !=  self .block_tables ]).view (- 1 ))
247223        for  i  in  free_table :
248224            if  not  (self .block_tables  ==  i ).any ():
@@ -252,7 +228,7 @@ def crop(self, maximum_length: int):
252228        """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be 
253229        negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" 
254230
255-         max_seq_len  =  self .get_seq_length ()
231+         max_seq_len  =  self ._seen_tokens . max ()
256232        if  maximum_length  <  0 :
257233            maximum_length  =  max_seq_len  -  abs (maximum_length )
258234
@@ -264,7 +240,7 @@ def crop(self, maximum_length: int):
264240            num_blocks  =  (new_tokens  +  self .block_size  -  1 ) //  self .block_size 
265241            self .block_tables [bs , num_blocks :] =  - 1 
266242            self ._seen_tokens [bs ] =  new_tokens 
267-          self . max_seq_len ,  _   =   self . _seen_tokens . max ( dim = 0 ) 
243+ 
268244        free_table  =  torch .unique ((origin_table [origin_table  !=  self .block_tables ]).view (- 1 ))
269245        for  i  in  free_table :
270246            if  not  (self .block_tables  ==  i ).any ():
0 commit comments