Skip to content

Commit 84e71e2

Browse files
feat: predictive active blocks for routing without load metrics (ai-dynamo#1731)
Signed-off-by: Yan Ru Pei <yanrpei@gmail.com> Co-authored-by: Alec <35311602+alec-flowers@users.noreply.github.com>
1 parent ffccc72 commit 84e71e2

File tree

25 files changed

+1158
-517
lines changed

25 files changed

+1158
-517
lines changed

components/metrics/src/lib.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ use serde::{Deserialize, Serialize};
8383
use std::net::SocketAddr;
8484
use std::time::Duration as StdDuration;
8585

86-
use dynamo_llm::kv_router::protocols::ForwardPassMetrics;
86+
use dynamo_llm::kv_router::protocols::{ForwardPassMetrics, LoadMetrics};
8787
use dynamo_llm::kv_router::scheduler::Endpoint;
8888
use dynamo_llm::kv_router::scoring::ProcessedEndpoints;
8989

@@ -449,7 +449,10 @@ impl PrometheusMetrics {
449449
// Update per-worker metrics
450450
for (worker_id, endpoint) in processed.endpoints.iter() {
451451
let worker_id = worker_id.to_string();
452-
let metrics = endpoint.data.clone();
452+
let load_metrics = endpoint.data.clone();
453+
let LoadMetrics::EngineLoadMetrics(metrics) = load_metrics else {
454+
panic!("Can only update with ForwardPassMetrics");
455+
};
453456

454457
self.set_worker_gauge(
455458
&self.kv_blocks_active,
@@ -602,7 +605,7 @@ pub fn postprocess_metrics(
602605
e.id().ok().map(|id| Endpoint {
603606
name: format!("worker-{id}"),
604607
subject: e.subject.clone(),
605-
data: m.clone(),
608+
data: LoadMetrics::EngineLoadMetrics(m.clone()),
606609
})
607610
})
608611
.collect();

components/router/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async fn app(runtime: Runtime) -> Result<()> {
6666

6767
let selector = Box::new(CustomWorkerSelector::default());
6868

69-
let router = KvRouter::new(component.clone(), args.block_size, Some(selector)).await?;
69+
let router = KvRouter::new(component.clone(), args.block_size, Some(selector), true).await?;
7070
let router = Ingress::for_engine(Arc::new(router))?;
7171

7272
component

container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,14 +3392,8 @@ index cafd8150b..6a5e45b4e 100644
33923392
+ num_requests_waiting: int
33933393
+ gpu_cache_usage_perc: float
33943394
+ gpu_prefix_cache_hit_rate: float
3395-
+ spec_decode_draft_acceptance_rate: Optional[float] = None
3396-
+ spec_decode_system_efficiency: Optional[float] = None
3397-
+ spec_decode_draft_tokens: Optional[int] = None
3398-
+ spec_decode_emitted_tokens: Optional[int] = None
3399-
+ spec_decode_accepted_tokens: Optional[int] = None
3400-
+ spec_decode_num_spec_tokens: Optional[int] = None
34013395
diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py
3402-
index f058b1329..2fdb5b8bf 100644
3396+
index f058b1329..fd5610a3c 100644
34033397
--- a/vllm/engine/multiprocessing/client.py
34043398
+++ b/vllm/engine/multiprocessing/client.py
34053399
@@ -1,4 +1,17 @@
@@ -3460,24 +3454,33 @@ index f058b1329..2fdb5b8bf 100644
34603454
from vllm.engine.protocol import EngineClient
34613455
# yapf: enable
34623456
from vllm.envs import VLLM_RPC_TIMEOUT
3463-
@@ -48,6 +66,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
3457+
@@ -48,6 +66,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
34643458
from vllm.sampling_params import SamplingParams
34653459
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
34663460
from vllm.utils import Device, deprecate_kwargs
34673461
+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
34683462
+from vllm.distributed.device_communicators.nixl import NixlMetadata
3463+
+
3464+
+# Import ForwardPassMetrics and related classes from dynamo
3465+
+try:
3466+
+ from dynamo.llm import ForwardPassMetrics, WorkerStats, KvStats
3467+
+except ImportError:
3468+
+ # Fallback if dynamo imports are not available
3469+
+ ForwardPassMetrics = None
3470+
+ WorkerStats = None
3471+
+ KvStats = None
34693472

34703473
logger = init_logger(__name__)
34713474

3472-
@@ -93,6 +113,7 @@ class MQLLMEngineClient(EngineClient):
3475+
@@ -93,6 +122,7 @@ class MQLLMEngineClient(EngineClient):
34733476
self._errored_with: Optional[BaseException] = None
34743477

34753478
# Get the configs.
34763479
+ self.vllm_config = engine_config
34773480
self.model_config = engine_config.model_config
34783481
self.decoding_config = engine_config.decoding_config
34793482

3480-
@@ -117,6 +138,10 @@ class MQLLMEngineClient(EngineClient):
3483+
@@ -117,6 +147,10 @@ class MQLLMEngineClient(EngineClient):
34813484
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
34823485
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
34833486

@@ -3488,7 +3491,7 @@ index f058b1329..2fdb5b8bf 100644
34883491
# IPC path for the data socket.
34893492
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
34903493

3491-
@@ -131,8 +156,27 @@ class MQLLMEngineClient(EngineClient):
3494+
@@ -131,8 +165,27 @@ class MQLLMEngineClient(EngineClient):
34923495
# Loop to check health of the LLMEngine periodically.
34933496
# Started after the MQLLMEngine is ready.
34943497
self.health_loop: Optional[asyncio.Task] = None
@@ -3516,7 +3519,7 @@ index f058b1329..2fdb5b8bf 100644
35163519
@staticmethod
35173520
def is_unsupported_config(vllm_config: VllmConfig):
35183521
# Pipeline parallel not yet supported
3519-
@@ -182,6 +226,61 @@ class MQLLMEngineClient(EngineClient):
3522+
@@ -182,6 +235,76 @@ class MQLLMEngineClient(EngineClient):
35203523
except Exception as e:
35213524
self._set_errored(e)
35223525

@@ -3553,13 +3556,28 @@ index f058b1329..2fdb5b8bf 100644
35533556
+ if self.metrics_publisher is not None and isinstance(
35543557
+ metrics, KvMetrics
35553558
+ ):
3556-
+ self.metrics_publisher.publish(metrics.request_active_slots,
3557-
+ metrics.request_total_slots,
3558-
+ metrics.kv_active_blocks,
3559-
+ metrics.kv_total_blocks,
3560-
+ metrics.num_requests_waiting,
3561-
+ metrics.gpu_cache_usage_perc,
3562-
+ metrics.gpu_prefix_cache_hit_rate)
3559+
+ # Construct structured metrics objects
3560+
+ worker_stats = WorkerStats(
3561+
+ request_active_slots=metrics.request_active_slots,
3562+
+ request_total_slots=metrics.request_total_slots,
3563+
+ num_requests_waiting=metrics.num_requests_waiting,
3564+
+ data_parallel_rank=None
3565+
+ )
3566+
+
3567+
+ kv_stats = KvStats(
3568+
+ kv_active_blocks=metrics.kv_active_blocks,
3569+
+ kv_total_blocks=metrics.kv_total_blocks,
3570+
+ gpu_cache_usage_perc=metrics.gpu_cache_usage_perc,
3571+
+ gpu_prefix_cache_hit_rate=metrics.gpu_prefix_cache_hit_rate
3572+
+ )
3573+
+
3574+
+ forward_pass_metrics = ForwardPassMetrics(
3575+
+ worker_stats=worker_stats,
3576+
+ kv_stats=kv_stats,
3577+
+ spec_decode_stats=None
3578+
+ )
3579+
+
3580+
+ self.metrics_publisher.publish(forward_pass_metrics)
35633581
+ logger.debug("Metrics successful.")
35643582
+
35653583
+ # TODO: Investigate sending whole stats object
@@ -3578,7 +3596,7 @@ index f058b1329..2fdb5b8bf 100644
35783596
async def run_output_handler_loop(self):
35793597
"""Get RequestOutputs from Engine and stream to Request Queues"""
35803598

3581-
@@ -250,7 +349,7 @@ class MQLLMEngineClient(EngineClient):
3599+
@@ -250,7 +373,7 @@ class MQLLMEngineClient(EngineClient):
35823600
# Put each output into the appropriate queue.
35833601
elif isinstance(
35843602
request_outputs,
@@ -3587,7 +3605,7 @@ index f058b1329..2fdb5b8bf 100644
35873605
self._add_output(request_outputs)
35883606
else:
35893607
for request_output in request_outputs:
3590-
@@ -261,7 +360,7 @@ class MQLLMEngineClient(EngineClient):
3608+
@@ -261,7 +384,7 @@ class MQLLMEngineClient(EngineClient):
35913609

35923610
def _add_output(self, request_output: Union[RequestOutput,
35933611
RPCAdapterLoadedResponse,
@@ -3596,7 +3614,7 @@ index f058b1329..2fdb5b8bf 100644
35963614
queue = self.output_queues.get(request_output.request_id)
35973615
if queue is not None:
35983616
queue.put_nowait(request_output)
3599-
@@ -283,12 +382,25 @@ class MQLLMEngineClient(EngineClient):
3617+
@@ -283,12 +406,25 @@ class MQLLMEngineClient(EngineClient):
36003618
# Wait until server is ready.
36013619
response = await self._wait_for_server_rpc(socket)
36023620

@@ -3622,7 +3640,7 @@ index f058b1329..2fdb5b8bf 100644
36223640

36233641
def close(self):
36243642
"""Destroy the ZeroMQ Context."""
3625-
@@ -298,6 +410,8 @@ class MQLLMEngineClient(EngineClient):
3643+
@@ -298,6 +434,8 @@ class MQLLMEngineClient(EngineClient):
36263644
# Cancel background tasks.
36273645
if self.health_loop is not None:
36283646
self.health_loop.cancel()
@@ -3631,7 +3649,7 @@ index f058b1329..2fdb5b8bf 100644
36313649
if self.output_loop is not None:
36323650
self.output_loop.cancel()
36333651

3634-
@@ -420,6 +534,9 @@ class MQLLMEngineClient(EngineClient):
3652+
@@ -420,6 +558,9 @@ class MQLLMEngineClient(EngineClient):
36353653
"""
36363654
if self._errored_with is not None:
36373655
raise self._errored_with
@@ -3641,15 +3659,15 @@ index f058b1329..2fdb5b8bf 100644
36413659

36423660
@property
36433661
def is_running(self) -> bool:
3644-
@@ -478,6 +595,7 @@ class MQLLMEngineClient(EngineClient):
3662+
@@ -478,6 +619,7 @@ class MQLLMEngineClient(EngineClient):
36453663
trace_headers: Optional[Mapping[str, str]] = None,
36463664
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
36473665
priority: int = 0,
36483666
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
36493667
*,
36503668
inputs: Optional[PromptType] = None # DEPRECATED
36513669
) -> AsyncGenerator[RequestOutput, None]:
3652-
@@ -507,7 +625,8 @@ class MQLLMEngineClient(EngineClient):
3670+
@@ -507,7 +649,8 @@ class MQLLMEngineClient(EngineClient):
36533671

36543672
return self._process_request(prompt, sampling_params, request_id,
36553673
lora_request, trace_headers,
@@ -3659,15 +3677,15 @@ index f058b1329..2fdb5b8bf 100644
36593677

36603678
@overload
36613679
def encode(
3662-
@@ -591,6 +710,7 @@ class MQLLMEngineClient(EngineClient):
3680+
@@ -591,6 +734,7 @@ class MQLLMEngineClient(EngineClient):
36633681
trace_headers: Optional[Mapping[str, str]] = None,
36643682
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
36653683
priority: int = 0,
36663684
+ remote_prefill_params: Optional[RemotePrefillParams] = None,
36673685
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
36683686
PoolingRequestOutput, None]]:
36693687
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
3670-
@@ -636,6 +756,12 @@ class MQLLMEngineClient(EngineClient):
3688+
@@ -636,6 +780,12 @@ class MQLLMEngineClient(EngineClient):
36713689
else:
36723690
lp_bytes = None
36733691

@@ -3680,7 +3698,7 @@ index f058b1329..2fdb5b8bf 100644
36803698
request_bytes = pickle.dumps(
36813699
RPCProcessRequest(
36823700
prompt=prompt,
3683-
@@ -645,11 +771,11 @@ class MQLLMEngineClient(EngineClient):
3701+
@@ -645,11 +795,11 @@ class MQLLMEngineClient(EngineClient):
36843702
trace_headers=trace_headers,
36853703
prompt_adapter_request=prompt_adapter_request,
36863704
priority=priority,
@@ -3694,7 +3712,7 @@ index f058b1329..2fdb5b8bf 100644
36943712
await self.input_socket.send_multipart(parts, copy=False)
36953713

36963714
# 4) Stream the RequestOutputs from the output queue. Note
3697-
@@ -740,3 +866,22 @@ class MQLLMEngineClient(EngineClient):
3715+
@@ -740,3 +890,22 @@ class MQLLMEngineClient(EngineClient):
36983716
# Raise on error, otherwise happily return None
36993717
if isinstance(request_output, BaseException):
37003718
raise request_output

docs/architecture/kv_cache_routing.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ See the License for the specific language governing permissions and
1515
limitations under the License.
1616
-->
1717

18+
>[!NOTE]
19+
>This information is temporary and will change soon.
1820
1921
# KV Cache Routing
2022
This documentation explains how Key-Value (KV) cache routing works in Dynamo, providing optimized inference for large language models by intelligently directing requests to workers with the most relevant cached data while simultaneously load balancing based on utilization metrics sent by the workers.

docs/guides/dynamo_run.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ It supports these engines: mistralrs, llamacpp, sglang, vllm, and tensorrt-llm.
88

99
Usage:
1010
```
11-
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=2.0] [--kv-gpu-cache-usage-weight=1.0] [--kv-waiting-requests-weight=1.0] [--verbosity (-v|-vv)]
11+
dynamo-run in=[http|text|dyn://<path>|batch:<folder>] out=echo_core|echo_full|mistralrs|llamacpp|sglang|vllm|dyn [--http-port 8080] [--model-path <path>] [--model-name <served-model-name>] [--model-config <hf-repo>] [--tensor-parallel-size=1] [--context-length=N] [--num-nodes=1] [--node-rank=0] [--leader-addr=127.0.0.1:9876] [--base-gpu-id=0] [--extra-engine-args=args.json] [--router-mode random|round-robin|kv] [--kv-overlap-score-weight=1.0] [--router-temperature=0.5] [--verbosity (-v|-vv)]
1212
```
1313

1414
Example: `dynamo run Qwen/Qwen3-0.6B`
@@ -201,6 +201,8 @@ The only difference from the distributed system above is `--router-mode kv`. The
201201
202202
For performance testing, compare a typical workload with `--router-mode random|round-robin` to see if it can benefit from KV-aware routing.
203203
204+
The argument `--kv-overlap-score-weight` sets the amount weighting on overlaps with prefix caches, which directly contributes to the prefill cost, so a large weight is expected to yield a better TTFT (at the expense of worse ITL). When this is set 0, we do not consider the prefix caches at all (falling back to pure load balancing behavior on the active blocks), in which case we do not require the backend engines to emit any KV events. The argument `--router-temperature` sets the temperature when randomly selecting the workers to route to via softmax sampling on the router cost logits, setting it to 0 recovers the deterministic behavior where the min logit is picked.
205+
204206
## Full usage details
205207
206208
`dynamo run` executes `dynamo-run`. `dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.

docs/guides/kv_router_perf_tuning.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ See the License for the specific language governing permissions and
1515
limitations under the License.
1616
-->
1717

18+
>[!NOTE]
19+
>This information is temporary and will change soon.
20+
1821
# KV Router Performance Tuning
1922

2023
## Overview

launch/dynamo-run/src/flags.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,20 +110,23 @@ pub struct Flags {
110110
#[arg(long, default_value = "round-robin")]
111111
pub router_mode: RouterMode,
112112

113+
/// Maximum number of batched tokens for KV routing
114+
/// Needed for informing the KV router
115+
/// TODO: derive from vllm args
116+
/// NOTE: this is not actually used for now
117+
#[arg(long, default_value = "8192")]
118+
pub max_num_batched_tokens: Option<u32>,
119+
113120
/// KV Router: Weight for overlap score in worker selection.
114121
/// Higher values prioritize KV cache reuse. Default: 2.0
115122
#[arg(long)]
116123
pub kv_overlap_score_weight: Option<f64>,
117124

118-
/// KV Router: Weight for GPU cache usage in worker selection.
119-
/// Higher values avoid workers with nearly full KV caches. Default: 1.0
120-
#[arg(long)]
121-
pub kv_gpu_cache_usage_weight: Option<f64>,
122-
123-
/// KV Router: Weight for waiting requests in worker selection.
124-
/// Higher values avoid workers with queued requests. Default: 1.0
125+
/// KV Router: Temperature for worker sampling via softmax.
126+
/// Higher values promote more randomness, and 0 fallbacks to deterministic.
127+
/// Default: 0.5
125128
#[arg(long)]
126-
pub kv_waiting_requests_weight: Option<f64>,
129+
pub router_temperature: Option<f64>,
127130

128131
/// Max model context length. Reduce this if you don't have enough VRAM for the full model
129132
/// context length (e.g. Llama 4).
@@ -211,8 +214,8 @@ impl Flags {
211214
self.router_mode.into(),
212215
KvRouterConfig::new(
213216
self.kv_overlap_score_weight,
214-
self.kv_gpu_cache_usage_weight,
215-
self.kv_waiting_requests_weight,
217+
self.router_temperature,
218+
self.max_num_batched_tokens,
216219
),
217220
)
218221
}

launch/dynamo-run/src/subprocess/vllm_inc.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@
2626
)
2727
from vllm.inputs import TokensPrompt
2828

29-
from dynamo.llm import ModelType, WorkerMetricsPublisher, register_llm
29+
from dynamo.llm import (
30+
ForwardPassMetrics,
31+
KvStats,
32+
ModelType,
33+
WorkerMetricsPublisher,
34+
WorkerStats,
35+
register_llm,
36+
)
3037
from dynamo.runtime import DistributedRuntime, dynamo_worker
3138
from dynamo.runtime.logging import configure_dynamo_logging
3239

@@ -70,15 +77,29 @@ def setup_kv_metrics(self):
7077
self.engine_client.set_metrics_publisher(self.metrics_publisher)
7178
# Initially send dummy metrics to kick start,
7279
# vLLM will not update stat until forward pass is triggered
73-
self.metrics_publisher.publish(
74-
0, # request_active_slots
75-
1024, # request_total_slots
76-
0, # kv_active_blocks
77-
1024, # kv_total_blocks
78-
0, # num_requests_waiting
79-
0.0, # gpu_cache_usage_perc
80-
0.0, # gpu_prefix_cache_hit_rate
80+
81+
# Create the structured metrics objects
82+
worker_stats = WorkerStats(
83+
request_active_slots=0,
84+
request_total_slots=1024,
85+
num_requests_waiting=0,
86+
data_parallel_rank=None,
87+
)
88+
89+
kv_stats = KvStats(
90+
kv_active_blocks=0,
91+
kv_total_blocks=1024,
92+
gpu_cache_usage_perc=0.0,
93+
gpu_prefix_cache_hit_rate=0.0,
8194
)
95+
96+
metrics = ForwardPassMetrics(
97+
worker_stats=worker_stats, kv_stats=kv_stats, spec_decode_stats=None
98+
)
99+
100+
# Publish the metrics as a single object
101+
self.metrics_publisher.publish(metrics)
102+
82103
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
83104
task.add_done_callback(
84105
lambda _: logging.debug("metrics publisher endpoint created")

lib/bindings/python/rust/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
7272
m.add_class::<Client>()?;
7373
m.add_class::<EtcdClient>()?;
7474
m.add_class::<AsyncResponseStream>()?;
75-
m.add_class::<llm::kv::KvRouter>()?;
7675
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
7776
m.add_class::<llm::kv::WorkerMetricsPublisher>()?;
7877
m.add_class::<llm::model_card::ModelDeploymentCard>()?;

0 commit comments

Comments
 (0)