Skip to content

Commit a38d91a

Browse files
authored
[https://nvbugs/5537996][fix] Let KV cache manager block initialization be aware whether it is doing a dry run or not (#9093)
Before this commit, the kv cache manager does the same regardless, which causes a mis-calculation in free memory available to allocate for the KV cache manager, hence causing a crash. This commit fixes this by letting KV cache manager initialization be aware whether it is doing the dry run or not. If it is a dry run, use the max_tokens setting that is already pre-calculated and filled into kv_cache_config.max_tokens. Signed-off-by: eopXD <yuehtingc@nvidia.com>
1 parent 4742c13 commit a38d91a

File tree

3 files changed

+64
-31
lines changed

3 files changed

+64
-31
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def _create_kv_cache_manager(
472472
kv_connector_manager=self._kv_connector_manager
473473
if not estimating_kv_cache else None,
474474
sparse_attn_config=sparse_attn_config,
475+
is_estimating_kv_cache=estimating_kv_cache,
475476
)
476477
elif is_nemotron_hybrid(config):
477478
if self._max_beam_width > 1:
@@ -518,6 +519,7 @@ def _create_kv_cache_manager(
518519
mapping=mapping,
519520
dtype=kv_cache_dtype,
520521
spec_config=spec_config,
522+
is_estimating_kv_cache=estimating_kv_cache,
521523
)
522524
elif is_qwen3_next(config):
523525
if self._max_beam_width > 1:
@@ -568,6 +570,7 @@ def _create_kv_cache_manager(
568570
mapping=mapping,
569571
dtype=kv_cache_dtype,
570572
spec_config=spec_config,
573+
is_estimating_kv_cache=estimating_kv_cache,
571574
)
572575
else:
573576
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
@@ -595,6 +598,7 @@ def _create_kv_cache_manager(
595598
kv_connector_manager=self._kv_connector_manager
596599
if not estimating_kv_cache else None,
597600
sparse_attn_config=sparse_attn_config,
601+
is_estimating_kv_cache=estimating_kv_cache,
598602
)
599603
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
600604
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def __init__(
195195
mapping: Mapping,
196196
dtype: DataType = DataType.HALF,
197197
spec_config: Optional["DecodingBaseConfig"] = None,
198+
is_estimating_kv_cache: bool = False,
198199
) -> None:
199200

200201
# mamba hybrid cache requires block reuse to be disabled in KV cache config
@@ -231,6 +232,7 @@ def __init__(
231232
dtype=dtype,
232233
spec_config=spec_config,
233234
layer_mask=layer_mask,
235+
is_estimating_kv_cache=is_estimating_kv_cache,
234236
)
235237

236238
def prepare_resources(self, scheduled_batch: ScheduledRequests):

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def __init__(
175175
enable_indexer_k_cache: bool = False,
176176
indexer_k_cache_quant_block_size: int = 128,
177177
indexer_k_cache_index_head_dim: int = 0,
178+
is_estimating_kv_cache: bool = False,
178179
**kwargs,
179180
) -> None:
180181
self.mapping = mapping
@@ -269,37 +270,61 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
269270
# Determine if this is VSWA (Variable Sliding Window Attention)
270271
self.is_vswa = len(set(self.max_attention_window_vec)) > 1
271272

272-
# Calculate blocks per window using appropriate method
273-
if self.is_vswa:
274-
# VSWA case: use C++ implementation for variable window sizes
275-
# model config check
276-
if model_config is None:
277-
raise ValueError(
278-
"model_config is required for VSWA (Variable Sliding Window Attention)"
279-
)
280-
# kv cache config check
281-
assert isinstance(
282-
kv_cache_config, KvCacheConfig
283-
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
284-
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
285-
kv_cache_config=kv_cache_config,
286-
model_config=model_config,
287-
extra_cost_memory=0,
288-
)
289-
else:
290-
# Standard case: use original Python implementation
291-
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
292-
kv_cache_config=kv_cache_config,
293-
head_dim=head_dim,
294-
tokens_per_block=tokens_per_block,
295-
mapping=mapping,
296-
dtype=dtype,
297-
kv_factor=self.kv_factor,
273+
# Calculate kv cache blocks for each window size
274+
# FIXME: flashinfer.py accesses kv_cache_manager.blocks_in_primary_pool
275+
# This dependency should be adjusted as it only covers the single window
276+
# case and not VSWA scheme.
277+
if is_estimating_kv_cache:
278+
# If this is an estimation dry run, we have already calculated the
279+
# max_tokens under _util.py::try_prepare_estimation
280+
# Since this is a dry run, assigning the same max_tokens capacity
281+
# to all window sizes as they are full attentions is enough.
282+
self.blocks_in_primary_pool = int(kv_cache_config.max_tokens //
283+
tokens_per_block)
284+
285+
host_cache_size = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0
286+
max_tokens_secondary = host_cache_size // self.get_cache_bytes_per_token(
298287
)
288+
self.blocks_in_secondary_pool = int(max_tokens_secondary //
289+
tokens_per_block)
290+
299291
blocks_per_window = {
300-
self.max_attention_window_vec[0]:
292+
window_size:
301293
(self.blocks_in_primary_pool, self.blocks_in_secondary_pool)
294+
for window_size in set(self.max_attention_window_vec)
302295
}
296+
logger.info(
297+
f"[kv cache manager] Primary/secondary blocks for window sizes set to {blocks_per_window} for estimation dry run"
298+
)
299+
else:
300+
if self.is_vswa:
301+
# VSWA case: use C++ implementation for variable window sizes
302+
if model_config is None:
303+
raise ValueError(
304+
"model_config is required for VSWA (Variable Sliding Window Attention)"
305+
)
306+
assert isinstance(
307+
kv_cache_config, KvCacheConfig
308+
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
309+
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
310+
kv_cache_config=kv_cache_config,
311+
model_config=model_config,
312+
extra_cost_memory=0,
313+
)
314+
else:
315+
# Standard case: use original Python implementation
316+
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
317+
kv_cache_config=kv_cache_config,
318+
head_dim=head_dim,
319+
tokens_per_block=tokens_per_block,
320+
mapping=mapping,
321+
dtype=dtype,
322+
kv_factor=self.kv_factor,
323+
)
324+
blocks_per_window = {
325+
self.max_attention_window_vec[0]:
326+
(self.blocks_in_primary_pool, self.blocks_in_secondary_pool)
327+
}
303328

304329
# Validate and adjust attention windows against their upper bounds if needed
305330
blocks_per_window, self.max_seq_len, self.max_attention_window_vec = self._validate_and_adjust_attention_windows(
@@ -736,11 +761,13 @@ def calculate_max_num_blocks(self,
736761
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)
737762

738763
# get number of blocks
739-
blocks_in_primary_pool = math.ceil(max_tokens / tokens_per_block)
764+
blocks_in_primary_pool = int(max_tokens // tokens_per_block)
765+
740766
host_cache_size = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0
741-
max_tokens_secondary = host_cache_size / cache_size_bytes_per_token
742-
blocks_in_secondary_pool = max(
743-
0, int(max_tokens_secondary / tokens_per_block))
767+
max_tokens_secondary = host_cache_size // self.get_cache_bytes_per_token(
768+
)
769+
blocks_in_secondary_pool = int(max_tokens_secondary // tokens_per_block)
770+
744771
return blocks_in_primary_pool, blocks_in_secondary_pool
745772

746773
def get_max_atten_window_upper_bound(self, blocks_in_primary_pool,

0 commit comments

Comments
 (0)