Skip to content

Commit ab1c547

Browse files
fix: adjust window sizes of VSWA at torch backend (NVIDIA#5880)
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 9e871ca commit ab1c547

File tree

3 files changed

+251
-13
lines changed

3 files changed

+251
-13
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import copy
12
import enum
23
import math
34
from abc import ABC, abstractmethod
45
from collections import OrderedDict, defaultdict
5-
from typing import Dict, List, Optional, Tuple, Union
6+
from typing import Dict, List, Optional, Set, Tuple, Union
67

78
import torch
89

@@ -11,7 +12,7 @@
1112
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
1213
from tensorrt_llm.sampling_params import SamplingParams
1314

14-
from ..._utils import nvtx_range
15+
from ..._utils import binding_dtype_size, nvtx_range
1516
from ...logger import logger
1617
from ...mapping import Mapping
1718
from .llm_request import LlmRequest, LlmRequestState, SamplingConfig
@@ -437,14 +438,10 @@ def calculate_max_num_blocks(self,
437438
cache_size_per_token = kv_factor * sum(
438439
self.num_kv_heads_per_layer) * head_dim
439440

440-
if dtype == DataType.FP8:
441-
kv_cache_dtype_bytes = 1
442-
elif dtype in (DataType.HALF, DataType.BF16):
443-
kv_cache_dtype_bytes = 2
444-
elif dtype == DataType.FLOAT:
445-
kv_cache_dtype_bytes = 4
446-
else:
441+
if dtype not in (DataType.FP8, DataType.HALF, DataType.BF16,
442+
DataType.FLOAT):
447443
raise ValueError(f'Cannot support {dtype} KV cache.')
444+
kv_cache_dtype_bytes = binding_dtype_size(dtype)
448445

449446
cache_size_bytes_per_token = cache_size_per_token * kv_cache_dtype_bytes
450447
free_mem, total_mem = torch.cuda.mem_get_info()
@@ -603,6 +600,102 @@ def _get_window_size_to_layers(self) -> dict[int, list[int]]:
603600
window_size_to_layers_map[window_size].append(local_layer_idx)
604601
return window_size_to_layers_map
605602

603+
@staticmethod
604+
def adjust_window_sizes_for_vswa(
605+
window_size_to_layers: Dict[int, List[int]],
606+
kv_cache_config: KvCacheConfigCpp,
607+
model_config: ModelConfig,
608+
pool_memory_bytes: int,
609+
kv_factor: int,
610+
dtype: DataType,
611+
is_cross_attention: bool = False,
612+
) -> Dict[int, List[int]]:
613+
614+
assert is_cross_attention is False, 'Cross attention is not supported'
615+
616+
max_tokens_from_config = kv_cache_config.max_tokens
617+
618+
def calculate_cache_size_per_token(layers: Set[int]) -> int:
619+
# Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize
620+
total_kv_heads = sum(model_config.num_kv_heads_per_layer[i]
621+
for i in layers)
622+
return total_kv_heads * kv_factor * model_config.head_size
623+
624+
# Calculate the required memory bytes per sequence.
625+
required_mem_bytes_per_seq = 0
626+
for window_size in sorted(window_size_to_layers):
627+
layers = window_size_to_layers[window_size]
628+
cache_size_per_token = calculate_cache_size_per_token(layers)
629+
cache_size_bytes_per_token = cache_size_per_token * binding_dtype_size(
630+
dtype)
631+
required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token
632+
logger.debug(
633+
f'Required memory per sequence: {required_mem_bytes_per_seq} bytes')
634+
635+
if required_mem_bytes_per_seq < pool_memory_bytes:
636+
# No need to adjust the window sizes.
637+
return copy.deepcopy(window_size_to_layers)
638+
639+
logger.debug(
640+
f'Adjusting the window sizes {list(window_size_to_layers)} to fit '
641+
f'the memory {pool_memory_bytes} bytes.')
642+
adjusted_window_size_to_layers = {}
643+
644+
remaining_mem_bytes = pool_memory_bytes
645+
remaining_layers = set(i for layers in window_size_to_layers.values()
646+
for i in layers)
647+
648+
accum_max_tokens = 0
649+
prev_window_size = 0
650+
651+
for window_size in sorted(window_size_to_layers):
652+
layers = window_size_to_layers[window_size]
653+
if remaining_mem_bytes > 0 and remaining_layers:
654+
# Calculate cache size per token for remaining layers only
655+
cache_size_per_token = calculate_cache_size_per_token(
656+
remaining_layers)
657+
cache_size_bytes_per_token = cache_size_per_token * binding_dtype_size(
658+
dtype)
659+
logger.debug(
660+
f'Cache size per token for {len(remaining_layers)} layers: '
661+
f'{cache_size_bytes_per_token} bytes')
662+
# Calculate max tokens that can fit in this window with remaining memory.
663+
max_tokens_in_window = min(
664+
remaining_mem_bytes // cache_size_bytes_per_token,
665+
window_size - prev_window_size)
666+
remaining_mem_bytes -= max_tokens_in_window * cache_size_bytes_per_token
667+
accum_max_tokens += max_tokens_in_window
668+
logger.debug(f'Remaining memory: {remaining_mem_bytes} bytes')
669+
logger.debug(
670+
f'Max token of window {window_size}: {accum_max_tokens}')
671+
672+
if accum_max_tokens < window_size:
673+
logger.debug(
674+
f'Max tokens ({accum_max_tokens}) cannot fill the current window ({window_size}). '
675+
f'The larger windows will have the same max tokens.')
676+
remaining_mem_bytes = 0
677+
678+
# Clamp the sequence length if provided explicitly.
679+
if max_tokens_from_config is not None:
680+
accum_max_tokens = min(max_tokens_from_config,
681+
accum_max_tokens)
682+
# If max tokens from config is reached, stop allocating
683+
# more memory. Since the maximum number of tokens is
684+
# already reached, for the remaining windows maxTokens
685+
# will be set by the current value of accumMaxTokens.
686+
if accum_max_tokens == max_tokens_from_config:
687+
remaining_mem_bytes = 0
688+
689+
if accum_max_tokens not in adjusted_window_size_to_layers:
690+
adjusted_window_size_to_layers[accum_max_tokens] = layers.copy()
691+
else:
692+
adjusted_window_size_to_layers[accum_max_tokens].extend(layers)
693+
694+
remaining_layers -= set(layers)
695+
prev_window_size = window_size
696+
697+
return adjusted_window_size_to_layers
698+
606699
def calculate_max_num_blocks_from_cpp(
607700
self,
608701
kv_cache_config: KvCacheConfigCpp,
@@ -622,6 +715,9 @@ def calculate_max_num_blocks_from_cpp(
622715
A dict of (max_attention_window, (blocks_in_primary_pool, blocks_in_secondary_pool)).
623716
"""
624717

718+
# VSWA on Torch backend has not supported the cross attention.
719+
is_cross_attention = False
720+
625721
# Construct WorldConfig from self.mapping
626722
world_config_cpp = WorldConfig(
627723
tensor_parallelism=self.mapping.tp_size,
@@ -636,12 +732,26 @@ def calculate_max_num_blocks_from_cpp(
636732
primary_pool_memory_bytes = free_mem
637733
secondary_pool_memory_bytes = 0
638734
logger.debug(
639-
f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \nsecondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB"
735+
f"primary_pool_memory_bytes is set to {primary_pool_memory_bytes/1024**3}GB, \n"
736+
f"secondary_pool_memory_bytes is set to {secondary_pool_memory_bytes/1024**3}GB"
737+
)
738+
739+
# Adjust the window sizes to fit the memory if even a single sequence
740+
# cannot fit in the memory.
741+
window_size_to_layers = self.adjust_window_sizes_for_vswa(
742+
window_size_to_layers=window_size_to_layers,
743+
model_config=model_config,
744+
kv_cache_config=kv_cache_config,
745+
pool_memory_bytes=primary_pool_memory_bytes,
746+
kv_factor=self.kv_factor,
747+
dtype=self.dtype,
748+
is_cross_attention=is_cross_attention,
640749
)
641750

642751
blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks(
643752
config=kv_cache_config,
644-
is_cross_attention=False, #TODO: support cross attention
753+
# TODO: support cross attention
754+
is_cross_attention=is_cross_attention,
645755
dtype=self.dtype,
646756
model_config=model_config,
647757
world_config=world_config_cpp,

tensorrt_llm/_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,22 @@ def str_dtype_to_torch(dtype):
180180
fp8=DataType.FP8,
181181
)
182182

183+
_binding_dtype_size = {
184+
DataType.INT64: 8,
185+
DataType.FLOAT: 4,
186+
DataType.INT32: 4,
187+
DataType.BF16: 2,
188+
DataType.HALF: 2,
189+
DataType.BOOL: 1,
190+
DataType.FP8: 1,
191+
DataType.INT8: 1,
192+
DataType.UINT8: 1,
193+
}
194+
195+
196+
def binding_dtype_size(dtype: DataType):
197+
return _binding_dtype_size[dtype]
198+
183199

184200
def str_dtype_to_binding(dtype):
185201
ret = _str_to_binding_dtype_dict.get(dtype)

tests/unittest/_torch/test_resource_manager.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
import tensorrt_llm
1212
import tensorrt_llm.bindings
13-
from tensorrt_llm._torch.pyexecutor.resource_manager import (PeftCacheConfig,
13+
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
14+
PeftCacheConfig,
1415
PeftCacheManager)
1516
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
1617
from tensorrt_llm.bindings import executor as tllm
1718
from tensorrt_llm.bindings.internal.batch_manager import \
1819
PeftTaskNotCachedException
1920

21+
DataType = tensorrt_llm.bindings.DataType
2022
LoraModule = tensorrt_llm.bindings.LoraModule
2123
LoraModuleType = tensorrt_llm.bindings.LoraModuleType
2224
current_dir = pathlib.Path(__file__).parent.resolve()
@@ -66,7 +68,15 @@ def __init__(self):
6668
self.num_rnn_layers = 0
6769
self.num_attention_heads = 1
6870
self.hidden_size = 16
69-
self.data_type = tensorrt_llm.bindings.DataType.HALF
71+
self.data_type = DataType.HALF
72+
73+
@property
74+
def num_kv_heads_per_layer(self):
75+
return [self.num_attention_heads] * self.num_attention_layers
76+
77+
@property
78+
def head_size(self):
79+
return self.hidden_size // self.num_attention_heads
7080

7181
class MockPeftCacheManagerConfig:
7282
"""
@@ -416,3 +426,105 @@ def test_put_get(self):
416426
self.assertEqual(entry.layer_id, expected_values[i][5])
417427
self.assertEqual(entry.adapter_size, expected_values[i][6])
418428
self.assertEqual(entry.num_slots, expected_values[i][7])
429+
430+
def test_adjust_window_sizes_for_vswa(self):
431+
window_size_to_layers = {
432+
100: [0, 1, 2, 3],
433+
200: [4, 5, 6],
434+
7000: [7, 8],
435+
}
436+
437+
model_config = self.MockModelConfig()
438+
model_config.num_attention_heads = 2
439+
model_config.hidden_size = 2
440+
model_config.data_type = DataType.HALF
441+
442+
total_layers = [
443+
i for layers in window_size_to_layers.values() for i in layers
444+
]
445+
446+
model_config.num_hidden_layers = len(total_layers)
447+
model_config.num_attention_layers = len(total_layers)
448+
449+
kv_factor = 2
450+
cache_bytes_per_token_per_layer = 8
451+
452+
# Define test cases:
453+
# (memory_bytes, expected_window_sizes, max_tokens, description)
454+
# If max_tokens is None, then it will use the default value of KvCacheConfig.
455+
test_cases = [
456+
(
457+
# Case 1: Limited memory - windows get clamped
458+
cache_bytes_per_token_per_layer * (100 * 9 + 30 * 5) + 4,
459+
{
460+
100: [0, 1, 2, 3],
461+
130: [4, 5, 6, 7, 8],
462+
},
463+
None,
464+
"limited_memory_clamped_windows"),
465+
(
466+
# Case 2: Less limited memory - the largest window get clamped
467+
cache_bytes_per_token_per_layer *
468+
(100 * 9 + 100 * 5 + 817 * 2) + 4,
469+
{
470+
100: [0, 1, 2, 3],
471+
200: [4, 5, 6],
472+
1017: [7, 8],
473+
},
474+
None,
475+
"less_limited_memory_clamped_windows"),
476+
(
477+
# Case 3: Sufficient memory - no clamping needed
478+
cache_bytes_per_token_per_layer *
479+
(100 * 4 + 200 * 3 + 7000 * 2) + 9402,
480+
{
481+
100: [0, 1, 2, 3],
482+
200: [4, 5, 6],
483+
7000: [7, 8],
484+
},
485+
None,
486+
"sufficient_memory_no_clamping"),
487+
(
488+
# Case 4: Very limited memory - all windows get small values
489+
cache_bytes_per_token_per_layer * (51 * 9) + 1,
490+
{
491+
51: [0, 1, 2, 3, 4, 5, 6, 7, 8],
492+
},
493+
None,
494+
"very_limited_memory_all_clamped"),
495+
(
496+
# Case 5: Less limited memory but max_tokens is given.
497+
# memory is enough for 1017 tokens, it will be clamped by max_tokens=134.
498+
cache_bytes_per_token_per_layer *
499+
(100 * 9 + 100 * 5 + 817 * 2) + 4,
500+
{
501+
100: [0, 1, 2, 3],
502+
134: [4, 5, 6, 7, 8],
503+
},
504+
134,
505+
"less_limited_memory_but_clamped_by_max_tokens"),
506+
]
507+
508+
for memory_bytes, expected_window_sizes, max_tokens, description in test_cases:
509+
with self.subTest(case=description, memory_bytes=memory_bytes):
510+
kv_cache_config = tllm.KvCacheConfig(max_tokens=max_tokens)
511+
adjusted = KVCacheManager.adjust_window_sizes_for_vswa(
512+
window_size_to_layers=window_size_to_layers,
513+
model_config=model_config,
514+
kv_cache_config=kv_cache_config,
515+
pool_memory_bytes=memory_bytes,
516+
kv_factor=kv_factor,
517+
dtype=model_config.data_type,
518+
is_cross_attention=False,
519+
)
520+
521+
self.assertEqual(
522+
adjusted, expected_window_sizes,
523+
f"Test case '{description}' failed.\n"
524+
f"Memory bytes: {memory_bytes}\n"
525+
f"Actual: {adjusted}\n"
526+
f"Expected: {expected_window_sizes}")
527+
528+
529+
if __name__ == "__main__":
530+
unittest.main()

0 commit comments

Comments
 (0)