diff --git a/CLAUDE.md b/CLAUDE.md index a5021aa..67daf75 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,13 +17,16 @@ This application showcases how to build a production-ready message streaming ser - Comprehensive error handling with `Result` types (no `unwrap()`/`expect()` in production code) - **Zero clippy warnings** - strict lints enforced, no `#[allow(...)]` in production code - **Connection resilience** with automatic reconnection and exponential backoff +- **Circuit breaker** pattern for fail-fast during outages - **Rate limiting** with token bucket algorithm (configurable RPS and burst) - **API key authentication** with constant-time comparison (timing attack resistant) - **Request ID propagation** for distributed tracing +- **Request timeout propagation** via `X-Request-Timeout` header - **Configurable CORS** with origin whitelist support - **Background stats caching** to avoid expensive queries on each request - **Structured concurrency** with proper task lifecycle management - **Background health checks** for early connection issue detection +- **Prometheus metrics** export for observability ## Architecture @@ -35,6 +38,7 @@ This application showcases how to build a production-ready message streaming ser │ Middleware Stack (src/middleware/) │ │ - rate_limit.rs: Token bucket rate limiting │ │ - auth.rs: API key authentication │ +│ - timeout.rs: Request timeout propagation │ │ - request_id.rs: Request ID propagation │ │ + tower_http: Tracing, CORS │ ├─────────────────────────────────────────────────────────────┤ @@ -50,6 +54,7 @@ This application showcases how to build a production-ready message streaming ser ├─────────────────────────────────────────────────────────────┤ │ IggyClientWrapper (src/iggy_client.rs) │ │ High-level wrapper with automatic reconnection │ +│ + Circuit breaker for fail-fast during outages │ │ + PollParams builder for cleaner polling API │ ├─────────────────────────────────────────────────────────────┤ │ Background Tasks (managed by TaskTracker) │ @@ -91,15 +96,23 @@ src/ ├── lib.rs # Library exports ├── config.rs # Configuration from environment ├── error.rs # Error types with HTTP status codes +├── metrics.rs # Prometheus metrics export ├── state.rs # Shared application state with stats caching ├── routes.rs # Route definitions and middleware stack -├── iggy_client.rs # Iggy SDK wrapper with auto-reconnection +├── iggy_client/ # Iggy SDK wrapper module +│ ├── mod.rs # Client wrapper with auto-reconnection +│ ├── circuit_breaker.rs # Circuit breaker pattern implementation +│ ├── connection.rs # Connection state management +│ ├── helpers.rs # Utility functions +│ ├── params.rs # PollParams builder +│ └── scopeguard.rs # Scope guard utilities ├── validation.rs # Input validation utilities ├── middleware/ │ ├── mod.rs # Middleware exports │ ├── ip.rs # Client IP extraction (shared by rate_limit and auth) │ ├── rate_limit.rs # Token bucket rate limiting (Governor) │ ├── auth.rs # API key authentication +│ ├── timeout.rs # Request timeout propagation │ └── request_id.rs # Request ID propagation ├── models/ │ ├── mod.rs # Model exports @@ -197,6 +210,13 @@ Environment variables (see `.env.example`): | `HEALTH_CHECK_INTERVAL_SECS` | `30` | Connection health check interval | | `OPERATION_TIMEOUT_SECS` | `30` | Timeout for Iggy operations | +### Circuit Breaker +| Variable | Default | Description | +|----------|---------|-------------| +| `CIRCUIT_BREAKER_FAILURE_THRESHOLD` | `5` | Failures before opening circuit | +| `CIRCUIT_BREAKER_SUCCESS_THRESHOLD` | `2` | Successes in half-open to close | +| `CIRCUIT_BREAKER_OPEN_DURATION_SECS` | `30` | How long circuit stays open | + ### Rate Limiting | Variable | Default | Description | |----------|---------|-------------| @@ -246,6 +266,7 @@ When empty (default), all X-Forwarded-For headers are trusted. **This is not rec | Variable | Default | Description | |----------|---------|-------------| | `STATS_CACHE_TTL_SECS` | `5` | Stats cache refresh interval | +| `METRICS_PORT` | `9090` | Prometheus metrics port (0 = disabled) | #### Log Levels @@ -394,7 +415,7 @@ environment: ### Running Tests ```bash -# Unit tests (93 tests) +# Unit tests (130 tests) cargo test --lib # Integration tests (24 tests, requires Docker for testcontainers) @@ -481,6 +502,7 @@ Error types and HTTP status codes: - `connection_failed` (503): Initial connection to Iggy server failed - `disconnected` (503): Lost connection during operation - `connection_reset` (503): Connection was reset by peer +- `circuit_open` (503): Circuit breaker is open, failing fast - `stream_error` (500): Stream operation failed - `topic_error` (500): Topic operation failed - `send_error` (500): Message send failed @@ -533,7 +555,7 @@ Iggy uses **0-indexed partitions**: ## Dependencies Key dependencies (see `Cargo.toml`): -- `iggy 0.8.0-edge.6`: Iggy Rust SDK (edge version for latest server features) +- `iggy 0.8.0`: Iggy Rust SDK - `axum 0.8`: Web framework - `tokio 1.48`: Async runtime - `tokio-util 0.7`: Task tracking and cancellation tokens @@ -543,8 +565,10 @@ Key dependencies (see `Cargo.toml`): - `governor 0.8`: Rate limiting with token bucket algorithm - `subtle 2.6`: Constant-time comparison for security - `tower-http 0.6`: HTTP middleware (CORS, tracing, request ID) -- `rust_decimal 1.37`: Exact decimal arithmetic for monetary values -- `testcontainers 0.24`: Integration testing with containerized Iggy +- `rust_decimal 1.39`: Exact decimal arithmetic for monetary values +- `metrics 0.24`: Application metrics +- `metrics-exporter-prometheus 0.16`: Prometheus metrics export +- `testcontainers 0.26`: Integration testing with containerized Iggy ## Structured Concurrency @@ -625,7 +649,7 @@ let messages = client.poll_with_params("stream", "topic", params).await?; Request flow (applied in order): ``` -Request → Rate Limit → Auth → Request ID → Tracing → CORS → Handler +Request → Rate Limit → Auth → Timeout → Request ID → Tracing → CORS → Handler ``` ### Client IP Extraction (`src/middleware/ip.rs`) @@ -648,6 +672,12 @@ Request → Rate Limit → Auth → Request ID → Tracing → CORS → Handler - Accepts key via `X-API-Key` header or `api_key` query parameter - Bypasses `/health` and `/ready` for health checks (exact path matching) +### Request Timeout (`src/middleware/timeout.rs`) +- Clients can specify `X-Request-Timeout: ` header +- Bounded: 100ms minimum, 5 minutes maximum +- Stored in request extensions for handler use +- `RequestTimeoutExt` trait for easy extraction in handlers + ## Deployment Security ### Reverse Proxy Configuration (Required) diff --git a/Cargo.lock b/Cargo.lock index 325f311..0d75b0b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1419,6 +1419,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foldhash" version = "0.2.0" @@ -1671,6 +1677,15 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash 0.1.5", +] + [[package]] name = "hashbrown" version = "0.16.1" @@ -1679,7 +1694,7 @@ checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.2.0", ] [[package]] @@ -2131,6 +2146,8 @@ dependencies = [ "exitcode", "governor", "iggy", + "metrics", + "metrics-exporter-prometheus", "rand 0.9.2", "reqwest", "rust_decimal", @@ -2398,6 +2415,52 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "metrics" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5312e9ba3771cfa961b585728215e3d972c950a3eed9252aa093d6301277e8" +dependencies = [ + "ahash 0.8.12", + "portable-atomic", +] + +[[package]] +name = "metrics-exporter-prometheus" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd7399781913e5393588a8d8c6a2867bf85fb38eaf2502fdce465aad2dc6f034" +dependencies = [ + "base64 0.22.1", + "http-body-util", + "hyper", + "hyper-util", + "indexmap 2.12.1", + "ipnet", + "metrics", + "metrics-util", + "quanta", + "thiserror 1.0.69", + "tokio", + "tracing", +] + +[[package]] +name = "metrics-util" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8496cc523d1f94c1385dd8f0f0c2c480b2b8aeccb5b7e4485ad6365523ae376" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", + "hashbrown 0.15.5", + "metrics", + "quanta", + "rand 0.9.2", + "rand_xoshiro", + "sketches-ddsketch", +] + [[package]] name = "mime" version = "0.3.17" @@ -3081,6 +3144,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_xoshiro" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f703f4665700daf5512dcca5f43afa6af89f09db47fb56be587f80636bda2d41" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "raw-cpuid" version = "11.6.0" @@ -3730,6 +3802,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" +[[package]] +name = "sketches-ddsketch" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1e9a774a6c28142ac54bb25d25562e6bcf957493a184f15ad4eebccb23e410a" + [[package]] name = "slab" version = "0.4.11" diff --git a/Cargo.toml b/Cargo.toml index bce4361..42a303c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,10 @@ subtle = "2.6" # Exit codes (BSD sysexits compatible) exitcode = "1.1" +# Metrics for Prometheus +metrics = "0.24" +metrics-exporter-prometheus = { version = "0.16", default-features = false, features = ["http-listener"] } + [dev-dependencies] reqwest = { version = "0.12", features = ["json"] } testcontainers = "0.26" diff --git a/src/config.rs b/src/config.rs index 1bcb3a1..41c51e7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -77,6 +77,18 @@ pub struct Config { /// Prevents operations from hanging indefinitely on network issues pub operation_timeout: Duration, + // ========================================================================= + // Circuit Breaker Configuration + // ========================================================================= + /// Number of consecutive failures before opening the circuit (default: 5) + pub circuit_breaker_failure_threshold: u32, + + /// Number of consecutive successes in half-open state to close circuit (default: 2) + pub circuit_breaker_success_threshold: u32, + + /// How long the circuit stays open before transitioning to half-open (default: 30s) + pub circuit_breaker_open_duration: Duration, + // ========================================================================= // Rate Limiting Configuration // ========================================================================= @@ -139,6 +151,9 @@ pub struct Config { /// Interval for background stats cache refresh (default: 5 seconds) pub stats_cache_ttl: Duration, + + /// Port for Prometheus metrics endpoint (default: 9090, 0 = disabled) + pub metrics_port: u16, } impl Config { @@ -180,6 +195,20 @@ impl Config { )?), operation_timeout: Duration::from_secs(Self::parse_env("OPERATION_TIMEOUT_SECS", 30)?), + // Circuit breaker + circuit_breaker_failure_threshold: Self::parse_env( + "CIRCUIT_BREAKER_FAILURE_THRESHOLD", + 5, + )?, + circuit_breaker_success_threshold: Self::parse_env( + "CIRCUIT_BREAKER_SUCCESS_THRESHOLD", + 2, + )?, + circuit_breaker_open_duration: Duration::from_secs(Self::parse_env( + "CIRCUIT_BREAKER_OPEN_DURATION_SECS", + 30, + )?), + // Rate limiting rate_limit_rps: Self::parse_env("RATE_LIMIT_RPS", 100)?, rate_limit_burst: Self::parse_env("RATE_LIMIT_BURST", 50)?, @@ -198,6 +227,7 @@ impl Config { // Observability log_level: env::var("RUST_LOG").unwrap_or_else(|_| "info".to_string()), stats_cache_ttl: Duration::from_secs(Self::parse_env("STATS_CACHE_TTL_SECS", 5)?), + metrics_port: Self::parse_env("METRICS_PORT", 9090)?, }; // Validate configuration before returning @@ -266,6 +296,25 @@ impl Config { !self.trusted_proxies.is_empty() } + /// Check if Prometheus metrics export is enabled. + pub fn metrics_enabled(&self) -> bool { + self.metrics_port > 0 + } + + /// Get the metrics endpoint address. + /// + /// Returns `None` if metrics are disabled (port = 0). + pub fn metrics_addr(&self) -> Option { + if self.metrics_enabled() { + Some(std::net::SocketAddr::from(( + [0, 0, 0, 0], + self.metrics_port, + ))) + } else { + None + } + } + /// Parse an environment variable into the specified type with a default value. fn parse_env(name: &str, default: T) -> AppResult where @@ -344,6 +393,10 @@ impl Default for Config { reconnect_max_delay: Duration::from_secs(30), health_check_interval: Duration::from_secs(30), operation_timeout: Duration::from_secs(30), + // Circuit breaker + circuit_breaker_failure_threshold: 5, + circuit_breaker_success_threshold: 2, + circuit_breaker_open_duration: Duration::from_secs(30), // Rate limiting rate_limit_rps: 100, rate_limit_burst: 50, @@ -359,6 +412,7 @@ impl Default for Config { // Observability log_level: "info".to_string(), stats_cache_ttl: Duration::from_secs(5), + metrics_port: 9090, } } } diff --git a/src/error.rs b/src/error.rs index a858674..69df216 100644 --- a/src/error.rs +++ b/src/error.rs @@ -53,6 +53,9 @@ pub enum AppError { #[error("Operation timed out: {0}")] OperationTimeout(String), + + #[error("Circuit breaker open: {0}")] + CircuitOpen(String), } /// Error response body for API endpoints. @@ -128,6 +131,13 @@ impl IntoResponse for AppError { "Operation timed out. Please try again.", ), + // Circuit breaker open - service is protecting itself from cascading failures + AppError::CircuitOpen(_) => ( + StatusCode::SERVICE_UNAVAILABLE, + "circuit_open", + "Service is temporarily unavailable due to recent failures. Please retry later.", + ), + // Client errors - safe to show the message as it's user-facing AppError::SerializationError(e) => { // Serde errors can be helpful for clients debugging their payload diff --git a/src/iggy_client/circuit_breaker.rs b/src/iggy_client/circuit_breaker.rs new file mode 100644 index 0000000..1cf681d --- /dev/null +++ b/src/iggy_client/circuit_breaker.rs @@ -0,0 +1,477 @@ +//! Circuit breaker pattern for connection resilience. +//! +//! The circuit breaker prevents request pile-up during outages by failing fast +//! when the system is known to be unavailable. This reduces load on the failing +//! service and improves recovery time. +//! +//! # States +//! +//! ```text +//! ┌────────────────────────────────────────────────────────────────────┐ +//! │ Circuit Breaker │ +//! │ │ +//! │ ┌─────────┐ failures ≥ threshold ┌─────────┐ │ +//! │ │ Closed │ ────────────────────────► │ Open │ │ +//! │ │ (Normal)│ │ (Fail │ │ +//! │ └────┬────┘ │ Fast) │ │ +//! │ │ ▲ └────┬────┘ │ +//! │ │ │ │ │ +//! │ │ │ success │ timeout expires │ +//! │ │ │ ▼ │ +//! │ │ │ ┌───────────────┐ │ +//! │ │ └─────────────────────────── │ HalfOpen │ │ +//! │ │ success │ (Probe with │ │ +//! │ │ │ one request) │ │ +//! │ │ └───────┬───────┘ │ +//! │ │ │ │ +//! │ │ │ failure │ +//! │ │ ▼ │ +//! │ │ ┌─────────┐ │ +//! │ └───────────────────────────── │ Open │ ◄────────────────┘ +//! │ reset └─────────┘ │ +//! └────────────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Configuration +//! +//! - `failure_threshold`: Number of consecutive failures before opening +//! - `success_threshold`: Number of consecutive successes in half-open to close +//! - `open_duration`: How long to stay open before trying half-open +//! +//! # Usage +//! +//! ```rust,ignore +//! let cb = CircuitBreaker::new(CircuitBreakerConfig::default()); +//! +//! // Check if request should be allowed +//! if !cb.allow_request() { +//! return Err(AppError::CircuitOpen); +//! } +//! +//! // Execute the operation +//! match operation().await { +//! Ok(result) => { +//! cb.record_success(); +//! Ok(result) +//! } +//! Err(e) => { +//! cb.record_failure(); +//! Err(e) +//! } +//! } +//! ``` + +use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use std::time::{Duration, Instant}; + +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; + +/// Circuit breaker state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CircuitState { + /// Normal operation - all requests pass through. + Closed, + /// Failing fast - all requests are rejected immediately. + Open, + /// Testing recovery - allowing limited requests through. + HalfOpen, +} + +impl std::fmt::Display for CircuitState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + CircuitState::Closed => write!(f, "closed"), + CircuitState::Open => write!(f, "open"), + CircuitState::HalfOpen => write!(f, "half-open"), + } + } +} + +/// Configuration for the circuit breaker. +#[derive(Debug, Clone)] +pub struct CircuitBreakerConfig { + /// Number of consecutive failures before opening the circuit. + pub failure_threshold: u32, + /// Number of consecutive successes in half-open state to close the circuit. + pub success_threshold: u32, + /// How long to stay in open state before transitioning to half-open. + pub open_duration: Duration, +} + +impl Default for CircuitBreakerConfig { + fn default() -> Self { + Self { + failure_threshold: 5, + success_threshold: 2, + open_duration: Duration::from_secs(30), + } + } +} + +impl CircuitBreakerConfig { + /// Create a new circuit breaker configuration. + pub fn new(failure_threshold: u32, success_threshold: u32, open_duration: Duration) -> Self { + Self { + failure_threshold, + success_threshold, + open_duration, + } + } +} + +/// Internal state for the circuit breaker. +struct CircuitBreakerState { + /// Current circuit state. + state: CircuitState, + /// When the circuit was opened (for timeout calculation). + opened_at: Option, + /// Number of consecutive failures (in closed state). + consecutive_failures: u32, + /// Number of consecutive successes (in half-open state). + consecutive_successes: u32, +} + +impl CircuitBreakerState { + fn new() -> Self { + Self { + state: CircuitState::Closed, + opened_at: None, + consecutive_failures: 0, + consecutive_successes: 0, + } + } +} + +/// Thread-safe circuit breaker implementation. +/// +/// Prevents cascading failures by failing fast when a service is unavailable. +/// Uses RwLock internally for thread-safe state management. +pub struct CircuitBreaker { + /// Configuration parameters. + config: CircuitBreakerConfig, + /// Internal state protected by RwLock. + state: RwLock, + /// Total number of times the circuit has been opened (for metrics). + times_opened: AtomicU32, + /// Total number of requests rejected due to open circuit (for metrics). + requests_rejected: AtomicU64, +} + +impl CircuitBreaker { + /// Create a new circuit breaker with the given configuration. + pub fn new(config: CircuitBreakerConfig) -> Self { + Self { + config, + state: RwLock::new(CircuitBreakerState::new()), + times_opened: AtomicU32::new(0), + requests_rejected: AtomicU64::new(0), + } + } + + /// Check if a request should be allowed through the circuit breaker. + /// + /// Returns `true` if the request can proceed, `false` if it should be rejected. + /// + /// # State Transitions + /// + /// - **Closed**: Always allows requests + /// - **Open**: Rejects requests; transitions to HalfOpen after timeout + /// - **HalfOpen**: Allows requests (for probing) + pub async fn allow_request(&self) -> bool { + // First, check with a read lock for the common case + { + let state = self.state.read().await; + match state.state { + CircuitState::Closed => return true, + CircuitState::HalfOpen => return true, + CircuitState::Open => { + // Check if timeout has expired + if let Some(opened_at) = state.opened_at + && opened_at.elapsed() < self.config.open_duration + { + self.requests_rejected.fetch_add(1, Ordering::Relaxed); + return false; + } + // Timeout expired - need to transition to half-open + } + } + } + + // Need write lock to transition from Open to HalfOpen + let mut state = self.state.write().await; + + // Re-check state in case another task already transitioned + if state.state == CircuitState::Open { + if let Some(opened_at) = state.opened_at + && opened_at.elapsed() >= self.config.open_duration + { + state.state = CircuitState::HalfOpen; + state.consecutive_successes = 0; + info!("Circuit breaker transitioning from Open to HalfOpen"); + return true; + } + self.requests_rejected.fetch_add(1, Ordering::Relaxed); + return false; + } + + true + } + + /// Record a successful operation. + /// + /// In HalfOpen state, consecutive successes can close the circuit. + pub async fn record_success(&self) { + let mut state = self.state.write().await; + + match state.state { + CircuitState::Closed => { + // Reset failure counter on success + state.consecutive_failures = 0; + } + CircuitState::HalfOpen => { + state.consecutive_successes += 1; + debug!( + consecutive_successes = state.consecutive_successes, + threshold = self.config.success_threshold, + "Circuit breaker recorded success in HalfOpen state" + ); + + if state.consecutive_successes >= self.config.success_threshold { + state.state = CircuitState::Closed; + state.opened_at = None; + state.consecutive_failures = 0; + info!("Circuit breaker closed after successful recovery"); + } + } + CircuitState::Open => { + // Shouldn't happen - requests are rejected in Open state + warn!("Unexpected success recorded in Open state"); + } + } + } + + /// Record a failed operation. + /// + /// In Closed state, consecutive failures can open the circuit. + /// In HalfOpen state, any failure reopens the circuit. + pub async fn record_failure(&self) { + let mut state = self.state.write().await; + + match state.state { + CircuitState::Closed => { + state.consecutive_failures += 1; + debug!( + consecutive_failures = state.consecutive_failures, + threshold = self.config.failure_threshold, + "Circuit breaker recorded failure" + ); + + if state.consecutive_failures >= self.config.failure_threshold { + state.state = CircuitState::Open; + state.opened_at = Some(Instant::now()); + self.times_opened.fetch_add(1, Ordering::Relaxed); + warn!( + failures = state.consecutive_failures, + open_duration = ?self.config.open_duration, + "Circuit breaker opened due to consecutive failures" + ); + } + } + CircuitState::HalfOpen => { + // Any failure in half-open state reopens the circuit + state.state = CircuitState::Open; + state.opened_at = Some(Instant::now()); + state.consecutive_successes = 0; + self.times_opened.fetch_add(1, Ordering::Relaxed); + warn!("Circuit breaker reopened after failure in HalfOpen state"); + } + CircuitState::Open => { + // Already open, refresh the timer + state.opened_at = Some(Instant::now()); + } + } + } + + /// Get the current circuit state. + pub async fn state(&self) -> CircuitState { + self.state.read().await.state + } + + /// Get the number of times the circuit has been opened. + pub fn times_opened(&self) -> u32 { + self.times_opened.load(Ordering::Relaxed) + } + + /// Get the number of requests rejected due to open circuit. + pub fn requests_rejected(&self) -> u64 { + self.requests_rejected.load(Ordering::Relaxed) + } + + /// Force the circuit to close (for testing or manual recovery). + pub async fn force_close(&self) { + let mut state = self.state.write().await; + state.state = CircuitState::Closed; + state.opened_at = None; + state.consecutive_failures = 0; + state.consecutive_successes = 0; + info!("Circuit breaker forcibly closed"); + } + + /// Force the circuit to open (for testing or manual intervention). + pub async fn force_open(&self) { + let mut state = self.state.write().await; + state.state = CircuitState::Open; + state.opened_at = Some(Instant::now()); + self.times_opened.fetch_add(1, Ordering::Relaxed); + warn!("Circuit breaker forcibly opened"); + } +} + +impl Default for CircuitBreaker { + fn default() -> Self { + Self::new(CircuitBreakerConfig::default()) + } +} + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::expect_used)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_circuit_breaker_starts_closed() { + let cb = CircuitBreaker::default(); + assert_eq!(cb.state().await, CircuitState::Closed); + assert!(cb.allow_request().await); + } + + #[tokio::test] + async fn test_circuit_opens_after_threshold_failures() { + let config = CircuitBreakerConfig::new(3, 2, Duration::from_secs(30)); + let cb = CircuitBreaker::new(config); + + // Record failures below threshold + cb.record_failure().await; + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Closed); + + // One more failure should open the circuit + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + assert_eq!(cb.times_opened(), 1); + } + + #[tokio::test] + async fn test_circuit_rejects_when_open() { + let config = CircuitBreakerConfig::new(1, 1, Duration::from_secs(30)); + let cb = CircuitBreaker::new(config); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + + // Requests should be rejected + assert!(!cb.allow_request().await); + assert_eq!(cb.requests_rejected(), 1); + } + + #[tokio::test] + async fn test_circuit_transitions_to_half_open() { + let config = CircuitBreakerConfig::new(1, 1, Duration::from_millis(10)); + let cb = CircuitBreaker::new(config); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + + // Wait for timeout + tokio::time::sleep(Duration::from_millis(20)).await; + + // Should allow request and transition to half-open + assert!(cb.allow_request().await); + assert_eq!(cb.state().await, CircuitState::HalfOpen); + } + + #[tokio::test] + async fn test_circuit_closes_after_success_in_half_open() { + let config = CircuitBreakerConfig::new(1, 2, Duration::from_millis(10)); + let cb = CircuitBreaker::new(config); + + // Open the circuit + cb.record_failure().await; + tokio::time::sleep(Duration::from_millis(20)).await; + + // Transition to half-open + assert!(cb.allow_request().await); + assert_eq!(cb.state().await, CircuitState::HalfOpen); + + // Record successes + cb.record_success().await; + assert_eq!(cb.state().await, CircuitState::HalfOpen); + + cb.record_success().await; + assert_eq!(cb.state().await, CircuitState::Closed); + } + + #[tokio::test] + async fn test_circuit_reopens_on_failure_in_half_open() { + let config = CircuitBreakerConfig::new(1, 2, Duration::from_millis(10)); + let cb = CircuitBreaker::new(config); + + // Open the circuit + cb.record_failure().await; + tokio::time::sleep(Duration::from_millis(20)).await; + + // Transition to half-open + assert!(cb.allow_request().await); + assert_eq!(cb.state().await, CircuitState::HalfOpen); + + // Failure should reopen + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + assert_eq!(cb.times_opened(), 2); + } + + #[tokio::test] + async fn test_success_resets_failure_counter() { + let config = CircuitBreakerConfig::new(3, 1, Duration::from_secs(30)); + let cb = CircuitBreaker::new(config); + + cb.record_failure().await; + cb.record_failure().await; + // Success should reset the counter + cb.record_success().await; + + // Now we need 3 more failures to open + cb.record_failure().await; + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Closed); + + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + } + + #[tokio::test] + async fn test_force_close() { + let cb = CircuitBreaker::default(); + cb.record_failure().await; + cb.record_failure().await; + cb.record_failure().await; + cb.record_failure().await; + cb.record_failure().await; + assert_eq!(cb.state().await, CircuitState::Open); + + cb.force_close().await; + assert_eq!(cb.state().await, CircuitState::Closed); + assert!(cb.allow_request().await); + } + + #[tokio::test] + async fn test_force_open() { + let cb = CircuitBreaker::default(); + assert_eq!(cb.state().await, CircuitState::Closed); + + cb.force_open().await; + assert_eq!(cb.state().await, CircuitState::Open); + assert!(!cb.allow_request().await); + } +} diff --git a/src/iggy_client/mod.rs b/src/iggy_client/mod.rs index ed56847..c2804fc 100644 --- a/src/iggy_client/mod.rs +++ b/src/iggy_client/mod.rs @@ -48,6 +48,7 @@ //! client.send_event_default(&event, None).await?; //! ``` +mod circuit_breaker; mod connection; mod helpers; mod params; @@ -67,6 +68,7 @@ use crate::error::{AppError, AppResult}; use crate::models::Event; // Re-exports for public API +pub use circuit_breaker::{CircuitBreaker, CircuitBreakerConfig, CircuitState}; pub use connection::ConnectionState; pub use helpers::{rand_jitter, to_identifier}; pub use params::PollParams; @@ -93,8 +95,16 @@ const MIN_RECONNECT_DELAY_MS: u64 = 100; /// Production-ready wrapper around the Iggy client. /// -/// Provides automatic reconnection, health monitoring, and consistent error handling -/// for all Iggy operations. Thread-safe and designed for concurrent access. +/// Provides automatic reconnection, health monitoring, circuit breaker protection, +/// and consistent error handling for all Iggy operations. Thread-safe and designed +/// for concurrent access. +/// +/// # Circuit Breaker +/// +/// The client includes a circuit breaker that prevents request pile-up during outages: +/// - **Closed** (normal): All requests pass through +/// - **Open** (failing): Requests fail fast without attempting the operation +/// - **Half-Open** (recovery): Limited requests allowed to test if service recovered /// /// # Performance Considerations /// @@ -127,6 +137,8 @@ pub struct IggyClientWrapper { config: Config, /// Connection state tracking state: Arc, + /// Circuit breaker for fail-fast during outages + circuit_breaker: Arc, } impl IggyClientWrapper { @@ -148,10 +160,18 @@ impl IggyClientWrapper { let client = IggyClient::from_connection_string(&config.iggy_connection_string) .map_err(|e| AppError::ConnectionFailed(e.to_string()))?; + // Initialize circuit breaker from config + let circuit_breaker_config = CircuitBreakerConfig::new( + config.circuit_breaker_failure_threshold, + config.circuit_breaker_success_threshold, + config.circuit_breaker_open_duration, + ); + let wrapper = Self { client: Arc::new(RwLock::new(client)), config, state: Arc::new(ConnectionState::new()), + circuit_breaker: Arc::new(CircuitBreaker::new(circuit_breaker_config)), }; wrapper.connect().await?; @@ -292,9 +312,20 @@ impl IggyClientWrapper { /// Execute an operation with automatic reconnection on connection failure. /// /// This is the core resilience mechanism. Features: + /// - **Circuit Breaker**: Fail fast when service is known to be unavailable /// - **Timeout**: All operations are bounded by `config.operation_timeout` /// - **Retry**: On connection failure, attempts reconnect and retries once /// + /// # Circuit Breaker Integration + /// + /// Before attempting the operation, the circuit breaker is checked: + /// - If **Open**: Returns `CircuitOpen` error immediately (fail fast) + /// - If **Closed** or **HalfOpen**: Proceeds with the operation + /// + /// After the operation, the circuit breaker is updated: + /// - Success: Records success (may close circuit if in HalfOpen) + /// - Failure: Records failure (may open circuit if threshold exceeded) + /// /// # Timeout vs Disconnection /// /// The function differentiates between: @@ -308,33 +339,62 @@ impl IggyClientWrapper { F: Fn() -> Fut, Fut: std::future::Future>, { + // Check circuit breaker before attempting operation + if !self.circuit_breaker.allow_request().await { + let state = self.circuit_breaker.state().await; + return Err(AppError::CircuitOpen(format!( + "Circuit breaker is {} - service temporarily unavailable", + state + ))); + } + let timeout_duration = self.config.operation_timeout; // First attempt with timeout let result = tokio::time::timeout(timeout_duration, operation()).await; match result { - Ok(Ok(value)) => Ok(value), + Ok(Ok(value)) => { + self.circuit_breaker.record_success().await; + Ok(value) + } Ok(Err(e)) if Self::is_connection_error(&e) => { + self.circuit_breaker.record_failure().await; warn!(error = %e, "Operation failed due to connection error, attempting reconnect"); self.reconnect().await?; // Retry with timeout - tokio::time::timeout(timeout_duration, operation()) - .await - .map_err(|_| { - AppError::OperationTimeout(format!( + let retry_result = tokio::time::timeout(timeout_duration, operation()).await; + match retry_result { + Ok(Ok(value)) => { + self.circuit_breaker.record_success().await; + Ok(value) + } + Ok(Err(e)) => { + if Self::is_connection_error(&e) { + self.circuit_breaker.record_failure().await; + } + Err(e) + } + Err(_) => { + self.circuit_breaker.record_failure().await; + Err(AppError::OperationTimeout(format!( "Operation timed out after {:?} on retry", timeout_duration - )) - })? + ))) + } + } + } + Ok(Err(e)) => { + // Non-connection error - don't record as circuit breaker failure + Err(e) } - Ok(Err(e)) => Err(e), Err(_) => { // Timeout on first attempt // Only reconnect if we have evidence the connection is actually lost. // A timeout alone doesn't mean disconnection - could just be slow. if !self.state.is_connected() { + self.circuit_breaker.record_failure().await; warn!( timeout = ?timeout_duration, "Operation timed out and connection state is disconnected, attempting reconnect" @@ -342,16 +402,29 @@ impl IggyClientWrapper { self.reconnect().await?; // Retry with timeout - tokio::time::timeout(timeout_duration, operation()) - .await - .map_err(|_| { - AppError::OperationTimeout(format!( + let retry_result = tokio::time::timeout(timeout_duration, operation()).await; + match retry_result { + Ok(Ok(value)) => { + self.circuit_breaker.record_success().await; + Ok(value) + } + Ok(Err(e)) => { + if Self::is_connection_error(&e) { + self.circuit_breaker.record_failure().await; + } + Err(e) + } + Err(_) => { + self.circuit_breaker.record_failure().await; + Err(AppError::OperationTimeout(format!( "Operation timed out after {:?} on retry", timeout_duration - )) - })? + ))) + } + } } else { // Connection appears healthy - this is just a slow operation + // Don't record as circuit breaker failure (not a connection issue) debug!( timeout = ?timeout_duration, "Operation timed out but connection state is healthy, not reconnecting" @@ -864,6 +937,26 @@ impl IggyClientWrapper { pub fn config(&self) -> &Config { &self.config } + + /// Get the current circuit breaker state. + pub async fn circuit_breaker_state(&self) -> CircuitState { + self.circuit_breaker.state().await + } + + /// Get circuit breaker metrics. + /// + /// Returns a tuple of (times_opened, requests_rejected). + pub fn circuit_breaker_metrics(&self) -> (u32, u64) { + ( + self.circuit_breaker.times_opened(), + self.circuit_breaker.requests_rejected(), + ) + } + + /// Force close the circuit breaker (for manual recovery). + pub async fn force_close_circuit(&self) { + self.circuit_breaker.force_close().await; + } } #[cfg(test)] @@ -901,6 +994,7 @@ mod tests { AppError::PollError("poll failed".to_string()), AppError::ConfigError("config issue".to_string()), AppError::OperationTimeout("timed out".to_string()), + AppError::CircuitOpen("circuit open".to_string()), ]; for error in test_cases { diff --git a/src/lib.rs b/src/lib.rs index 6d30428..11ac36e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ pub mod config; pub mod error; pub mod handlers; pub mod iggy_client; +pub mod metrics; pub mod middleware; pub mod models; pub mod routes; diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..57d2c74 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,242 @@ +//! Prometheus metrics for application observability. +//! +//! This module provides Prometheus-compatible metrics for monitoring the application. +//! Metrics are exposed via a dedicated HTTP endpoint (default: `/metrics`). +//! +//! # Available Metrics +//! +//! ## Counters +//! - `iggy_messages_sent_total` - Total messages sent (with labels: stream, topic, status) +//! - `iggy_messages_polled_total` - Total messages polled (with labels: stream, topic) +//! - `iggy_connection_reconnects_total` - Total reconnection attempts +//! - `iggy_circuit_breaker_opens_total` - Times the circuit breaker opened +//! - `iggy_circuit_breaker_rejections_total` - Requests rejected by circuit breaker +//! +//! ## Histograms +//! - `iggy_request_duration_seconds` - Request duration (with labels: endpoint, method, status) +//! - `iggy_send_duration_seconds` - Message send duration +//! - `iggy_poll_duration_seconds` - Message poll duration +//! +//! ## Gauges +//! - `iggy_connection_status` - Current connection status (1 = connected, 0 = disconnected) +//! - `iggy_circuit_breaker_state` - Circuit breaker state (0 = closed, 1 = half-open, 2 = open) +//! +//! # Usage +//! +//! ```rust,ignore +//! use iggy_sample::metrics::{init_metrics, record_message_sent, record_request_duration}; +//! +//! // Initialize metrics (call once at startup) +//! init_metrics(); +//! +//! // Record metrics in handlers +//! record_message_sent("my-stream", "my-topic", "success"); +//! record_request_duration("/messages", "POST", "200", 0.045); +//! ``` + +use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; +use metrics_exporter_prometheus::PrometheusBuilder; +use std::net::SocketAddr; +use tracing::{error, info}; + +/// Metric names as constants for consistency. +pub mod names { + pub const MESSAGES_SENT_TOTAL: &str = "iggy_messages_sent_total"; + pub const MESSAGES_POLLED_TOTAL: &str = "iggy_messages_polled_total"; + pub const CONNECTION_RECONNECTS_TOTAL: &str = "iggy_connection_reconnects_total"; + pub const CIRCUIT_BREAKER_OPENS_TOTAL: &str = "iggy_circuit_breaker_opens_total"; + pub const CIRCUIT_BREAKER_REJECTIONS_TOTAL: &str = "iggy_circuit_breaker_rejections_total"; + pub const REQUEST_DURATION_SECONDS: &str = "iggy_request_duration_seconds"; + pub const SEND_DURATION_SECONDS: &str = "iggy_send_duration_seconds"; + pub const POLL_DURATION_SECONDS: &str = "iggy_poll_duration_seconds"; + pub const CONNECTION_STATUS: &str = "iggy_connection_status"; + pub const CIRCUIT_BREAKER_STATE: &str = "iggy_circuit_breaker_state"; +} + +/// Initialize the Prometheus metrics exporter. +/// +/// This sets up metric descriptions and starts the Prometheus HTTP listener +/// on the specified address (default: 0.0.0.0:9090). +/// +/// # Arguments +/// +/// * `metrics_addr` - Address for the Prometheus metrics endpoint +/// +/// # Returns +/// +/// `Ok(())` if initialization succeeds, `Err` with message otherwise. +pub fn init_metrics(metrics_addr: SocketAddr) -> Result<(), String> { + // Set up Prometheus exporter + PrometheusBuilder::new() + .with_http_listener(metrics_addr) + .install() + .map_err(|e| format!("Failed to install Prometheus exporter: {e}"))?; + + // Describe all metrics + describe_counter!( + names::MESSAGES_SENT_TOTAL, + "Total number of messages sent to Iggy" + ); + describe_counter!( + names::MESSAGES_POLLED_TOTAL, + "Total number of messages polled from Iggy" + ); + describe_counter!( + names::CONNECTION_RECONNECTS_TOTAL, + "Total number of connection reconnection attempts" + ); + describe_counter!( + names::CIRCUIT_BREAKER_OPENS_TOTAL, + "Total number of times the circuit breaker opened" + ); + describe_counter!( + names::CIRCUIT_BREAKER_REJECTIONS_TOTAL, + "Total number of requests rejected by circuit breaker" + ); + + describe_histogram!( + names::REQUEST_DURATION_SECONDS, + "HTTP request duration in seconds" + ); + describe_histogram!( + names::SEND_DURATION_SECONDS, + "Message send operation duration in seconds" + ); + describe_histogram!( + names::POLL_DURATION_SECONDS, + "Message poll operation duration in seconds" + ); + + describe_gauge!( + names::CONNECTION_STATUS, + "Iggy connection status (1 = connected, 0 = disconnected)" + ); + describe_gauge!( + names::CIRCUIT_BREAKER_STATE, + "Circuit breaker state (0 = closed, 1 = half-open, 2 = open)" + ); + + info!(addr = %metrics_addr, "Prometheus metrics endpoint started"); + Ok(()) +} + +/// Try to initialize metrics, logging any errors but not failing. +/// +/// This is useful for cases where metrics are optional. +pub fn try_init_metrics(metrics_addr: SocketAddr) { + if let Err(e) = init_metrics(metrics_addr) { + error!(error = %e, "Failed to initialize metrics, continuing without metrics"); + } +} + +// ============================================================================= +// Counter Recording Functions +// ============================================================================= + +/// Record a message sent event. +pub fn record_message_sent(stream: &str, topic: &str, status: &str) { + counter!(names::MESSAGES_SENT_TOTAL, "stream" => stream.to_string(), "topic" => topic.to_string(), "status" => status.to_string()) + .increment(1); +} + +/// Record messages sent in batch. +pub fn record_messages_sent_batch(stream: &str, topic: &str, status: &str, count: u64) { + counter!(names::MESSAGES_SENT_TOTAL, "stream" => stream.to_string(), "topic" => topic.to_string(), "status" => status.to_string()) + .increment(count); +} + +/// Record messages polled. +pub fn record_messages_polled(stream: &str, topic: &str, count: u64) { + counter!(names::MESSAGES_POLLED_TOTAL, "stream" => stream.to_string(), "topic" => topic.to_string()) + .increment(count); +} + +/// Record a reconnection attempt. +pub fn record_reconnect_attempt() { + counter!(names::CONNECTION_RECONNECTS_TOTAL).increment(1); +} + +/// Record circuit breaker opening. +pub fn record_circuit_breaker_open() { + counter!(names::CIRCUIT_BREAKER_OPENS_TOTAL).increment(1); +} + +/// Record circuit breaker rejection. +pub fn record_circuit_breaker_rejection() { + counter!(names::CIRCUIT_BREAKER_REJECTIONS_TOTAL).increment(1); +} + +// ============================================================================= +// Histogram Recording Functions +// ============================================================================= + +/// Record HTTP request duration. +pub fn record_request_duration(endpoint: &str, method: &str, status: &str, duration_secs: f64) { + histogram!(names::REQUEST_DURATION_SECONDS, "endpoint" => endpoint.to_string(), "method" => method.to_string(), "status" => status.to_string()) + .record(duration_secs); +} + +/// Record message send duration. +pub fn record_send_duration(stream: &str, topic: &str, duration_secs: f64) { + histogram!(names::SEND_DURATION_SECONDS, "stream" => stream.to_string(), "topic" => topic.to_string()) + .record(duration_secs); +} + +/// Record message poll duration. +pub fn record_poll_duration(stream: &str, topic: &str, duration_secs: f64) { + histogram!(names::POLL_DURATION_SECONDS, "stream" => stream.to_string(), "topic" => topic.to_string()) + .record(duration_secs); +} + +// ============================================================================= +// Gauge Recording Functions +// ============================================================================= + +/// Update connection status gauge. +pub fn set_connection_status(connected: bool) { + gauge!(names::CONNECTION_STATUS).set(if connected { 1.0 } else { 0.0 }); +} + +/// Update circuit breaker state gauge. +/// +/// States: 0 = closed, 1 = half-open, 2 = open +pub fn set_circuit_breaker_state(state: u8) { + gauge!(names::CIRCUIT_BREAKER_STATE).set(f64::from(state)); +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: These tests verify the functions don't panic. + // Full metrics testing requires integration tests with a Prometheus scraper. + + #[test] + fn test_record_message_sent() { + // Should not panic even without metrics initialized + record_message_sent("test-stream", "test-topic", "success"); + } + + #[test] + fn test_record_messages_polled() { + record_messages_polled("test-stream", "test-topic", 10); + } + + #[test] + fn test_record_request_duration() { + record_request_duration("/messages", "POST", "200", 0.1); + } + + #[test] + fn test_set_connection_status() { + set_connection_status(true); + set_connection_status(false); + } + + #[test] + fn test_set_circuit_breaker_state() { + set_circuit_breaker_state(0); // closed + set_circuit_breaker_state(1); // half-open + set_circuit_breaker_state(2); // open + } +} diff --git a/src/middleware/ip.rs b/src/middleware/ip.rs index 6ed2e63..3020178 100644 --- a/src/middleware/ip.rs +++ b/src/middleware/ip.rs @@ -132,6 +132,11 @@ enum ExtractedIp<'a> { /// - Returns borrowed `&str` slices pointing into the request headers /// - No allocations in this function /// - Caller is responsible for `.to_string()` if ownership is needed +/// +/// # Empty Header Handling +/// +/// Empty or whitespace-only headers are treated as `NotFound` to prevent +/// creating a separate rate-limit bucket for each empty-header request. #[inline] fn extract_ip_from_headers(req: &Request) -> ExtractedIp<'_> { // Check X-Forwarded-For first (maybe set by reverse proxy) @@ -140,14 +145,21 @@ fn extract_ip_from_headers(req: &Request) -> ExtractedIp<'_> { && let Ok(value) = forwarded.to_str() && let Some(first_ip) = value.split(',').next() { - return ExtractedIp::FromXff(first_ip.trim()); + let trimmed = first_ip.trim(); + // Treat empty strings as not found to avoid creating separate rate-limit buckets + if !trimmed.is_empty() { + return ExtractedIp::FromXff(trimmed); + } } // Check X-Real-IP (alternative header used by some proxies) if let Some(real_ip) = req.headers().get("x-real-ip") && let Ok(value) = real_ip.to_str() { - return ExtractedIp::FromRealIp(value.trim()); + let trimmed = value.trim(); + if !trimmed.is_empty() { + return ExtractedIp::FromRealIp(trimmed); + } } ExtractedIp::NotFound @@ -368,26 +380,48 @@ mod tests { #[test] fn test_extract_ip_empty_xff_header() { - // Empty header value should fall back to unknown + // Empty header value should fall back to unknown to prevent + // creating separate rate-limit buckets for empty-header requests let req = Request::builder() .header("x-forwarded-for", "") .body(Body::empty()) .unwrap(); - // Empty string after trim is still returned (split returns one empty element) - // This matches real-world behavior where empty headers are sometimes sent - assert_eq!(extract_client_ip(&req), ""); + assert_eq!(extract_client_ip(&req), "unknown"); } #[test] fn test_extract_ip_whitespace_only_xff() { - // Whitespace-only header should return empty string after trim + // Whitespace-only header should fall back to unknown let req = Request::builder() .header("x-forwarded-for", " ") .body(Body::empty()) .unwrap(); - assert_eq!(extract_client_ip(&req), ""); + assert_eq!(extract_client_ip(&req), "unknown"); + } + + #[test] + fn test_extract_ip_empty_xff_falls_back_to_real_ip() { + // Empty XFF should fall through to X-Real-IP + let req = Request::builder() + .header("x-forwarded-for", "") + .header("x-real-ip", "192.168.1.1") + .body(Body::empty()) + .unwrap(); + + assert_eq!(extract_client_ip(&req), "192.168.1.1"); + } + + #[test] + fn test_extract_ip_empty_real_ip_header() { + // Empty X-Real-IP should fall back to unknown + let req = Request::builder() + .header("x-real-ip", "") + .body(Body::empty()) + .unwrap(); + + assert_eq!(extract_client_ip(&req), "unknown"); } #[test] diff --git a/src/middleware/mod.rs b/src/middleware/mod.rs index 36aaf57..8961bfd 100644 --- a/src/middleware/mod.rs +++ b/src/middleware/mod.rs @@ -5,14 +5,15 @@ //! - **Rate Limiting**: Token bucket algorithm with configurable RPS and burst //! - **API Key Authentication**: Constant-time comparison for security //! - **Request ID**: Automatic generation and propagation for distributed tracing +//! - **Request Timeout**: Client-specified timeout propagation //! - **Trusted Proxy Validation**: CIDR-based proxy source validation //! //! # Architecture //! //! ```text -//! Request → Rate Limiter → Auth → Request ID → Handler → Response -//! ↓ ↓ ↓ -//! 429 Too Many 401 Unauth X-Request-Id header +//! Request → Rate Limiter → Auth → Timeout → Request ID → Handler → Response +//! ↓ ↓ ↓ ↓ +//! 429 Too Many 401 Unauth ext X-Request-Id header //! ``` //! //! # Security Considerations @@ -21,13 +22,19 @@ //! - Rate limiting prevents abuse and DoS attacks //! - Trusted proxy configuration mitigates IP spoofing attacks //! - Request IDs enable audit trails and debugging +//! - Request timeout bounds prevent abuse via extreme values pub mod auth; pub mod ip; pub mod rate_limit; pub mod request_id; +pub mod timeout; pub use auth::ApiKeyAuth; pub use ip::{UNKNOWN_IP, extract_client_ip, extract_client_ip_with_validation}; pub use rate_limit::{RateLimitError, RateLimitLayer, TrustedProxyConfig}; pub use request_id::RequestIdLayer; +pub use timeout::{ + MAX_REQUEST_TIMEOUT_MS, MIN_REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_HEADER, RequestTimeout, + RequestTimeoutExt, extract_request_timeout, +}; diff --git a/src/middleware/timeout.rs b/src/middleware/timeout.rs new file mode 100644 index 0000000..0e1b02d --- /dev/null +++ b/src/middleware/timeout.rs @@ -0,0 +1,183 @@ +//! Request timeout propagation middleware. +//! +//! This module provides middleware for propagating client-specified request timeouts, +//! allowing clients to specify how long they're willing to wait for a response. +//! +//! # Usage +//! +//! Clients can specify a timeout via the `X-Request-Timeout` header: +//! ```text +//! X-Request-Timeout: 5000 # 5 seconds in milliseconds +//! ``` +//! +//! Handlers can then extract this timeout: +//! ```rust,ignore +//! async fn handler( +//! timeout: Option>, +//! // ... +//! ) -> impl IntoResponse { +//! let effective_timeout = timeout +//! .map(|t| t.0.duration) +//! .unwrap_or(default_timeout); +//! // Use effective_timeout for operations +//! } +//! ``` +//! +//! # Benefits +//! +//! - Prevents wasted work when clients have already given up +//! - Allows clients to specify shorter timeouts for time-sensitive operations +//! - Enables deadline propagation in distributed systems +//! +//! # Security Considerations +//! +//! - Minimum and maximum timeout bounds are enforced to prevent abuse +//! - Invalid values are ignored (fall back to server default) +//! - Zero or negative values are rejected + +use std::time::Duration; + +use axum::extract::Request; +use axum::middleware::Next; +use axum::response::Response; +use tracing::debug; + +/// Minimum allowed request timeout (100ms). +/// +/// Prevents clients from requesting unreasonably short timeouts that +/// would cause operations to always fail. +pub const MIN_REQUEST_TIMEOUT_MS: u64 = 100; + +/// Maximum allowed request timeout (5 minutes). +/// +/// Prevents clients from requesting excessively long timeouts that +/// could tie up resources. Adjust based on your longest expected operation. +pub const MAX_REQUEST_TIMEOUT_MS: u64 = 300_000; + +/// Header name for client-specified request timeout. +pub const REQUEST_TIMEOUT_HEADER: &str = "x-request-timeout"; + +/// Extracted request timeout from client header. +/// +/// This is stored in request extensions and can be extracted by handlers. +#[derive(Debug, Clone, Copy)] +pub struct RequestTimeout { + /// The timeout duration specified by the client. + pub duration: Duration, + /// The original value from the header (for logging). + pub original_ms: u64, +} + +impl RequestTimeout { + /// Create a new RequestTimeout from milliseconds. + /// + /// Returns `None` if the value is outside the allowed range. + pub fn from_millis(ms: u64) -> Option { + if !(MIN_REQUEST_TIMEOUT_MS..=MAX_REQUEST_TIMEOUT_MS).contains(&ms) { + return None; + } + Some(Self { + duration: Duration::from_millis(ms), + original_ms: ms, + }) + } +} + +/// Middleware that extracts and validates the `X-Request-Timeout` header. +/// +/// If a valid timeout is present, it's stored in request extensions +/// for handlers to use. +pub async fn extract_request_timeout(mut request: Request, next: Next) -> Response { + // Try to extract and parse the timeout header + if let Some(timeout_value) = request.headers().get(REQUEST_TIMEOUT_HEADER) + && let Ok(value_str) = timeout_value.to_str() + { + if let Ok(ms) = value_str.trim().parse::() { + if let Some(timeout) = RequestTimeout::from_millis(ms) { + debug!( + timeout_ms = ms, + "Client specified request timeout via header" + ); + request.extensions_mut().insert(timeout); + } else { + debug!( + timeout_ms = ms, + min = MIN_REQUEST_TIMEOUT_MS, + max = MAX_REQUEST_TIMEOUT_MS, + "Client timeout outside allowed range, ignoring" + ); + } + } else { + debug!( + value = value_str, + "Invalid X-Request-Timeout header value, ignoring" + ); + } + } + + next.run(request).await +} + +/// Extension trait for extracting request timeout from request extensions. +pub trait RequestTimeoutExt { + /// Get the client-specified timeout, or fall back to the provided default. + fn effective_timeout(&self, default: Duration) -> Duration; +} + +impl RequestTimeoutExt for axum::http::Request { + fn effective_timeout(&self, default: Duration) -> Duration { + self.extensions() + .get::() + .map(|t| t.duration) + .unwrap_or(default) + } +} + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::expect_used)] +mod tests { + use super::*; + + #[test] + fn test_request_timeout_from_millis_valid() { + let timeout = RequestTimeout::from_millis(5000).unwrap(); + assert_eq!(timeout.duration, Duration::from_millis(5000)); + assert_eq!(timeout.original_ms, 5000); + } + + #[test] + fn test_request_timeout_from_millis_minimum() { + let timeout = RequestTimeout::from_millis(MIN_REQUEST_TIMEOUT_MS).unwrap(); + assert_eq!( + timeout.duration, + Duration::from_millis(MIN_REQUEST_TIMEOUT_MS) + ); + } + + #[test] + fn test_request_timeout_from_millis_maximum() { + let timeout = RequestTimeout::from_millis(MAX_REQUEST_TIMEOUT_MS).unwrap(); + assert_eq!( + timeout.duration, + Duration::from_millis(MAX_REQUEST_TIMEOUT_MS) + ); + } + + #[test] + fn test_request_timeout_from_millis_too_low() { + let timeout = RequestTimeout::from_millis(MIN_REQUEST_TIMEOUT_MS - 1); + assert!(timeout.is_none()); + } + + #[test] + fn test_request_timeout_from_millis_too_high() { + let timeout = RequestTimeout::from_millis(MAX_REQUEST_TIMEOUT_MS + 1); + assert!(timeout.is_none()); + } + + #[test] + fn test_request_timeout_from_millis_zero() { + let timeout = RequestTimeout::from_millis(0); + assert!(timeout.is_none()); + } +} diff --git a/src/routes.rs b/src/routes.rs index 2a60a4d..74c5df6 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -43,13 +43,16 @@ use axum::Router; use axum::extract::DefaultBodyLimit; +use axum::middleware; use axum::routing::{delete, get, post}; use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; use tracing::info; use crate::handlers; -use crate::middleware::{ApiKeyAuth, RateLimitError, RateLimitLayer, RequestIdLayer}; +use crate::middleware::{ + ApiKeyAuth, RateLimitError, RateLimitLayer, RequestIdLayer, extract_request_timeout, +}; use crate::state::AppState; /// Build the application router with all routes and middleware configured. @@ -133,10 +136,14 @@ pub fn build_router(state: AppState) -> Result { // 3. Tracing router = router.layer(TraceLayer::new_for_http()); - // 4. Request ID + // 4. Request Timeout propagation + // Extracts X-Request-Timeout header and stores in request extensions + router = router.layer(middleware::from_fn(extract_request_timeout)); + + // 5. Request ID router = router.layer(RequestIdLayer::new()); - // 5. Authentication (if enabled) + // 6. Authentication (if enabled) let auth_layer = ApiKeyAuth::new(config.api_key.clone(), config.auth_bypass_paths.clone()); if auth_layer.is_enabled() { info!("API key authentication enabled"); @@ -145,7 +152,7 @@ pub fn build_router(state: AppState) -> Result { info!("API key authentication disabled (no API_KEY set)"); } - // 6. Rate Limiting (if enabled) - applied first, runs last in request pipeline + // 7. Rate Limiting (if enabled) - applied first, runs last in request pipeline if config.rate_limiting_enabled() { info!( rps = config.rate_limit_rps, diff --git a/src/validation.rs b/src/validation.rs index 36313cb..795b5ca 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -27,6 +27,13 @@ pub const MAX_PARTITIONS: u32 = 1000; /// Minimum number of partitions per topic. pub const MIN_PARTITIONS: u32 = 1; +/// Maximum consumer ID value. +/// +/// While Iggy uses u32 for consumer IDs, we set a reasonable upper bound +/// to detect likely misconfigurations (e.g., passing garbage data). +/// This value (1 billion) is high enough for any realistic use case. +pub const MAX_CONSUMER_ID: u32 = 1_000_000_000; + /// Validate a resource name (stream or topic). /// /// Rules: @@ -155,14 +162,21 @@ pub fn validate_partition_id(_partition_id: u32) -> AppResult<()> { /// Validate a consumer ID for polling. /// -/// Consumer IDs must be positive (at least 1). -/// A consumer_id of 0 is invalid. +/// Consumer IDs must be between 1 and [`MAX_CONSUMER_ID`] (inclusive). +/// A consumer_id of 0 is invalid, and values above MAX_CONSUMER_ID +/// likely indicate a misconfiguration. pub fn validate_consumer_id(consumer_id: u32) -> AppResult<()> { if consumer_id == 0 { return Err(AppError::BadRequest( "Consumer ID must be at least 1".to_string(), )); } + if consumer_id > MAX_CONSUMER_ID { + return Err(AppError::BadRequest(format!( + "Consumer ID {} exceeds maximum of {}", + consumer_id, MAX_CONSUMER_ID + ))); + } Ok(()) } @@ -328,7 +342,8 @@ mod tests { fn test_valid_consumer_ids() { assert!(validate_consumer_id(1).is_ok()); assert!(validate_consumer_id(100).is_ok()); - assert!(validate_consumer_id(u32::MAX).is_ok()); + assert!(validate_consumer_id(1_000_000).is_ok()); + assert!(validate_consumer_id(MAX_CONSUMER_ID).is_ok()); } #[test] @@ -337,4 +352,18 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("at least 1")); } + + #[test] + fn test_invalid_consumer_id_too_high() { + let result = validate_consumer_id(MAX_CONSUMER_ID + 1); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exceeds maximum")); + } + + #[test] + fn test_invalid_consumer_id_max_u32() { + let result = validate_consumer_id(u32::MAX); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exceeds maximum")); + } } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index ae2f117..6852ecd 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -159,6 +159,10 @@ impl TestFixture { reconnect_max_delay: Duration::from_secs(1), health_check_interval: Duration::from_secs(30), operation_timeout: Duration::from_secs(30), + // Circuit breaker (default settings for tests) + circuit_breaker_failure_threshold: 5, + circuit_breaker_success_threshold: 2, + circuit_breaker_open_duration: Duration::from_secs(30), // Rate limiting (disabled for tests) rate_limit_rps: 0, rate_limit_burst: 50, @@ -174,6 +178,7 @@ impl TestFixture { // Observability log_level: "warn".to_string(), stats_cache_ttl: Duration::from_secs(5), + metrics_port: 0, // Disabled for tests }; let iggy_client = IggyClientWrapper::new(config.clone()) @@ -1060,6 +1065,10 @@ impl SecureTestFixture { reconnect_max_delay: Duration::from_secs(1), health_check_interval: Duration::from_secs(30), operation_timeout: Duration::from_secs(30), + // Circuit breaker (default settings for tests) + circuit_breaker_failure_threshold: 5, + circuit_breaker_success_threshold: 2, + circuit_breaker_open_duration: Duration::from_secs(30), // Rate limiting enabled - 5 RPS with burst of 2 for testing rate_limit_rps: 5, rate_limit_burst: 2, @@ -1073,6 +1082,7 @@ impl SecureTestFixture { trusted_proxies: vec![], // Empty = trust all (test mode) log_level: "warn".to_string(), stats_cache_ttl: Duration::from_secs(5), + metrics_port: 0, // Disabled for tests }; let iggy_client = IggyClientWrapper::new(config.clone())