Skip to content

Commit 6e2b22e

Browse files
kthuirmccorm4
andauthored
feat: ETCD high availability client failover - lease watch resilience (#3950)
Signed-off-by: Jacky <[email protected]> Co-authored-by: Ryan McCormick <[email protected]>
1 parent 25fc732 commit 6e2b22e

File tree

10 files changed

+1403
-88
lines changed

10 files changed

+1403
-88
lines changed

lib/bindings/python/rust/planner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ impl InnerClient {
479479
/// Wait for a new scaling decision. Use `get` when this returns to fetch the values.
480480
async fn wait(&self) -> anyhow::Result<()> {
481481
let watcher = self.etcd_client.kv_watch_prefix(&self.key).await?;
482-
let (_prefix, _watcher, mut receiver) = watcher.dissolve();
482+
let (_prefix, mut receiver) = watcher.dissolve();
483483
tokio::select! {
484484
_ = receiver.recv() => {
485485
Ok(())

lib/llm/src/disagg_router.rs

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use serde::{Deserialize, Serialize};
5+
use std::sync::{Arc, Mutex};
6+
use tokio::sync::watch;
7+
use tracing;
8+
9+
use dynamo_runtime::DistributedRuntime;
10+
use dynamo_runtime::transports::etcd::WatchEvent;
11+
12+
#[derive(Clone, Debug, Serialize, Deserialize)]
13+
pub struct DisaggRouterConf {
14+
pub max_local_prefill_length: i32,
15+
}
16+
17+
impl Default for DisaggRouterConf {
18+
fn default() -> Self {
19+
Self {
20+
max_local_prefill_length: 1000,
21+
}
22+
}
23+
}
24+
25+
impl DisaggRouterConf {
26+
pub async fn from_etcd_with_watcher(
27+
drt: Arc<DistributedRuntime>,
28+
model_name: &str,
29+
) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
30+
let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
31+
32+
// Get the initial value if it exists
33+
let Some(etcd_client) = drt.etcd_client() else {
34+
anyhow::bail!("Static components don't have an etcd client");
35+
};
36+
let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
37+
Ok(kvs) => {
38+
if let Some(kv) = kvs.first() {
39+
match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
40+
Ok(config) => {
41+
tracing::debug!(
42+
"Found initial config for key {}: {:?}",
43+
etcd_key,
44+
config
45+
);
46+
config
47+
}
48+
Err(e) => {
49+
tracing::warn!(
50+
"Failed to parse initial config for key {}: {}",
51+
etcd_key,
52+
e
53+
);
54+
DisaggRouterConf::default()
55+
}
56+
}
57+
} else {
58+
tracing::debug!(
59+
"No initial config found for key {}, using default",
60+
etcd_key
61+
);
62+
DisaggRouterConf::default()
63+
}
64+
}
65+
Err(e) => {
66+
tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
67+
DisaggRouterConf::default()
68+
}
69+
};
70+
71+
// Create watch channel for config updates
72+
let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
73+
74+
// Set up the watcher after getting the initial value
75+
let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
76+
let (key, mut kv_event_rx) = prefix_watcher.dissolve();
77+
78+
// Spawn background task to watch for config changes
79+
drt.runtime().secondary().spawn(async move {
80+
tracing::info!("Starting config watcher for disagg router key: {}", key);
81+
82+
loop {
83+
let kv_event = tokio::select! {
84+
_ = watch_tx.closed() => {
85+
tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
86+
break;
87+
}
88+
kv_event = kv_event_rx.recv() => {
89+
match kv_event {
90+
Some(kv_event) => kv_event,
91+
None => {
92+
tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
93+
break;
94+
}
95+
}
96+
}
97+
};
98+
99+
tracing::debug!("Received watch event for key {}", key);
100+
101+
match kv_event {
102+
WatchEvent::Put(kv) => {
103+
let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
104+
if let Ok(config) = val {
105+
tracing::info!("Config updated for key {}: {:?}", key, config);
106+
// Broadcast the update
107+
if watch_tx.send(config).is_err() {
108+
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
109+
break;
110+
}
111+
} else {
112+
tracing::error!("Unable to parse router config for key {}", key);
113+
break;
114+
}
115+
}
116+
WatchEvent::Delete(_) => {
117+
tracing::warn!("Config key was deleted: {}", key);
118+
// Reset to default values
119+
if watch_tx.send(DisaggRouterConf::default()).is_err() {
120+
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
121+
break;
122+
}
123+
}
124+
}
125+
}
126+
127+
tracing::debug!("Completed config watcher for key: {}", key);
128+
});
129+
130+
Ok((initial_config, watch_rx))
131+
}
132+
}
133+
134+
#[derive(Clone)]
135+
pub struct DisaggregatedRouter {
136+
max_local_prefill_length: Arc<Mutex<i32>>,
137+
model_name: String,
138+
config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
139+
}
140+
141+
impl DisaggregatedRouter {
142+
pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
143+
DisaggregatedRouter {
144+
max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
145+
model_name,
146+
config_watcher: None,
147+
}
148+
}
149+
150+
pub async fn new_with_etcd_and_default(
151+
drt: Arc<DistributedRuntime>,
152+
model_name: String,
153+
default_max_local_prefill_length: i32,
154+
) -> anyhow::Result<Self> {
155+
let (mut config, watcher) =
156+
DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
157+
158+
// Use the provided default if no etcd value was found (when config is the default value)
159+
if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
160+
config.max_local_prefill_length = default_max_local_prefill_length;
161+
}
162+
163+
let router = Self {
164+
max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
165+
model_name: model_name.clone(),
166+
config_watcher: Some(watcher),
167+
};
168+
169+
// Start background task to watch for config updates
170+
router.start_config_watcher();
171+
172+
Ok(router)
173+
}
174+
175+
fn start_config_watcher(&self) {
176+
if let Some(watcher) = self.config_watcher.clone() {
177+
let mut watcher = watcher;
178+
// Create a clone for the task
179+
let model_name = self.model_name.clone();
180+
let max_local_prefill_length = self.max_local_prefill_length.clone();
181+
182+
tokio::spawn(async move {
183+
tracing::info!("Starting config update watcher for model: {}", model_name);
184+
185+
while watcher.changed().await.is_ok() {
186+
let config = watcher.borrow().clone();
187+
let new_value = config.max_local_prefill_length;
188+
189+
// Update the value using the mutex
190+
let mut current_value = max_local_prefill_length.lock().unwrap();
191+
let old_value = *current_value;
192+
if old_value != new_value {
193+
*current_value = new_value;
194+
tracing::info!(
195+
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
196+
model_name,
197+
old_value,
198+
new_value
199+
);
200+
}
201+
}
202+
203+
tracing::debug!("Config watcher closed for model: {}", model_name);
204+
});
205+
}
206+
}
207+
208+
pub fn check_for_updates(&self) {
209+
if let Some(watcher) = &self.config_watcher
210+
&& watcher.has_changed().unwrap_or(false)
211+
{
212+
let config = watcher.borrow().clone();
213+
let new_value = config.max_local_prefill_length;
214+
215+
// Update the value using the mutex
216+
let mut current_value = self.max_local_prefill_length.lock().unwrap();
217+
let old_value = *current_value;
218+
if old_value != new_value {
219+
*current_value = new_value;
220+
tracing::info!(
221+
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
222+
self.model_name,
223+
old_value,
224+
new_value
225+
);
226+
}
227+
}
228+
}
229+
230+
pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
231+
// Check for updates before making the decision
232+
self.check_for_updates();
233+
234+
// Get the current value from the mutex
235+
let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
236+
237+
// schedule the request purely based on the prefill length
238+
// TODO: apply math models and compare local vs remote prefill TTFT
239+
prefill_length - prefix_hit_length > max_local_prefill_length
240+
}
241+
242+
pub fn update_value(&self, max_local_prefill_length: i32) {
243+
let mut current = self.max_local_prefill_length.lock().unwrap();
244+
*current = max_local_prefill_length;
245+
}
246+
247+
pub fn get_model_name(&self) -> &str {
248+
&self.model_name
249+
}
250+
}

