diff --git a/Cargo.lock b/Cargo.lock index 7520a03c..017561a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,7 @@ dependencies = [ "clap", "compact_str", "config", + "dashmap", "derive_more", "displaydoc", "dotenvy", diff --git a/ai-gateway/Cargo.toml b/ai-gateway/Cargo.toml index b04065b3..9ff76d1d 100644 --- a/ai-gateway/Cargo.toml +++ b/ai-gateway/Cargo.toml @@ -99,6 +99,7 @@ uuid = { workspace = true, features = ["serde", "v7"] } weighted-balance = { workspace = true } workspace_root = { workspace = true, optional = true } ts-rs = { workspace = true, features = ["uuid-impl"] } +dashmap = "6.1.0" [dev-dependencies] cargo-husky = { workspace = true, features = ["user-hooks"] } @@ -178,4 +179,4 @@ required-features = ["testing"] [[test]] name = "retries" -required-features = ["testing"] \ No newline at end of file +required-features = ["testing"] diff --git a/ai-gateway/src/discover/monitor/health/provider.rs b/ai-gateway/src/discover/monitor/health/provider.rs index 923a1317..b9fe2f4f 100644 --- a/ai-gateway/src/discover/monitor/health/provider.rs +++ b/ai-gateway/src/discover/monitor/health/provider.rs @@ -1,16 +1,13 @@ //! Dynamically remove inference providers that fail health checks use std::sync::Arc; +use dashmap::DashMap; use futures::future::{self, BoxFuture}; use meltdown::Token; use opentelemetry::KeyValue; use rust_decimal::prelude::ToPrimitive; -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use tokio::{ - sync::{RwLock, mpsc::Sender}, - task::JoinSet, - time, -}; +use rustc_hash::FxHashSet as HashSet; +use tokio::{sync::mpsc::Sender, task::JoinSet, time}; use tower::discover::Change; use tracing::{Instrument, debug, error, trace}; use weighted_balance::weight::Weight; @@ -38,8 +35,7 @@ use crate::{ types::{provider::InferenceProvider, router::RouterId}, }; -pub type HealthMonitorMap = - Arc>>; +pub type HealthMonitorMap = DashMap; #[derive(Debug, Clone)] pub enum ProviderHealthMonitor { @@ -668,16 +664,24 @@ impl HealthMonitor { loop { interval.tick().await; - let mut monitors = self.app_state.0.health_monitors.write().await; + let mut check_futures = Vec::new(); - for (router_id, monitor) in monitors.iter_mut() { + + for entry in self.app_state.0.health_monitors.iter() { + let router_id = entry.key().clone(); + let span = tracing::info_span!("health_monitor", router_id = ?router_id); + let health_monitors = &self.app_state.0.health_monitors; let check_future = async move { - let result = monitor.check_monitor().await; - if let Err(e) = &result { - error!(router_id = ?router_id, error = ?e, "Provider health monitor check failed"); + if let Some(mut entry) = health_monitors.get_mut(&router_id) { + let result = entry.value_mut().check_monitor().await; + if let Err(e) = &result { + error!(router_id = ?router_id, error = ?e, "Provider health monitor check failed"); + } + result + } else { + Ok(()) } - result }.instrument(span); check_futures.push(check_future); @@ -721,7 +725,7 @@ impl AppState { router_config: Arc, tx: Sender>, ) { - self.0.health_monitors.write().await.insert( + self.0.health_monitors.insert( router_id.clone(), ProviderHealthMonitor::provider_weighted( tx, @@ -738,7 +742,7 @@ impl AppState { router_config: Arc, tx: Sender>, ) { - self.0.health_monitors.write().await.insert( + self.0.health_monitors.insert( router_id.clone(), ProviderHealthMonitor::model_weighted( tx, @@ -755,7 +759,7 @@ impl AppState { router_config: Arc, tx: Sender>, ) { - self.0.health_monitors.write().await.insert( + self.0.health_monitors.insert( router_id.clone(), ProviderHealthMonitor::provider_latency( tx, @@ -772,7 +776,7 @@ impl AppState { router_config: Arc, tx: Sender>, ) { - self.0.health_monitors.write().await.insert( + self.0.health_monitors.insert( router_id.clone(), ProviderHealthMonitor::model_latency( tx,