@@ -34,22 +34,23 @@ class IPEXPagedCache(Cache):
3434 def __init__ (
3535 self ,
3636 config : PretrainedConfig ,
37- batch_size : int ,
37+ max_batch_size : int ,
3838 max_cache_len : int ,
3939 device ,
4040 dtype = None ,
4141 layer_device_map = None ,
4242 ** kwargs ,
4343 ) -> None :
4444 super ().__init__ ()
45- self .batch_size = batch_size
45+ self .max_batch_size = max_batch_size
4646 # Used in `generate` to keep tally of how many tokens the cache has seen
47- self ._seen_tokens = torch .zeros ([batch_size ], dtype = torch .int32 , device = device )
47+
48+ self ._seen_tokens = torch .zeros ([max_batch_size ], dtype = torch .int32 , device = device )
4849 default_block_size = 16 if device .type == "cpu" else 64
4950 self .block_size = int (os .environ .get ("OI_PAGED_ATTN_BLOCK_SIZE" , str (default_block_size )))
50- self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * batch_size
51+ self .num_blocks = (max_cache_len // self .block_size + (max_cache_len % self .block_size != 0 )) * max_batch_size
5152 self .block_tables = - 1 * torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device ).reshape (
52- batch_size , - 1
53+ max_batch_size , - 1
5354 )
5455 self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = device )
5556 self .max_cache_len = max_cache_len
@@ -193,7 +194,7 @@ def get_max_length(self) -> Optional[int]:
193194
194195 def reset (self ):
195196 """Resets the cache values while preserving the objects"""
196- self ._seen_tokens = torch .zeros ([self .batch_size ], dtype = torch .int32 , device = self .block_tables .device )
197+ self ._seen_tokens = torch .zeros ([self .max_batch_size ], dtype = torch .int32 , device = self .block_tables .device )
197198 self .block_tables .fill_ (- 1 )
198199 self .free_blocks = torch .ones ([self .num_blocks ], dtype = torch .int32 , device = self .block_tables .device )
199200 self .max_seq_len = 0
0 commit comments