Skip to content

Commit 8d1b068

Browse files
[TRTLLM-8477][chore] Replace KvCacheConfigCpp with KvCacheConfig inside PyExecutor (#8259)
Signed-off-by: leslie-fang25 <leslief@nvidia.com>
1 parent 1a90449 commit 8d1b068

File tree

6 files changed

+23
-44
lines changed

6 files changed

+23
-44
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tensorrt_llm._torch.model_config import ModelConfig
1111
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1212
from tensorrt_llm.bindings.executor import DecodingMode
13-
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig,
13+
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
1414
MTPDecodingConfig, PeftCacheConfig,
1515
SamplerType, SpeculativeConfig,
1616
TorchLlmArgs)
@@ -58,7 +58,7 @@ def __init__(
5858
tokens_per_block: int,
5959
max_seq_len: int,
6060
max_batch_size: int,
61-
kv_cache_config: trtllm.KvCacheConfig,
61+
kv_cache_config: KvCacheConfig,
6262
pytorch_backend_config: PyTorchConfig,
6363
speculative_config: SpeculativeConfig,
6464
):
@@ -790,7 +790,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
790790
max_seq_len: int, mm_encoder_only: bool,
791791
speculative_config: SpeculativeConfig,
792792
decoding_config: trtllm.DecodingConfig,
793-
kv_cache_config: trtllm.KvCacheConfig):
793+
kv_cache_config: KvCacheConfig):
794794
sampler_args = create_torch_sampler_args(
795795
mapping,
796796
max_seq_len=engine.max_seq_len,

tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
2121
from tensorrt_llm._torch.pyexecutor.resource_manager import (
22-
BaseResourceManager, CacheTypeCpp, DataType, KvCacheConfigCpp,
23-
KVCacheManager, get_pp_layers)
22+
BaseResourceManager, CacheTypeCpp, DataType, KVCacheManager, get_pp_layers)
2423
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
24+
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
2525
from tensorrt_llm.mapping import Mapping
2626

2727

