Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion ai-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ uuid = { workspace = true, features = ["serde", "v7"] }
weighted-balance = { workspace = true }
workspace_root = { workspace = true, optional = true }
ts-rs = { workspace = true, features = ["uuid-impl"] }
dashmap = "6.1.0"

[dev-dependencies]
cargo-husky = { workspace = true, features = ["user-hooks"] }
Expand Down Expand Up @@ -178,4 +179,4 @@ required-features = ["testing"]

[[test]]
name = "retries"
required-features = ["testing"]
required-features = ["testing"]
40 changes: 22 additions & 18 deletions ai-gateway/src/discover/monitor/health/provider.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
//! Dynamically remove inference providers that fail health checks
use std::sync::Arc;

use dashmap::DashMap;
use futures::future::{self, BoxFuture};
use meltdown::Token;
use opentelemetry::KeyValue;
use rust_decimal::prelude::ToPrimitive;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use tokio::{
sync::{RwLock, mpsc::Sender},
task::JoinSet,
time,
};
use rustc_hash::FxHashSet as HashSet;
use tokio::{sync::mpsc::Sender, task::JoinSet, time};
use tower::discover::Change;
use tracing::{Instrument, debug, error, trace};
use weighted_balance::weight::Weight;
Expand Down Expand Up @@ -38,8 +35,7 @@ use crate::{
types::{provider::InferenceProvider, router::RouterId},
};

pub type HealthMonitorMap =
Arc<RwLock<HashMap<RouterId, ProviderHealthMonitor>>>;
pub type HealthMonitorMap = DashMap<RouterId, ProviderHealthMonitor>;

#[derive(Debug, Clone)]
pub enum ProviderHealthMonitor {
Expand Down Expand Up @@ -668,16 +664,24 @@ impl HealthMonitor {

loop {
interval.tick().await;
let mut monitors = self.app_state.0.health_monitors.write().await;

let mut check_futures = Vec::new();
for (router_id, monitor) in monitors.iter_mut() {

for entry in self.app_state.0.health_monitors.iter() {
let router_id = entry.key().clone();

let span = tracing::info_span!("health_monitor", router_id = ?router_id);
let health_monitors = &self.app_state.0.health_monitors;
let check_future = async move {
let result = monitor.check_monitor().await;
if let Err(e) = &result {
error!(router_id = ?router_id, error = ?e, "Provider health monitor check failed");
if let Some(mut entry) = health_monitors.get_mut(&router_id) {
let result = entry.value_mut().check_monitor().await;
if let Err(e) = &result {
error!(router_id = ?router_id, error = ?e, "Provider health monitor check failed");
}
result
} else {
Ok(())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Health Checks Fail Due to DashMap Race Condition

A race condition in the health monitoring loop, introduced by the DashMap migration, can cause health checks to be silently skipped or misapplied. The loop iterates over DashMap entries to collect router IDs, then attempts to retrieve and check the corresponding monitor via get_mut() within an async future. If an entry is removed or replaced between iteration and get_mut(), the health check is silently skipped (returning Ok(())) or the wrong monitor is checked. This bypasses the concurrency control of the previous RwLock implementation, potentially leading to unmonitored failed providers.

Fix in Cursor Fix in Web

}
result
}.instrument(span);

check_futures.push(check_future);
Expand Down Expand Up @@ -721,7 +725,7 @@ impl AppState {
router_config: Arc<RouterConfig>,
tx: Sender<Change<ProviderWeightedKey, DispatcherService>>,
) {
self.0.health_monitors.write().await.insert(
self.0.health_monitors.insert(
router_id.clone(),
ProviderHealthMonitor::provider_weighted(
tx,
Expand All @@ -738,7 +742,7 @@ impl AppState {
router_config: Arc<RouterConfig>,
tx: Sender<Change<ModelWeightedKey, DispatcherService>>,
) {
self.0.health_monitors.write().await.insert(
self.0.health_monitors.insert(
router_id.clone(),
ProviderHealthMonitor::model_weighted(
tx,
Expand All @@ -755,7 +759,7 @@ impl AppState {
router_config: Arc<RouterConfig>,
tx: Sender<Change<ProviderKey, DispatcherService>>,
) {
self.0.health_monitors.write().await.insert(
self.0.health_monitors.insert(
router_id.clone(),
ProviderHealthMonitor::provider_latency(
tx,
Expand All @@ -772,7 +776,7 @@ impl AppState {
router_config: Arc<RouterConfig>,
tx: Sender<Change<ModelKey, DispatcherService>>,
) {
self.0.health_monitors.write().await.insert(
self.0.health_monitors.insert(
router_id.clone(),
ProviderHealthMonitor::model_latency(
tx,
Expand Down