1+ import copy
12import enum
23import math
34from abc import ABC , abstractmethod
45from collections import OrderedDict , defaultdict
5- from typing import Dict , List , Optional , Tuple , Union
6+ from typing import Dict , List , Optional , Set , Tuple , Union
67
78import torch
89
1112from tensorrt_llm .bindings .BuildInfo import ENABLE_MULTI_DEVICE
1213from tensorrt_llm .sampling_params import SamplingParams
1314
14- from ..._utils import nvtx_range
15+ from ..._utils import binding_dtype_size , nvtx_range
1516from ...logger import logger
1617from ...mapping import Mapping
1718from .llm_request import LlmRequest , LlmRequestState , SamplingConfig
@@ -437,14 +438,10 @@ def calculate_max_num_blocks(self,
437438 cache_size_per_token = kv_factor * sum (
438439 self .num_kv_heads_per_layer ) * head_dim
439440
440- if dtype == DataType .FP8 :
441- kv_cache_dtype_bytes = 1
442- elif dtype in (DataType .HALF , DataType .BF16 ):
443- kv_cache_dtype_bytes = 2
444- elif dtype == DataType .FLOAT :
445- kv_cache_dtype_bytes = 4
446- else :
441+ if dtype not in (DataType .FP8 , DataType .HALF , DataType .BF16 ,
442+ DataType .FLOAT ):
447443 raise ValueError (f'Cannot support { dtype } KV cache.' )
444+ kv_cache_dtype_bytes = binding_dtype_size (dtype )
448445
449446 cache_size_bytes_per_token = cache_size_per_token * kv_cache_dtype_bytes
450447 free_mem , total_mem = torch .cuda .mem_get_info ()
@@ -603,6 +600,102 @@ def _get_window_size_to_layers(self) -> dict[int, list[int]]:
603600 window_size_to_layers_map [window_size ].append (local_layer_idx )
604601 return window_size_to_layers_map
605602
603+ @staticmethod
604+ def adjust_window_sizes_for_vswa (
605+ window_size_to_layers : Dict [int , List [int ]],
606+ kv_cache_config : KvCacheConfigCpp ,
607+ model_config : ModelConfig ,
608+ pool_memory_bytes : int ,
609+ kv_factor : int ,
610+ dtype : DataType ,
611+ is_cross_attention : bool = False ,
612+ ) -> Dict [int , List [int ]]:
613+
614+ assert is_cross_attention is False , 'Cross attention is not supported'
615+
616+ max_tokens_from_config = kv_cache_config .max_tokens
617+
618+ def calculate_cache_size_per_token (layers : Set [int ]) -> int :
619+ # Same as BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize
620+ total_kv_heads = sum (model_config .num_kv_heads_per_layer [i ]
621+ for i in layers )
622+ return total_kv_heads * kv_factor * model_config .head_size
623+
624+ # Calculate the required memory bytes per sequence.
625+ required_mem_bytes_per_seq = 0
626+ for window_size in sorted (window_size_to_layers ):
627+ layers = window_size_to_layers [window_size ]
628+ cache_size_per_token = calculate_cache_size_per_token (layers )
629+ cache_size_bytes_per_token = cache_size_per_token * binding_dtype_size (
630+ dtype )
631+ required_mem_bytes_per_seq += window_size * cache_size_bytes_per_token
632+ logger .debug (
633+ f'Required memory per sequence: { required_mem_bytes_per_seq } bytes' )
634+
635+ if required_mem_bytes_per_seq < pool_memory_bytes :
636+ # No need to adjust the window sizes.
637+ return copy .deepcopy (window_size_to_layers )
638+
639+ logger .debug (
640+ f'Adjusting the window sizes { list (window_size_to_layers )} to fit '
641+ f'the memory { pool_memory_bytes } bytes.' )
642+ adjusted_window_size_to_layers = {}
643+
644+ remaining_mem_bytes = pool_memory_bytes
645+ remaining_layers = set (i for layers in window_size_to_layers .values ()
646+ for i in layers )
647+
648+ accum_max_tokens = 0
649+ prev_window_size = 0
650+
651+ for window_size in sorted (window_size_to_layers ):
652+ layers = window_size_to_layers [window_size ]
653+ if remaining_mem_bytes > 0 and remaining_layers :
654+ # Calculate cache size per token for remaining layers only
655+ cache_size_per_token = calculate_cache_size_per_token (
656+ remaining_layers )
657+ cache_size_bytes_per_token = cache_size_per_token * binding_dtype_size (
658+ dtype )
659+ logger .debug (
660+ f'Cache size per token for { len (remaining_layers )} layers: '
661+ f'{ cache_size_bytes_per_token } bytes' )
662+ # Calculate max tokens that can fit in this window with remaining memory.
663+ max_tokens_in_window = min (
664+ remaining_mem_bytes // cache_size_bytes_per_token ,
665+ window_size - prev_window_size )
666+ remaining_mem_bytes -= max_tokens_in_window * cache_size_bytes_per_token
667+ accum_max_tokens += max_tokens_in_window
668+ logger .debug (f'Remaining memory: { remaining_mem_bytes } bytes' )
669+ logger .debug (
670+ f'Max token of window { window_size } : { accum_max_tokens } ' )
671+
672+ if accum_max_tokens < window_size :
673+ logger .debug (
674+ f'Max tokens ({ accum_max_tokens } ) cannot fill the current window ({ window_size } ). '
675+ f'The larger windows will have the same max tokens.' )
676+ remaining_mem_bytes = 0
677+
678+ # Clamp the sequence length if provided explicitly.
679+ if max_tokens_from_config is not None :
680+ accum_max_tokens = min (max_tokens_from_config ,
681+ accum_max_tokens )
682+ # If max tokens from config is reached, stop allocating
683+ # more memory. Since the maximum number of tokens is
684+ # already reached, for the remaining windows maxTokens
685+ # will be set by the current value of accumMaxTokens.
686+ if accum_max_tokens == max_tokens_from_config :
687+ remaining_mem_bytes = 0
688+
689+ if accum_max_tokens not in adjusted_window_size_to_layers :
690+ adjusted_window_size_to_layers [accum_max_tokens ] = layers .copy ()
691+ else :
692+ adjusted_window_size_to_layers [accum_max_tokens ].extend (layers )
693+
694+ remaining_layers -= set (layers )
695+ prev_window_size = window_size
696+
697+ return adjusted_window_size_to_layers
698+
606699 def calculate_max_num_blocks_from_cpp (
607700 self ,
608701 kv_cache_config : KvCacheConfigCpp ,
@@ -622,6 +715,9 @@ def calculate_max_num_blocks_from_cpp(
622715 A dict of (max_attention_window, (blocks_in_primary_pool, blocks_in_secondary_pool)).
623716 """
624717
718+ # VSWA on Torch backend has not supported the cross attention.
719+ is_cross_attention = False
720+
625721 # Construct WorldConfig from self.mapping
626722 world_config_cpp = WorldConfig (
627723 tensor_parallelism = self .mapping .tp_size ,
@@ -636,12 +732,26 @@ def calculate_max_num_blocks_from_cpp(
636732 primary_pool_memory_bytes = free_mem
637733 secondary_pool_memory_bytes = 0
638734 logger .debug (
639- f"primary_pool_memory_bytes is set to { primary_pool_memory_bytes / 1024 ** 3 } GB, \n secondary_pool_memory_bytes is set to { secondary_pool_memory_bytes / 1024 ** 3 } GB"
735+ f"primary_pool_memory_bytes is set to { primary_pool_memory_bytes / 1024 ** 3 } GB, \n "
736+ f"secondary_pool_memory_bytes is set to { secondary_pool_memory_bytes / 1024 ** 3 } GB"
737+ )
738+
739+ # Adjust the window sizes to fit the memory if even a single sequence
740+ # cannot fit in the memory.
741+ window_size_to_layers = self .adjust_window_sizes_for_vswa (
742+ window_size_to_layers = window_size_to_layers ,
743+ model_config = model_config ,
744+ kv_cache_config = kv_cache_config ,
745+ pool_memory_bytes = primary_pool_memory_bytes ,
746+ kv_factor = self .kv_factor ,
747+ dtype = self .dtype ,
748+ is_cross_attention = is_cross_attention ,
640749 )
641750
642751 blocks_per_window = KVCacheManagerCpp .calculate_max_num_blocks (
643752 config = kv_cache_config ,
644- is_cross_attention = False , #TODO: support cross attention
753+ # TODO: support cross attention
754+ is_cross_attention = is_cross_attention ,
645755 dtype = self .dtype ,
646756 model_config = model_config ,
647757 world_config = world_config_cpp ,
0 commit comments