@@ -180,7 +180,7 @@ def __init__(
180180
mamba_ssm_cache_dtype: torch.dtype,
181181

182182
# kv cache parameters
183-
kv_cache_config: KvCacheConfigCpp,
183+
kv_cache_config: KvCacheConfig,
184184
kv_cache_type: CacheTypeCpp,
185185
*,
186186
num_layers: int,

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def create_py_executor(
223223
llm_args.peft_cache_config)
224224

225225
assert llm_args.kv_cache_config, "Expect llm_args.kv_cache_config is not None"
226-
kv_cache_config = PybindMirror.maybe_to_pybind(llm_args.kv_cache_config)
226+
kv_cache_config = llm_args.kv_cache_config
227227
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
228228
# Disable KV cache reuse for deterministic mode
229229
kv_cache_config.enable_block_reuse = False
@@ -251,7 +251,7 @@ def create_py_executor(
251251
if max_num_tokens is None:
252252
max_num_tokens = 8192
253253

254-
tokens_per_block = llm_args.kv_cache_config.tokens_per_block
254+
tokens_per_block = kv_cache_config.tokens_per_block
255255

256256
if pytorch_backend_config.attn_backend in [
257257
"FLASHINFER", "FLASHINFER_STAR_ATTENTION"

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import tensorrt_llm.bindings
1212
from tensorrt_llm._utils import mpi_disabled
1313
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
14+
from tensorrt_llm.llmapi.llm_args import KvCacheConfig, PybindMirror
1415
from tensorrt_llm.lora_helper import LoraConfig
1516
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
1617
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
@@ -31,7 +32,6 @@
3132

3233
BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager
3334
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
34-
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
3535
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
3636
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
3737
DataType = tensorrt_llm.bindings.DataType
@@ -145,7 +145,7 @@ class KVCacheManager(BaseResourceManager):
145145

146146
def __init__(
147147
self,
148-
kv_cache_config: KvCacheConfigCpp,
148+
kv_cache_config: KvCacheConfig,
149149
kv_cache_type: CacheTypeCpp,
150150
*,
151151
num_layers: int,
@@ -268,8 +268,8 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
268268
)
269269
# kv cache config check
270270
assert isinstance(
271-
kv_cache_config, KvCacheConfigCpp
272-
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfigCpp"
271+
kv_cache_config, KvCacheConfig
272+
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
273273
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
274274
kv_cache_config=kv_cache_config,
275275
model_config=model_config,
@@ -370,28 +370,6 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
370370
def shutdown(self):
371371
self.impl.release_pools()
372372

373-
@classmethod
374-
def from_model_config(cls,
375-
model_config: ModelConfigCpp,
376-
kv_cache_config: KvCacheConfigCpp,
377-
mapping: Mapping,
378-
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
379-
dtype: DataType = DataType.HALF) -> "KVCacheManager":
380-
return cls(
381-
kv_cache_config,
382-
kv_cache_type,
383-
num_layers=model_config.num_attention_layers(mapping.pp_size),
384-
# NOTE: this preserves existing behavior in KV cache manager.
385-
# But we should change this to pass a list at some point.
386-
# We're assuming the KV cache is homogeneous here.
387-
num_kv_heads=model_config.num_kv_heads(0),
388-
head_dim=model_config.size_per_head,
389-
tokens_per_block=model_config.tokens_per_block,
390-
max_seq_len=model_config.max_seq_len,
391-
max_batch_size=model_config.max_batch_size,
392-
mapping=mapping,
393-
dtype=dtype)
394-
395373
def get_max_resource_count(self) -> int:
396374
return self.impl.max_num_blocks
397375

@@ -566,7 +544,7 @@ def calculate_scaling_factor_size_bytes(
566544
scaling_factor_dtype)
567545

568546
def calculate_max_num_blocks(self,
569-
kv_cache_config: KvCacheConfigCpp,
547+
kv_cache_config: KvCacheConfig,
570548
head_dim: int,
571549
tokens_per_block: int,
572550
mapping: Mapping,
@@ -772,7 +750,7 @@ def _get_window_size_to_layers(self) -> dict[int, list[int]]:
772750
def adjust_window_sizes_for_vswa(
773751
window_size_to_layers: Dict[int, List[int]],
774752
max_attention_window_vec: List[int],
775-
kv_cache_config: KvCacheConfigCpp,
753+
kv_cache_config: KvCacheConfig,
776754
model_config: ModelConfigCpp,
777755
pool_memory_bytes: int,
778756
kv_factor: int,
@@ -887,7 +865,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int:
887865

888866
def calculate_max_num_blocks_from_cpp(
889867
self,
890-
kv_cache_config: KvCacheConfigCpp,
868+
kv_cache_config: KvCacheConfig,
891869
model_config: ModelConfigCpp,
892870
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
893871
"""
@@ -945,7 +923,7 @@ def calculate_max_num_blocks_from_cpp(
945923
self.max_attention_window_vec = max_attention_window_vec
946924

947925
blocks_per_window = KVCacheManagerCpp.calculate_max_num_blocks(
948-
config=kv_cache_config,
926+
config=PybindMirror.maybe_to_pybind(kv_cache_config),
949927
# TODO: support cross attention
950928
is_cross_attention=is_cross_attention,
951929
dtype=self.dtype,

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
from tensorrt_llm.bindings import (CudaStream, DataType, ModelConfig,
1818
WorldConfig, make_sampling_config)
1919
from tensorrt_llm.bindings.executor import (DecodingConfig, DecodingMode,
20-
FinishReason, KvCacheConfig)
20+
FinishReason)
2121
from tensorrt_llm.bindings.internal.algorithms import CreateNewDecoderRequests
2222
from tensorrt_llm.bindings.internal.batch_manager import (
2323
DecoderInputBuffers, add_new_tokens_to_requests, make_decoding_batch_input)
2424
from tensorrt_llm.bindings.internal.runtime import (BufferManager, CudaEvent,
2525
DecoderState,
2626
GptDecoderBatched)
2727
from tensorrt_llm.executor.result import Logprob
28+
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
2829
from tensorrt_llm.mapping import Mapping
2930
from tensorrt_llm.sampling_params import SamplingParams
3031

tests/unittest/_torch/executor/test_resource_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from tensorrt_llm.bindings import executor as tllm
2121
from tensorrt_llm.bindings.internal.batch_manager import \
2222
PeftTaskNotCachedException
23+
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
2324
from tensorrt_llm.lora_helper import LoraConfig
2425
from tensorrt_llm.mapping import Mapping
2526

@@ -574,11 +575,11 @@ def _create_model_config_for_kv_cache_manager() -> ModelConfigCpp:
574575

575576
@staticmethod
576577
def _create_kv_cache_config_for_kv_cache_manager(
577-
params: dict) -> tllm.KvCacheConfig:
578+
params: dict) -> KvCacheConfig:
578579
"""
579580
Create a KV cache config for KVCacheManager test.
580581
"""
581-
return tllm.KvCacheConfig(**params)
582+
return KvCacheConfig(**params)
582583

583584
def test_calculate_max_num_blocks_from_cpp(self):
584585
# Construct a minimal mapping (single-rank, no TP/PP)
@@ -633,9 +634,8 @@ class MemTestCase(NamedTuple):
633634
"free_gpu_memory_fraction": free_gpu_memory_fraction,
634635
"enable_block_reuse": enable_block_reuse,
635636
},
636-
# NOTE: use np.float32 to avoid float precision issue between python(double in most cases) and cpp binding(float)
637-
expected_memory_bytes=(int(
638-
fixed_free_mem * np.float32(free_gpu_memory_fraction)), 0),
637+
expected_memory_bytes=(int(fixed_free_mem *
638+
free_gpu_memory_fraction), 0),
639639
),
640640
]
641641

0 commit comments

Comments
 (0)