Skip to content

Commit a317fa2

Browse files
authored
feat: (cherry pick) Use ForwardPassCallback api from TRTLLM to register end of forward pass callback to enable cuda graphs #3297 (#4109)
Signed-off-by: Kyle McGill <[email protected]>
1 parent 1660b0c commit a317fa2

File tree

6 files changed

+326
-6
lines changed

6 files changed

+326
-6
lines changed

lib/bindings/python/rust/llm/block_manager/vllm/connector/trtllm_worker.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ pub trait Worker: Send + Sync {
3939

4040
fn start_load_kv(&mut self) -> anyhow::Result<()>;
4141

42+
fn execute_offload_operations(&mut self) -> anyhow::Result<()>;
43+
4244
fn save_kv_layer(&mut self, layer_idx: usize) -> anyhow::Result<()>;
4345

4446
fn get_finished(
@@ -215,16 +217,24 @@ impl Worker for KvConnectorWorker {
215217
Ok(())
216218
}
217219

220+
// Assumes the operations are in a valid state for offloading.
221+
fn execute_offload_operations(&mut self) -> anyhow::Result<()> {
222+
let offloading_operations = std::mem::take(&mut self.offloading_operations);
223+
for operation in offloading_operations {
224+
self.connector.enqueue_request(operation);
225+
}
226+
Ok(())
227+
}
228+
218229
fn save_kv_layer(&mut self, _layer_idx: usize) -> anyhow::Result<()> {
219230
self.layers_complete += 1;
220231
if self.layers_complete == self.layer_events.len() {
221-
let offloading_operations = std::mem::take(&mut self.offloading_operations);
222232
// block on the the completion of the last layer
223233
// todo(ryan): capture the context, pass this to the scheduler to do the await on another thread
224234
// or put the event on a stream and use stream waits to keep it all on device.
225235
event_sync_blocking(self.layer_events[self.layers_complete - 1]);
226-
for operation in offloading_operations {
227-
self.connector.enqueue_request(operation);
236+
if let Err(e) = self.execute_offload_operations() {
237+
tracing::error!("Failed to execute offload operations: {}", e);
228238
}
229239
}
230240
Ok(())
@@ -431,6 +441,12 @@ impl PyTrtllmKvConnectorWorker {
431441
.map_err(to_pyerr)
432442
}
433443

444+
pub fn execute_offload_operations(&mut self) -> PyResult<()> {
445+
self.connector_worker
446+
.execute_offload_operations()
447+
.map_err(to_pyerr)
448+
}
449+
434450
pub fn save_kv_layer(&mut self, layer_idx: usize) -> PyResult<()> {
435451
self.connector_worker
436452
.save_kv_layer(layer_idx)

lib/bindings/python/src/dynamo/llm/trtllm_integration/connector/kvbm_connector_worker.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,21 @@
1313

1414

1515
class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
16+
def _callable_object(self) -> callable:
17+
assert (
18+
self._connector is not None
19+
), "Expected cache connector worker to have non-None _connector obj"
20+
assert (
21+
self.event is not None
22+
), "Expected cache connector worker to have non-None event obj"
23+
24+
def callback():
25+
self.event.record()
26+
self.event.synchronize()
27+
self._connector.execute_offload_operations()
28+
29+
return callback
30+
1631
def __init__(self, llm_args: TorchLlmArgs):
1732
super().__init__(llm_args)
1833

@@ -22,6 +37,18 @@ def __init__(self, llm_args: TorchLlmArgs):
2237
self.rank = mappings.rank
2338

2439
self._connector = RustKvConnectorWorker(self.drt, str(self.rank))
40+
self.event = torch.cuda.Event()
41+
42+
# Default to old way of processing offload
43+
self.use_forward_pass_callable = False
44+
45+
def register_forward_pass_callable(self) -> callable:
46+
"""
47+
Register a callable object which will be called at the
48+
end of the forward pass.
49+
"""
50+
self.use_forward_pass_callable = True
51+
return self._callable_object()
2552

2653
def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
2754
"""
@@ -30,7 +57,6 @@ def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
3057
Args:
3158
kv_cache_tensor: The contiguous KV cache tensor.
3259
"""
33-
print(f"Register KV Caches on rank {self.rank}")
3460
logger.info(
3561
f"KvConnectorWorker started registering the kv caches on rank {self.rank}"
3662
)
@@ -104,8 +130,9 @@ def save_kv_layer(self, layer_idx: int, stream: torch.cuda.Stream):
104130
layer_idx: The index of the layer to save.
105131
stream: The stream the forward pass is being executed on.
106132
"""
107-
self.events[layer_idx].record(stream)
108-
self._connector.save_kv_layer(layer_idx)
133+
if not self.use_forward_pass_callable:
134+
self.events[layer_idx].record(stream)
135+
self._connector.save_kv_layer(layer_idx)
109136

110137
def get_finished(
111138
self, finished_gen_req_ids: list[int], started_loading_req_ids: list[int]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
backend: pytorch
5+
cuda_graph_config:
6+
max_batch_size: 8
7+
kv_cache_config:
8+
enable_partial_reuse: false
9+
free_gpu_memory_fraction: 0.80
10+
max_tokens: 8192
11+
kv_connector_config:
12+
connector_module: dynamo.llm.trtllm_integration.connector
13+
connector_scheduler_class: DynamoKVBMConnectorLeader
14+
connector_worker_class: DynamoKVBMConnectorWorker
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
backend: pytorch
5+
cuda_graph_config: null
6+
kv_cache_config:
7+
enable_partial_reuse: false
8+
free_gpu_memory_fraction: 0.80
9+
max_tokens: 8192
10+
kv_connector_config:
11+
connector_module: dynamo.llm.trtllm_integration.connector
12+
connector_scheduler_class: DynamoKVBMConnectorLeader
13+
connector_worker_class: DynamoKVBMConnectorWorker

tests/kvbm/test_cuda_graph.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Determinism test for language model API using pytest.
7+
8+
This test suite checks if the model produces deterministic outputs
9+
when given the same inputs with fixed seed and temperature=0.
10+
11+
The test uses comprehensive server warmup (sending all test prompts
12+
before validation) to avoid server initialization effects that could
13+
impact determinism measurements.
14+
"""
15+
16+
import logging
17+
import os
18+
import shutil
19+
20+
import pytest
21+
import requests
22+
23+
from tests.utils.engine_process import FRONTEND_PORT
24+
from tests.utils.managed_process import DynamoFrontendProcess, ManagedProcess
25+
from tests.utils.payloads import check_models_api
26+
27+
logger = logging.getLogger(__name__)
28+
29+
# Just need a model to show the config works rather than any stress of the system.
30+
MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
31+
SERVED_MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
32+
33+
PROMPT = "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
34+
35+
36+
class DynamoWorkerProcess(ManagedProcess):
37+
"""Process manager for Dynamo worker with TRTLLM backend"""
38+
39+
def __init__(self, request, worker_id: str, engine_config: str):
40+
self.worker_id = worker_id
41+
42+
command = [
43+
"python3",
44+
"-m",
45+
"dynamo.trtllm",
46+
"--model",
47+
MODEL_PATH,
48+
"--served-model-name",
49+
SERVED_MODEL_NAME,
50+
"--extra-engine-args",
51+
engine_config,
52+
]
53+
54+
# Set debug logging environment
55+
env = os.environ.copy()
56+
env["DYN_LOG"] = "debug"
57+
env["DYN_SYSTEM_ENABLED"] = "true"
58+
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
59+
env["DYN_SYSTEM_PORT"] = "9345"
60+
env["DYN_KVBM_CPU_CACHE_GB"] = "20"
61+
env["DYN_KVBM_DISK_CACHE_GB"] = "60"
62+
env["DYN_KVBM_LEADER_WORKER_INIT_TIMEOUT_SECS"] = "1200"
63+
64+
# TODO: Have the managed process take a command name explicitly to distinguish
65+
# between processes started with the same command.
66+
log_dir = f"{request.node.name}_{worker_id}"
67+
68+
# Clean up any existing log directory from previous runs
69+
try:
70+
shutil.rmtree(log_dir)
71+
logger.info(f"Cleaned up existing log directory: {log_dir}")
72+
except FileNotFoundError:
73+
# Directory doesn't exist, which is fine
74+
pass
75+
76+
super().__init__(
77+
command=command,
78+
env=env,
79+
health_check_urls=[
80+
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
81+
("http://localhost:9345/health", self.is_ready),
82+
],
83+
timeout=300,
84+
display_output=True,
85+
terminate_existing=False,
86+
log_dir=log_dir,
87+
)
88+
89+
def get_pid(self) -> int | None:
90+
"""Get the PID of the worker process"""
91+
return self.proc.pid if hasattr(self, "proc") and self.proc else None
92+
93+
def is_ready(self, response) -> bool:
94+
"""Check the health of the worker process"""
95+
try:
96+
data = response.json()
97+
if data.get("status") == "ready":
98+
logger.info(
99+
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is ready"
100+
)
101+
return True
102+
logger.warning(
103+
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is not ready: {data.get('status')}"
104+
)
105+
except ValueError:
106+
logger.warning(
107+
f"{self.__class__.__name__} {{ name: {self.worker_id} }} health response is not valid JSON"
108+
)
109+
return False
110+
111+
112+
def send_completion_request(
113+
prompt: str, max_tokens: int, timeout: int = 120
114+
) -> requests.Response:
115+
"""Send a completion request to the frontend"""
116+
payload = {
117+
"model": SERVED_MODEL_NAME,
118+
"prompt": prompt,
119+
"stream": False,
120+
"max_tokens": max_tokens,
121+
}
122+
123+
headers = {"Content-Type": "application/json"}
124+
125+
logger.info(
126+
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
127+
)
128+
129+
try:
130+
response = requests.post(
131+
"http://localhost:8000/v1/completions",
132+
headers=headers,
133+
json=payload,
134+
timeout=timeout,
135+
)
136+
return response
137+
except requests.exceptions.Timeout:
138+
logger.error(f"Request timed out after {timeout} seconds")
139+
raise
140+
except requests.exceptions.RequestException as e:
141+
logger.error(f"Request failed with error: {e}")
142+
raise
143+
144+
145+
# Test markers to align with repository conventions
146+
# Todo: enable the rest when kvbm is built in the ci
147+
@pytest.mark.kvbm
148+
@pytest.mark.trtllm_marker
149+
@pytest.mark.e2e
150+
@pytest.mark.slow
151+
@pytest.mark.gpu_1
152+
@pytest.mark.skip(
153+
reason="Enable these tests once `main` dynamo upgrades to TRTLLM 1.2+"
154+
)
155+
def test_kvbm_without_cuda_graph_enabled(request, runtime_services):
156+
"""
157+
End-to-end test for TRTLLM worker with cuda_graph_config not defined and
158+
KVBM enabled.
159+
160+
This test verifies a TRTLLM worker is able to serve requests when
161+
cuda graphs are not enabled in pytorch. KVBM should be able to offload
162+
blocks regardless.
163+
"""
164+
165+
logger.info("Starting frontend...")
166+
with DynamoFrontendProcess(request):
167+
logger.info("Frontend started.")
168+
169+
engine_config_with_cuda_graph_and_kvbm = (
170+
"tests/kvbm/engine_config_without_cuda_graph_and_kvbm.yaml"
171+
)
172+
logger.info("Starting worker...")
173+
with DynamoWorkerProcess(
174+
request, "decode", engine_config_with_cuda_graph_and_kvbm
175+
) as worker:
176+
logger.info(f"Worker PID: {worker.get_pid()}")
177+
178+
response = send_completion_request(PROMPT, 100, timeout=10)
179+
assert (
180+
response.ok
181+
), f"Expected successful status, got {response.status_code}"
182+
logger.info(f"Completion request succeeded: {response.status_code}")
183+
184+
185+
@pytest.mark.kvbm
186+
@pytest.mark.trtllm_marker
187+
@pytest.mark.e2e
188+
@pytest.mark.slow
189+
@pytest.mark.gpu_1
190+
@pytest.mark.skip(
191+
reason="Enable these tests once dynamo `main` upgrades to TRTLLM 1.2+"
192+
)
193+
def test_kvbm_with_cuda_graph_enabled(request, runtime_services):
194+
"""
195+
End-to-end test for TRTLLM worker with cuda_graph_config defined and
196+
KVBM enabled.
197+
198+
This test verifies a TRTLLM worker is able to serve requests when
199+
cuda graphs are enabled in pytorch. KVBM should be able to offload
200+
blocks regardless.
201+
"""
202+
203+
logger.info("Starting frontend...")
204+
with DynamoFrontendProcess(request):
205+
logger.info("Frontend started.")
206+
207+
engine_config_with_cuda_graph_and_kvbm = (
208+
"tests/kvbm/engine_config_with_cuda_graph_and_kvbm.yaml"
209+
)
210+
logger.info("Starting worker...")
211+
with DynamoWorkerProcess(
212+
request, "decode", engine_config_with_cuda_graph_and_kvbm
213+
) as worker:
214+
logger.info(f"Worker PID: {worker.get_pid()}")
215+
216+
response = send_completion_request(PROMPT, 100, timeout=10)
217+
assert (
218+
response.ok
219+
), f"Expected successful status, got {response.status_code}"
220+
logger.info(f"Completion request succeeded: {response.status_code}")

tests/utils/managed_process.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,36 @@ def subprocesses(self) -> list[psutil.Process]:
568568
return []
569569

570570

571+
class DynamoFrontendProcess(ManagedProcess):
572+
"""Process manager for Dynamo frontend"""
573+
574+
_logger = logging.getLogger()
575+
576+
def __init__(self, request):
577+
command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"]
578+
579+
log_dir = f"{request.node.name}_frontend"
580+
581+
# Clean up any existing log directory from previous runs
582+
try:
583+
shutil.rmtree(log_dir)
584+
self._logger.info(f"Cleaned up existing log directory: {log_dir}")
585+
except FileNotFoundError:
586+
# Directory doesn't exist, which is fine
587+
pass
588+
589+
super().__init__(
590+
command=command,
591+
display_output=True,
592+
terminate_existing=True,
593+
log_dir=log_dir,
594+
)
595+
596+
def get_pid(self) -> int | None:
597+
"""Get the PID of the worker process"""
598+
return self.proc.pid if self.proc else None
599+
600+
571601
def main():
572602
with ManagedProcess(
573603
command=[

0 commit comments

Comments
 (0)