@@ -175,6 +175,7 @@ def __init__(
175175 enable_indexer_k_cache : bool = False ,
176176 indexer_k_cache_quant_block_size : int = 128 ,
177177 indexer_k_cache_index_head_dim : int = 0 ,
178+ is_estimating_kv_cache : bool = False ,
178179 ** kwargs ,
179180 ) -> None :
180181 self .mapping = mapping
@@ -269,37 +270,61 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
269270 # Determine if this is VSWA (Variable Sliding Window Attention)
270271 self .is_vswa = len (set (self .max_attention_window_vec )) > 1
271272
272- # Calculate blocks per window using appropriate method
273- if self .is_vswa :
274- # VSWA case: use C++ implementation for variable window sizes
275- # model config check
276- if model_config is None :
277- raise ValueError (
278- "model_config is required for VSWA (Variable Sliding Window Attention)"
279- )
280- # kv cache config check
281- assert isinstance (
282- kv_cache_config , KvCacheConfig
283- ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
284- blocks_per_window = self .calculate_max_num_blocks_from_cpp (
285- kv_cache_config = kv_cache_config ,
286- model_config = model_config ,
287- extra_cost_memory = 0 ,
288- )
289- else :
290- # Standard case: use original Python implementation
291- self .blocks_in_primary_pool , self .blocks_in_secondary_pool = self .calculate_max_num_blocks (
292- kv_cache_config = kv_cache_config ,
293- head_dim = head_dim ,
294- tokens_per_block = tokens_per_block ,
295- mapping = mapping ,
296- dtype = dtype ,
297- kv_factor = self .kv_factor ,
273+ # Calculate kv cache blocks for each window size
274+ # FIXME: flashinfer.py accesses kv_cache_manager.blocks_in_primary_pool
275+ # This dependency should be adjusted as it only covers the single window
276+ # case and not VSWA scheme.
277+ if is_estimating_kv_cache :
278+ # If this is an estimation dry run, we have already calculated the
279+ # max_tokens under _util.py::try_prepare_estimation
280+ # Since this is a dry run, assigning the same max_tokens capacity
281+ # to all window sizes as they are full attentions is enough.
282+ self .blocks_in_primary_pool = int (kv_cache_config .max_tokens //
283+ tokens_per_block )
284+
285+ host_cache_size = kv_cache_config .host_cache_size if kv_cache_config .host_cache_size else 0
286+ max_tokens_secondary = host_cache_size // self .get_cache_bytes_per_token (
298287 )
288+ self .blocks_in_secondary_pool = int (max_tokens_secondary //
289+ tokens_per_block )
290+
299291 blocks_per_window = {
300- self . max_attention_window_vec [ 0 ] :
292+ window_size :
301293 (self .blocks_in_primary_pool , self .blocks_in_secondary_pool )
294+ for window_size in set (self .max_attention_window_vec )
302295 }
296+ logger .info (
297+ f"[kv cache manager] Primary/secondary blocks for window sizes set to { blocks_per_window } for estimation dry run"
298+ )
299+ else :
300+ if self .is_vswa :
301+ # VSWA case: use C++ implementation for variable window sizes
302+ if model_config is None :
303+ raise ValueError (
304+ "model_config is required for VSWA (Variable Sliding Window Attention)"
305+ )
306+ assert isinstance (
307+ kv_cache_config , KvCacheConfig
308+ ), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
309+ blocks_per_window = self .calculate_max_num_blocks_from_cpp (
310+ kv_cache_config = kv_cache_config ,
311+ model_config = model_config ,
312+ extra_cost_memory = 0 ,
313+ )
314+ else :
315+ # Standard case: use original Python implementation
316+ self .blocks_in_primary_pool , self .blocks_in_secondary_pool = self .calculate_max_num_blocks (
317+ kv_cache_config = kv_cache_config ,
318+ head_dim = head_dim ,
319+ tokens_per_block = tokens_per_block ,
320+ mapping = mapping ,
321+ dtype = dtype ,
322+ kv_factor = self .kv_factor ,
323+ )
324+ blocks_per_window = {
325+ self .max_attention_window_vec [0 ]:
326+ (self .blocks_in_primary_pool , self .blocks_in_secondary_pool )
327+ }
303328
304329 # Validate and adjust attention windows against their upper bounds if needed
305330 blocks_per_window , self .max_seq_len , self .max_attention_window_vec = self ._validate_and_adjust_attention_windows (
@@ -736,11 +761,13 @@ def calculate_max_num_blocks(self,
736761 max_tokens = mpi_comm ().allreduce (max_tokens , op = MPI .MIN )
737762
738763 # get number of blocks
739- blocks_in_primary_pool = math .ceil (max_tokens / tokens_per_block )
764+ blocks_in_primary_pool = int (max_tokens // tokens_per_block )
765+
740766 host_cache_size = kv_cache_config .host_cache_size if kv_cache_config .host_cache_size else 0
741- max_tokens_secondary = host_cache_size / cache_size_bytes_per_token
742- blocks_in_secondary_pool = max (
743- 0 , int (max_tokens_secondary / tokens_per_block ))
767+ max_tokens_secondary = host_cache_size // self .get_cache_bytes_per_token (
768+ )
769+ blocks_in_secondary_pool = int (max_tokens_secondary // tokens_per_block )
770+
744771 return blocks_in_primary_pool , blocks_in_secondary_pool
745772
746773 def get_max_atten_window_upper_bound (self , blocks_in_primary_pool ,
0 commit comments