Skip to content

Commit cc81028

Browse files
authored
[TRTLLM-8812][chore] Limit the scope of pybind based CacheTransceiverConfig (NVIDIA#8558)
Signed-off-by: junq <[email protected]>
1 parent ee21ea3 commit cc81028

File tree

4 files changed

+19
-25
lines changed

4 files changed

+19
-25
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
MODEL_CLASS_VISION_ENCODER_MAPPING
1212
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
1313
from tensorrt_llm.bindings.executor import DecodingMode
14-
from tensorrt_llm.llmapi.llm_args import (EagleDecodingConfig, KvCacheConfig,
14+
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig,
15+
EagleDecodingConfig, KvCacheConfig,
1516
MTPDecodingConfig, PeftCacheConfig,
1617
SamplerType, SchedulerConfig,
1718
SparseAttentionConfig,
@@ -666,7 +667,7 @@ def create_py_executor_instance(
666667
max_num_tokens: Optional[int] = None,
667668
peft_cache_config: Optional[PeftCacheConfig] = None,
668669
scheduler_config: Optional[SchedulerConfig] = None,
669-
cache_transceiver_config: Optional[trtllm.CacheTransceiverConfig] = None,
670+
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
670671
) -> PyExecutor:
671672
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
672673

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tensorrt_llm import logger
66
from tensorrt_llm._torch.distributed.communicator import Distributed
77
from tensorrt_llm.bindings import WorldConfig
8-
from tensorrt_llm.bindings.executor import CacheTransceiverConfig
8+
from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig
99
from tensorrt_llm.mapping import Mapping
1010

1111
from .llm_request import LlmRequest
@@ -36,13 +36,13 @@ def create_kv_cache_transceiver(
3636
logger.info("cache_transceiver is disabled")
3737
return None
3838

39-
if cache_transceiver_config.backend == BackendTypeCpp.DEFAULT:
39+
if cache_transceiver_config.backend == "DEFAULT":
4040
# When cache_transceiver_config.backend is not set, fallback to env_vars settings
4141
# NIXL is the default backend
42-
cache_transceiver_config.backend = BackendTypeCpp.NIXL
42+
cache_transceiver_config.backend = "NIXL"
4343
# Ordered by priority
44-
env_vars = [("TRTLLM_USE_UCX_KVCACHE", BackendTypeCpp.UCX),
45-
("TRTLLM_USE_MPI_KVCACHE", BackendTypeCpp.MPI)]
44+
env_vars = [("TRTLLM_USE_UCX_KVCACHE", "UCX"),
45+
("TRTLLM_USE_MPI_KVCACHE", "MPI")]
4646
for env_var, be_type in env_vars:
4747
if getenv(env_var) == "1":
4848
logger.warning(
@@ -51,10 +51,10 @@ def create_kv_cache_transceiver(
5151
cache_transceiver_config.backend = be_type
5252
break
5353

54-
if cache_transceiver_config.backend == BackendTypeCpp.MPI:
54+
if cache_transceiver_config.backend == "MPI":
5555
logger.warning(
5656
"MPI CacheTransceiver is deprecated, UCX or NIXL is recommended")
57-
elif cache_transceiver_config.backend == BackendTypeCpp.UCX:
57+
elif cache_transceiver_config.backend == "UCX":
5858
logger.info(
5959
f"Using UCX kv-cache transceiver. If your devices are not in the same domain, please consider setting "
6060
f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server "
@@ -116,7 +116,7 @@ def __init__(self, mapping: Mapping, dist: Distributed,
116116
tokens_per_block, world_config,
117117
pp_layer_num_per_pp_rank, dtype,
118118
attention_type,
119-
cache_transceiver_config)
119+
cache_transceiver_config._to_pybind())
120120

121121
def respond_and_send_async(self, req: LlmRequest):
122122
return self.impl.respond_and_send_async(req)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from tensorrt_llm.bindings.executor import GuidedDecodingConfig
1818
from tensorrt_llm.llmapi.llm_args import (CapacitySchedulerPolicy,
1919
ContextChunkingPolicy, LoadFormat,
20-
PybindMirror, TorchLlmArgs)
20+
TorchLlmArgs)
2121
from tensorrt_llm.llmapi.tokenizer import (TokenizerBase,
2222
_llguidance_tokenizer_info,
2323
_xgrammar_tokenizer_info)
@@ -289,10 +289,7 @@ def create_py_executor(
289289
else:
290290
dist = MPIDist(mapping=mapping)
291291

292-
cache_transceiver_config = None
293-
if llm_args.cache_transceiver_config is not None:
294-
cache_transceiver_config = PybindMirror.maybe_to_pybind(
295-
llm_args.cache_transceiver_config)
292+
cache_transceiver_config = llm_args.cache_transceiver_config
296293

297294
has_draft_model_engine = False
298295
has_spec_drafter = False

tests/unittest/others/test_kv_cache_transceiver.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
1313
LlmRequestState)
1414
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
15+
from tensorrt_llm.llmapi.llm_args import CacheTransceiverConfig
1516
from tensorrt_llm.mapping import Mapping
1617
from tensorrt_llm.sampling_params import SamplingParams
1718

@@ -67,11 +68,7 @@ def ctx_gen_kv_cache_dtype(request):
6768
@pytest.mark.parametrize("attention_type",
6869
[AttentionTypeCpp.DEFAULT, AttentionTypeCpp.MLA],
6970
ids=["mha", "mla"])
70-
@pytest.mark.parametrize("backend", [
71-
trtllm.CacheTransceiverBackendType.NIXL,
72-
trtllm.CacheTransceiverBackendType.UCX
73-
],
74-
ids=["NIXL", "UCX"])
71+
@pytest.mark.parametrize("backend", ["NIXL", "UCX"], ids=["NIXL", "UCX"])
7572
def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
7673
attention_type, backend):
7774
# Init kv_cache manager and cache transceiver
@@ -80,8 +77,8 @@ def test_kv_cache_transceiver_single_process(ctx_gen_kv_cache_dtype,
8077
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
8178
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
8279

83-
cache_transceiver_config = trtllm.CacheTransceiverConfig(
84-
backend=backend, max_tokens_in_buffer=512)
80+
cache_transceiver_config = CacheTransceiverConfig(backend=backend,
81+
max_tokens_in_buffer=512)
8582
dist = MPIDist(mapping=mapping)
8683
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
8784
mapping, dist, kv_cache_manager_ctx, attention_type,
@@ -147,9 +144,8 @@ def test_cancel_request_in_transmission(attention_type):
147144
kv_cache_manager_ctx = create_kv_cache_manager(mapping, ctx_kv_cache_dtype)
148145
kv_cache_manager_gen = create_kv_cache_manager(mapping, gen_kv_cache_dtype)
149146

150-
cache_transceiver_config = trtllm.CacheTransceiverConfig(
151-
backend=trtllm.CacheTransceiverBackendType.DEFAULT,
152-
max_tokens_in_buffer=512)
147+
cache_transceiver_config = CacheTransceiverConfig(backend="DEFAULT",
148+
max_tokens_in_buffer=512)
153149

154150
kv_cache_transceiver_ctx = create_kv_cache_transceiver(
155151
mapping, dist, kv_cache_manager_ctx, attention_type,

0 commit comments

Comments
 (0)