Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions components/src/dynamo/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Optional

import uvloop
# from kvbm.vllm_integration.consolidator_config import get_consolidator_endpoints
from prometheus_client import REGISTRY
from vllm.distributed.kv_events import ZmqEventPublisher
from vllm.usage.usage_lib import UsageContext
Expand Down
161 changes: 75 additions & 86 deletions lib/llm/src/discovery/watcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,26 @@

use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tokio::sync::Notify;

use anyhow::Context as _;
use tokio::sync::{Notify, mpsc::Receiver};
use futures::StreamExt;

use dynamo_runtime::{
DistributedRuntime,
discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryKey, DiscoveryStream},
pipeline::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter,
},
protocols::{EndpointId, annotated::Annotated},
storage::key_value_store::WatchEvent,
};

use crate::{
backend::Backend,
entrypoint,
kv_router::{KvRouterConfig, PrefillRouter},
model_card::{self, ModelDeploymentCard},
model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{
Expand Down Expand Up @@ -99,17 +100,45 @@ impl ModelWatcher {
}

/// Common watch logic with optional namespace filtering
pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
pub async fn watch(&self, mut discovery_stream: DiscoveryStream, target_namespace: Option<&str>) {
let global_namespace = target_namespace.is_none_or(is_global_namespace);

while let Some(event) = events_rx.recv().await {
while let Some(result) = discovery_stream.next().await {
let event = match result {
Ok(event) => event,
Err(err) => {
tracing::error!(%err, "Error in discovery stream");
continue;
}
};

match event {
WatchEvent::Put(kv) => {
let key = kv.key_str();
let endpoint_id = match key_extract(key) {
Ok((eid, _)) => eid,
Err(err) => {
tracing::error!(%key, %err, "Failed extracting EndpointId from key. Ignoring instance.");
DiscoveryEvent::Added(instance) => {
// Extract EndpointId, instance_id, and card from the discovery instance
let (endpoint_id, instance_id, mut card) = match &instance {
DiscoveryInstance::ModelCard {
namespace,
component,
endpoint,
instance_id,
..
} => {
let eid = EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
};

match instance.deserialize_model_card::<ModelDeploymentCard>() {
Ok(card) => (eid, *instance_id, card),
Err(err) => {
tracing::error!(%err, instance_id, "Failed to deserialize model card");
continue;
}
}
}
_ => {
tracing::error!("Unexpected discovery instance type (expected ModelCard)");
continue;
}
};
Expand All @@ -127,21 +156,6 @@ impl ModelWatcher {
continue;
}

let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};

// If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance
let can_add =
Expand All @@ -164,7 +178,10 @@ impl ModelWatcher {
continue;
}

match self.handle_put(key, &endpoint_id, &mut card).await {
// Use instance_id as the HashMap key (simpler and sufficient since keys are opaque)
let key = format!("{:x}", instance_id);

match self.handle_put(&key, &endpoint_id, &mut card).await {
Ok(()) => {
tracing::info!(
model_name = card.name(),
Expand All @@ -183,10 +200,12 @@ impl ModelWatcher {
}
}
}
WatchEvent::Delete(key) => {
let deleted_key = key.as_ref();
DiscoveryEvent::Removed(instance_id) => {
// Use instance_id hex as the HashMap key (matches what we saved with)
let key = format!("{:x}", instance_id);

match self
.handle_delete(deleted_key, target_namespace, global_namespace)
.handle_delete(&key, target_namespace, global_namespace)
.await
{
Ok(Some(model_name)) => {
Expand All @@ -212,6 +231,8 @@ impl ModelWatcher {
target_namespace: Option<&str>,
is_global_namespace: bool,
) -> anyhow::Result<Option<String>> {
tracing::warn!("DISCOVERY_VALIDATION: handle_delete: key={}", key);

let card = match self.manager.remove_model_card(key) {
Some(card) => card,
None => {
Expand Down Expand Up @@ -303,6 +324,8 @@ impl ModelWatcher {
endpoint_id: &EndpointId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
tracing::warn!("DISCOVERY_VALIDATION: handle_put: key={}", key);

card.download_config().await?;

let component = self
Expand Down Expand Up @@ -559,35 +582,37 @@ impl ModelWatcher {

/// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let store = self.drt.store();
let Some(card_bucket) = store.get_bucket(model_card::ROOT_PATH).await? else {
// no cards
return Ok(vec![]);
};
let entries = card_bucket.entries().await?;
let discovery = self.drt.discovery_client();
let instances = discovery.list(DiscoveryKey::AllModelCards).await?;

let mut results = Vec::with_capacity(entries.len());
for (key, card_bytes) in entries {
let r = match serde_json::from_slice::<ModelDeploymentCard>(&card_bytes) {
let mut results = Vec::with_capacity(instances.len());
for instance in instances {
match instance.deserialize_model_card::<ModelDeploymentCard>() {
Ok(card) => {
let maybe_endpoint_id =
key_extract(&key).map(|(endpoint_id, _instance_id)| endpoint_id);
let endpoint_id = match maybe_endpoint_id {
Ok(eid) => eid,
Err(err) => {
tracing::error!(%err, "Skipping invalid key, not string or not EndpointId");
// Extract EndpointId from the instance
let endpoint_id = match &instance {
dynamo_runtime::discovery::DiscoveryInstance::ModelCard {
namespace,
component,
endpoint,
..
} => EndpointId {
namespace: namespace.clone(),
component: component.clone(),
name: endpoint.clone(),
},
_ => {
tracing::error!("Unexpected discovery instance type (expected ModelCard)");
continue;
}
};
(endpoint_id, card)
results.push((endpoint_id, card));
}
Err(err) => {
let value = String::from_utf8_lossy(&card_bytes);
tracing::error!(%err, %value, "Invalid JSON in model card");
tracing::error!(%err, "Failed to deserialize model card");
continue;
}
};
results.push(r);
}
}
Ok(results)
}
Expand All @@ -612,40 +637,4 @@ impl ModelWatcher {
}
}

/// The ModelDeploymentCard is published in store with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
if !s.starts_with(model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
}
let parts: Vec<&str> = s.split('/').collect();

// Need at least prefix model_card::ROOT_PATH (2 parts) + namespace, component, name (3 parts)
if parts.len() <= 5 {
anyhow::bail!("Invalid format: not enough path segments in {s}");
}

let endpoint_id = EndpointId {
namespace: parts[2].to_string(),
component: parts[3].to_string(),
name: parts[4].to_string(),
};
Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_key_extract() {
let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, _) = key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
}
}
41 changes: 24 additions & 17 deletions lib/llm/src/discovery/worker_monitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{watch_and_extract_field, DiscoveryKey};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;
Expand Down Expand Up @@ -79,21 +79,13 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let endpoint = &self.client.endpoint;
let component = endpoint.component();

let Some(etcd_client) = component.drt().etcd_client() else {
// Static mode, no monitoring needed
return Ok(());
};

// Watch for runtime config updates from model deployment cards
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
model_card::ROOT_PATH,
key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config),
component.drt().child_token(),
)
.await?;
let mut config_events_rx = runtime_configs_watcher.receiver();
// Watch for runtime config updates from model deployment cards via discovery interface
let discovery = component.drt().discovery_client();
let discovery_stream = discovery.list_and_watch(DiscoveryKey::AllModelCards).await?;
let mut config_events_rx = watch_and_extract_field(
discovery_stream,
|card: ModelDeploymentCard| card.runtime_config,
);

// Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
Expand All @@ -117,6 +109,21 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
// Handle runtime config updates
_ = config_events_rx.changed() => {
let runtime_configs = config_events_rx.borrow().clone();

tracing::warn!(
worker_count = runtime_configs.len(),
"DISCOVERY: Runtime config updates received"
);

// Log detailed config state for comparison
let config_details: Vec<(u64, Option<u64>)> = runtime_configs
.iter()
.map(|(&lease_id, config)| (lease_id, config.total_kv_blocks))
.collect();
tracing::warn!(
"DISCOVERY_VALIDATION: config_state: configs={:?}",
config_details
);

let mut states = worker_load_states.write().unwrap();
states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
Expand Down
12 changes: 4 additions & 8 deletions lib/llm/src/entrypoint/input/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::{self, ModelDeploymentCard},
model_card::ModelDeploymentCard,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
Expand Down Expand Up @@ -59,7 +59,6 @@ pub async fn prepare_engine(
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic(local_model) => {
let store = Arc::new(distributed_runtime.store().clone());
let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime.clone(),
Expand All @@ -68,14 +67,11 @@ pub async fn prepare_engine(
None,
None,
));
let (_, receiver) = store.watch(
model_card::ROOT_PATH,
None,
distributed_runtime.primary_token(),
);
let discovery = distributed_runtime.discovery_client();
let discovery_stream = discovery.list_and_watch(dynamo_runtime::discovery::DiscoveryKey::AllModelCards).await?;
let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver, None).await;
inner_watch_obj.watch(discovery_stream, None).await;
});
tracing::info!("Waiting for remote model..");

Expand Down
Loading
Loading