lib/llm/src/kv_router/subscriber.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,14 @@ pub async fn start_kv_router_background(
274274
cleanup_orphaned_consumers(&mut nats_queue, &etcd_client, &component, &consumer_uuid).await;
275275

276276
// Watch for router deletions to clean up orphaned consumers
277-
let (_prefix_str, _watcher, mut router_replicas_rx) = etcd_client
277+
let (_prefix_str, mut router_replicas_rx) = etcd_client
278278
.kv_get_and_watch_prefix(&format!("{}/", KV_ROUTERS_ROOT_PATH))
279279
.await?
280280
.dissolve();
281281

282282
// Get the generate endpoint and watch for instance deletions
283283
let generate_endpoint = component.endpoint("generate");
284-
let (_instance_prefix, _instance_watcher, mut instance_event_rx) = etcd_client
284+
let (_instance_prefix, mut instance_event_rx) = etcd_client
285285
.kv_get_and_watch_prefix(generate_endpoint.etcd_root())
286286
.await?
287287
.dissolve();

lib/runtime/src/component/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ impl Client {
213213
.kv_get_and_watch_prefix(endpoint.etcd_root())
214214
.await?;
215215

216-
let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
216+
let (prefix, mut kv_event_rx) = prefix_watcher.dissolve();
217217

218218
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
219219

0 commit comments

Comments
 (0)