Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pub async fn track_metrics(req: Request<Body>, next: Next) -> Response {
.map(|b| b.0.clone())
.unwrap_or_else(|| "none".to_string());

histogram!("rpc_request_duration_seconds").record(duration);
histogram!("rpc_request_duration_seconds", "rpc_method" => rpc_method.clone(), "backend" => backend.clone()).record(duration);
counter!("rpc_requests_total", "method" => method, "status" => status, "rpc_method" => rpc_method, "backend" => backend).increment(1);

response
Expand All @@ -146,12 +146,12 @@ pub async fn proxy(
// Valid key
}
Ok(None) => {
info!("API key '{}' is invalid", api_key);
info!("Invalid API key presented (prefix={}...)", &api_key[..api_key.len().min(6)]);
return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
}
Err(e) => {
if e == "Rate limit exceeded" {
warn!("API key '{}' rate limited", api_key);
warn!("API key rate limited (prefix={}...)", &api_key[..api_key.len().min(6)]);
return (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
} else {
error!("Key validation error: {}", e);
Expand All @@ -177,18 +177,23 @@ pub async fn proxy(
}
};

// Rebuild URI (remove ?api-key=... from request)
let request_path_and_query = req
// Rebuild URI: strip api-key from query params while preserving others
let path = req.uri().path();
let cleaned_query = req
.uri()
.path_and_query()
.map(|x| x.as_str())
.unwrap_or("/");

// Remove api-key from the incoming request's query parameters
let cleaned_request_path = if let Some(pos) = request_path_and_query.find("?api-key=") {
&request_path_and_query[..pos]
.query()
.map(|q| {
q.split('&')
.filter(|p| !p.starts_with("api-key="))
.collect::<Vec<_>>()
.join("&")
})
.unwrap_or_default();

let cleaned_request_path = if cleaned_query.is_empty() {
path.to_string()
} else {
request_path_and_query
format!("{}?{}", path, cleaned_query)
};

// Build URI with selected backend
Expand Down Expand Up @@ -331,14 +336,14 @@ pub async fn ws_proxy(
// Authorized
}
Ok(None) => {
info!("WebSocket: API key '{}' is invalid from {}", api_key, addr);
info!("WebSocket: Invalid API key from {} (prefix={}...)", addr, &api_key[..api_key.len().min(6)]);
return (StatusCode::UNAUTHORIZED, "Unauthorized").into_response();
}
Err(e) => {
if e == "Rate limit exceeded" {
warn!(
"WebSocket: API key '{}' rate limited from {}",
api_key, addr
"WebSocket: API key rate limited from {} (prefix={}...)",
addr, &api_key[..api_key.len().min(6)]
);
return (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded").into_response();
}
Expand Down Expand Up @@ -479,8 +484,14 @@ async fn handle_ws_connection(

// Run both directions concurrently, stop when either ends
tokio::select! {
_ = client_to_backend => {},
_ = backend_to_client => {},
_ = client_to_backend => {
// Client side ended; send close to backend
let _ = backend_write.send(TungsteniteMessage::Close(None)).await;
},
_ = backend_to_client => {
// Backend side ended; send close to client
let _ = client_write.send(Message::Close(None)).await;
},
}

info!(
Expand Down
53 changes: 38 additions & 15 deletions src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{

use arc_swap::ArcSwap;
use axum::{body::Body, http::Request};
use futures_util::future;
use hyper_tls::HttpsConnector;
use hyper_util::client::legacy::{connect::HttpConnector, Client};
use metrics::gauge;
Expand Down Expand Up @@ -54,11 +55,15 @@ impl HealthState {
}

pub fn get_status(&self, label: &str) -> Option<BackendHealthStatus> {
self.statuses.read().unwrap().get(label).cloned()
self.statuses
.read()
.unwrap_or_else(|e| e.into_inner())
.get(label)
.cloned()
}

pub fn update_status(&self, label: &str, status: BackendHealthStatus) {
let mut statuses = self.statuses.write().unwrap();
let mut statuses = self.statuses.write().unwrap_or_else(|e| e.into_inner());
if let Some(s) = statuses.get_mut(label) {
*s = status;
} else {
Expand All @@ -68,7 +73,10 @@ impl HealthState {
}

pub fn get_all_statuses(&self) -> HashMap<String, BackendHealthStatus> {
self.statuses.read().unwrap().clone()
self.statuses
.read()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
}

Expand Down Expand Up @@ -127,20 +135,35 @@ pub async fn health_check_loop(
) {
loop {
// Load the current state for this iteration
// We load it once per loop iteration to ensure consistency during the check cycle
let current_state = router_state.load();

let backends = &current_state.backends;

let health_config = &current_state.health_check_config;
let health_state = &current_state.health_state;
let check_interval = Duration::from_secs(health_config.interval_secs);

for backend in backends {
let check_result = perform_health_check(&client, &backend.config, health_config).await;
// Run all health checks concurrently so one slow backend doesn't block others
let check_futures: Vec<_> = current_state
.backends
.iter()
.map(|backend| {
let client = client.clone();
let config = backend.config.clone();
let hc = health_config.clone();
async move {
let result = perform_health_check(&client, &config, &hc).await;
(config.label.clone(), result)
}
})
.collect();

let results = future::join_all(check_futures).await;

for (i, (label, check_result)) in results.into_iter().enumerate() {
let backend = &current_state.backends[i];

// Get current status from the detailed state
let mut current_status = health_state
.get_status(&backend.config.label)
.get_status(&label)
.unwrap_or_default();

let previous_healthy = current_status.healthy;
Expand All @@ -160,7 +183,7 @@ pub async fn health_check_loop(

tracing::debug!(
"Health check succeeded for backend {} (consecutive successes: {})",
backend.config.label,
label,
current_status.consecutive_successes
);
}
Expand All @@ -178,7 +201,7 @@ pub async fn health_check_loop(

tracing::warn!(
"Health check failed for backend {} (consecutive failures: {}): {}",
backend.config.label,
label,
current_status.consecutive_failures,
error
);
Expand All @@ -191,23 +214,23 @@ pub async fn health_check_loop(
if previous_healthy && !current_status.healthy {
tracing::warn!(
"Backend {} marked as UNHEALTHY after {} consecutive failures",
backend.config.label,
label,
current_status.consecutive_failures
);
} else if !previous_healthy && current_status.healthy {
tracing::info!(
"Backend {} marked as HEALTHY after {} consecutive successes",
backend.config.label,
label,
current_status.consecutive_successes
);
}

// Update metrics
gauge!("rpc_backend_health", "backend" => backend.config.label.clone())
gauge!("rpc_backend_health", "backend" => label.clone())
.set(if current_status.healthy { 1.0 } else { 0.0 });

// Update detailed state (locked)
health_state.update_status(&backend.config.label, current_status.clone());
health_state.update_status(&label, current_status.clone());

// Update atomic boolean (lock-free)
backend
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async fn main() {
let config = load_config(&args.config).expect("Failed to load router configuration");

info!("Loaded configuration from: {}", args.config);
info!("Redis URL configured: {}", config.redis_url);
info!("Redis URL configured (host redacted)");

info!("Loaded {} backends", config.backends.len());
for backend in &config.backends {
Expand Down
6 changes: 6 additions & 0 deletions src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ pub struct MockKeyStore {
pub error_keys: Arc<Mutex<HashMap<String, String>>>,
}

impl Default for MockKeyStore {
fn default() -> Self {
Self::new()
}
}

impl MockKeyStore {
pub fn new() -> Self {
Self {
Expand Down
6 changes: 3 additions & 3 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use axum::body::Body;
use hyper_tls::HttpsConnector;
use hyper_util::client::legacy::{connect::HttpConnector, Client};
use rand::Rng;
use tracing::info;
use tracing::{debug, info};

use crate::{
config::{Backend, HealthCheckConfig},
Expand Down Expand Up @@ -55,11 +55,11 @@ impl AppState {
.find(|b| b.config.label == *backend_label)
{
if backend.healthy.load(Ordering::Relaxed) {
info!("Method {} routed to label={}", method, backend_label);
debug!("Method {} routed to label={}", method, backend_label);
return Some((backend.config.label.clone(), backend.config.url.clone()));
} else {
info!(
"Method {} routed to label={} but backend is unhealthy, falling back to weighted selection",
"Method {} target label={} is unhealthy, falling back to weighted selection",
method, backend_label
);
}
Expand Down