diff --git a/keylime-push-model-agent/src/main.rs b/keylime-push-model-agent/src/main.rs index a3ed257b..89f0a062 100644 --- a/keylime-push-model-agent/src/main.rs +++ b/keylime-push-model-agent/src/main.rs @@ -5,7 +5,6 @@ use clap::Parser; use keylime::config::PushModelConfigTrait; use log::{debug, error, info}; mod attestation; -mod auth; mod context_info_handler; mod registration; mod response_handler; diff --git a/keylime-push-model-agent/src/auth.rs b/keylime/src/auth.rs similarity index 78% rename from keylime-push-model-agent/src/auth.rs rename to keylime/src/auth.rs index 01600d18..70aa8bac 100644 --- a/keylime-push-model-agent/src/auth.rs +++ b/keylime/src/auth.rs @@ -1,26 +1,24 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright 2025 Keylime Authors -//! Challenge-Response Authentication Module +//! Authentication types and utilities //! -//! This module implements the challenge-response authentication protocol -//! as described in Keylime Enhancement 103. It provides a standalone -//! authentication client that can be used independently or integrated -//! with existing HTTP clients. +//! This module provides common types and utilities for authentication +//! that are shared between different components of the Keylime system. -use anyhow::{anyhow, Result}; -use chrono::{DateTime, Duration, Utc}; -use keylime::structures::{ +use crate::structures::{ ProofOfPossession, SessionIdResponse, SessionRequest, SessionRequestAttributes, SessionRequestData, SessionResponse, SupportedAuthMethod, }; +use anyhow::{anyhow, Result}; +use chrono::{DateTime, Duration, Utc}; use log::{debug, info, warn}; -use reqwest::{Client, Method, StatusCode}; +use reqwest::Client; use std::sync::Arc; use tokio::sync::Mutex; -/// Configuration for the authentication client +/// Configuration for authentication #[derive(Debug, Clone)] pub struct AuthConfig { /// Base URL of the verifier (e.g., "https://verifier.example.com") @@ -52,14 +50,14 @@ impl Default for AuthConfig { /// Session token with expiration information #[derive(Debug, Clone)] -struct SessionToken { - token: String, - expires_at: DateTime, - session_id: u64, +pub struct SessionToken { + pub token: String, + pub expires_at: DateTime, + pub session_id: u64, } impl SessionToken { - fn is_valid(&self, buffer_minutes: i64) -> bool { + pub fn is_valid(&self, buffer_minutes: i64) -> bool { let buffer = Duration::minutes(buffer_minutes); Utc::now() + buffer < self.expires_at } @@ -67,7 +65,10 @@ impl SessionToken { /// Mock TPM operations for testing pub trait TpmOperations: Send + Sync { - fn generate_proof(&self, challenge: &str) -> Result; + fn generate_proof( + &self, + challenge: &str, + ) -> Result; } /// Default mock TPM implementation @@ -75,7 +76,11 @@ pub trait TpmOperations: Send + Sync { pub struct MockTpmOperations; impl TpmOperations for MockTpmOperations { - fn generate_proof(&self, challenge: &str) -> Result { + fn generate_proof( + &self, + challenge: &str, + ) -> Result { + use log::debug; debug!("Generating mock TPM proof for challenge: {challenge}"); // Create a deterministic but unique proof based on the challenge @@ -84,7 +89,7 @@ impl TpmOperations for MockTpmOperations { use base64::{engine::general_purpose, Engine as _}; - Ok(ProofOfPossession { + Ok(crate::structures::ProofOfPossession { message: general_purpose::STANDARD.encode(message), signature: general_purpose::STANDARD.encode(signature), }) @@ -99,6 +104,17 @@ pub struct AuthenticationClient { tpm_ops: Box, } +impl std::fmt::Debug for AuthenticationClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthenticationClient") + .field("config", &self.config) + .field("http_client", &"") + .field("session_token", &">>") + .field("tpm_ops", &">") + .finish() + } +} + impl AuthenticationClient { /// Create a new authentication client with the given configuration pub fn new(config: AuthConfig) -> Result { @@ -135,15 +151,63 @@ impl AuthenticationClient { }) } - /// Get a valid authentication token, performing authentication if necessary - pub async fn get_auth_token(&self) -> Result { + /// Create a raw authentication client with no middleware + /// This is used internally by the authentication middleware to avoid infinite loops + pub fn new_raw(config: AuthConfig) -> Result { + let timeout = std::time::Duration::from_millis(config.timeout_ms); + let http_client = Client::builder() + .timeout(timeout) + .danger_accept_invalid_certs(true) // For testing + .build()?; + + Ok(Self { + config, + http_client, + session_token: Arc::new(Mutex::new(None)), + tpm_ops: Box::new(MockTpmOperations), + }) + } + + /// Create a raw authentication client with custom TPM operations and no middleware + pub fn new_raw_with_tpm_ops( + config: AuthConfig, + tpm_ops: Box, + ) -> Result { + let timeout = std::time::Duration::from_millis(config.timeout_ms); + let http_client = Client::builder() + .timeout(timeout) + .danger_accept_invalid_certs(true) // For testing + .build()?; + + Ok(Self { + config, + http_client, + session_token: Arc::new(Mutex::new(None)), + tpm_ops, + }) + } + + /// Get the authentication configuration + pub fn config(&self) -> &AuthConfig { + &self.config + } + + /// Get a valid authentication token with metadata (token, expiration, session_id) + /// This method is used by the authentication middleware to access token details + pub async fn get_auth_token_with_metadata( + &self, + ) -> Result<(String, DateTime, u64)> { let token_guard = self.session_token.lock().await; // Check if we have a valid token if let Some(ref token) = *token_guard { if token.is_valid(self.config.token_refresh_buffer_minutes) { - debug!("Using existing valid token"); - return Ok(token.token.clone()); + debug!("Using existing valid token with metadata"); + return Ok(( + token.token.clone(), + token.expires_at, + token.session_id, + )); } else { debug!( "Token expired or expiring soon, need to re-authenticate" @@ -155,27 +219,20 @@ impl AuthenticationClient { drop(token_guard); // Release lock before authentication - // Perform authentication - self.authenticate().await - } + // Perform authentication and return metadata + let _token_string = self.authenticate().await?; - /// Check if we currently have a valid token - pub async fn has_valid_token(&self) -> bool { + // Get the token details from the newly stored token let token_guard = self.session_token.lock().await; if let Some(ref token) = *token_guard { - token.is_valid(self.config.token_refresh_buffer_minutes) + Ok((token.token.clone(), token.expires_at, token.session_id)) } else { - false + Err(anyhow!( + "Token was not stored properly after authentication" + )) } } - /// Clear the current token (e.g., after receiving 401) - pub async fn clear_token(&self) { - let mut token_guard = self.session_token.lock().await; - *token_guard = None; - debug!("Authentication token cleared"); - } - /// Perform the complete authentication flow async fn authenticate(&self) -> Result { info!( @@ -410,41 +467,12 @@ impl AuthenticationClient { Ok(token.clone()) } - - /// Make an authenticated HTTP request (convenience method for testing) - pub async fn make_authenticated_request( - &self, - method: Method, - url: &str, - body: Option, - ) -> Result { - let token = self.get_auth_token().await?; - - let mut request = self.http_client.request(method, url); - request = request.header("Authorization", format!("Bearer {token}")); - - if let Some(body) = body { - request = request - .header("Content-Type", "application/vnd.api+json") - .body(body); - } - - let response = request.send().await?; - - // Handle 401 responses by clearing token - if response.status() == StatusCode::UNAUTHORIZED { - warn!("Received 401, clearing token"); - self.clear_token().await; - return Err(anyhow!("Authentication token was rejected (401)")); - } - - Ok(response) - } } #[cfg(test)] mod tests { use super::*; + use serde_json; use wiremock::matchers::{header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -460,7 +488,7 @@ mod tests { max_auth_retries: 2, }; - AuthenticationClient::new(config).unwrap() + AuthenticationClient::new(config).unwrap() //#[allow_ci] } #[tokio::test] @@ -531,16 +559,16 @@ mod tests { let client = create_test_client(&mock_server.uri()).await; - // Test authentication - let token = client.get_auth_token().await.unwrap(); + // Test authentication - get token with metadata since that's our main method + let (token, _expires_at, session_id) = + client.get_auth_token_with_metadata().await.unwrap(); //#[allow_ci] assert_eq!(token, "test-token-456"); + assert_eq!(session_id, 1); - // Test that token is cached - assert!(client.has_valid_token().await); - - // Test that subsequent calls use cached token - let token2 = client.get_auth_token().await.unwrap(); - assert_eq!(token2, "test-token-456"); + // Verify token is valid + let token_guard = client.session_token.lock().await; + let session_token = token_guard.as_ref().unwrap(); //#[allow_ci] + assert!(session_token.is_valid(5)); } #[tokio::test] @@ -607,10 +635,10 @@ mod tests { let client = create_test_client(&mock_server.uri()).await; - let result = client.get_auth_token().await; + let result = client.get_auth_token_with_metadata().await; assert!(result.is_err()); assert!(result - .unwrap_err() + .unwrap_err() //#[allow_ci] .to_string() .contains("Authentication failed")); } @@ -642,7 +670,7 @@ mod tests { } }), )) - .expect(1..) // May be called multiple times + .expect(1..) //#[allow_ci] // May be called multiple times .mount(&mock_server) .await; @@ -678,7 +706,7 @@ mod tests { } }), )) - .expect(1..) // May be called multiple times + .expect(1..) //#[allow_ci] // May be called multiple times .mount(&mock_server) .await; @@ -691,36 +719,60 @@ mod tests { max_auth_retries: 2, }; - let client = AuthenticationClient::new(config).unwrap(); + let client = AuthenticationClient::new(config).unwrap(); //#[allow_ci] // Since token expires in 1 minute but we have 5 minute buffer, // it should be considered invalid and trigger re-authentication - let token = client.get_auth_token().await.unwrap(); + let (token, _, _) = + client.get_auth_token_with_metadata().await.unwrap(); //#[allow_ci] assert_eq!(token, "short-lived-token"); // Check that token is considered invalid due to buffer - assert!(!client.has_valid_token().await); + let token_guard = client.session_token.lock().await; + let session_token = token_guard.as_ref().unwrap(); //#[allow_ci] + assert!(!session_token.is_valid(5)); } #[tokio::test] - async fn test_clear_token() { - let mock_server = MockServer::start().await; - let client = create_test_client(&mock_server.uri()).await; + async fn test_raw_client_creation() { + let config = AuthConfig { + verifier_base_url: "https://127.0.0.1:8881".to_string(), + agent_id: "test-agent-raw".to_string(), + avoid_tpm: true, + timeout_ms: 1000, + token_refresh_buffer_minutes: 5, + max_auth_retries: 2, + }; - // Manually insert a token - { - let mut token_guard = client.session_token.lock().await; - *token_guard = Some(SessionToken { - token: "test-token".to_string(), - expires_at: Utc::now() + Duration::hours(1), - session_id: 1, - }); - } + let raw_client = AuthenticationClient::new_raw(config).unwrap(); //#[allow_ci] - assert!(client.has_valid_token().await); + // Verify the client was created successfully + assert_eq!(raw_client.config.agent_id, "test-agent-raw"); + assert_eq!(raw_client.config.timeout_ms, 1000); + assert!(raw_client.config.avoid_tpm); + } - client.clear_token().await; + #[tokio::test] + async fn test_raw_client_with_tpm_ops() { + let config = AuthConfig { + verifier_base_url: "https://127.0.0.1:8881".to_string(), + agent_id: "test-agent-raw-tpm".to_string(), + avoid_tpm: false, + timeout_ms: 2000, + token_refresh_buffer_minutes: 10, + max_auth_retries: 1, + }; - assert!(!client.has_valid_token().await); + let custom_tpm_ops = Box::new(MockTpmOperations); + let raw_client = AuthenticationClient::new_raw_with_tpm_ops( + config, + custom_tpm_ops, + ) + .unwrap(); //#[allow_ci] + + // Verify the client was created successfully + assert_eq!(raw_client.config.agent_id, "test-agent-raw-tpm"); + assert_eq!(raw_client.config.timeout_ms, 2000); + assert!(!raw_client.config.avoid_tpm); } } diff --git a/keylime/src/lib.rs b/keylime/src/lib.rs index 988ef8e1..7c4c6c86 100644 --- a/keylime/src/lib.rs +++ b/keylime/src/lib.rs @@ -2,6 +2,7 @@ pub mod agent_data; pub mod agent_identity; pub mod agent_registration; pub mod algorithms; +pub mod auth; pub mod boot_time; pub mod cert; pub mod config; diff --git a/keylime/src/resilient_client.rs b/keylime/src/resilient_client.rs index 3bf33171..40d146af 100644 --- a/keylime/src/resilient_client.rs +++ b/keylime/src/resilient_client.rs @@ -1,3 +1,5 @@ +use crate::auth::{AuthConfig, AuthenticationClient, SessionToken}; +use anyhow; use async_trait::async_trait; use chrono::Utc; use http::Extensions; @@ -15,7 +17,9 @@ use reqwest_retry::{ Retryable, RetryableStrategy, }; use serde::Serialize; +use std::sync::Arc; use std::time::Duration; +use tokio::sync::{Mutex, RwLock}; // We define a default maximum delay for retries, which in the pracitical sense // is set to 1 hour. This can be adjusted based on the application's needs. @@ -186,6 +190,181 @@ fn parse_retry_after( None } +/// Shared state for authentication tokens with proper concurrency control +#[derive(Debug)] +struct TokenState { + /// RwLock for the actual token - allows concurrent reads + token: RwLock>, + /// Mutex for refresh operations - ensures single writer + refresh_lock: Mutex<()>, + /// Raw authentication client (no middleware to avoid loops) + auth_client: AuthenticationClient, +} + +impl TokenState { + fn new( + auth_config: AuthConfig, + ) -> Result> { + // Create a raw authentication client to avoid middleware loops + let auth_client = AuthenticationClient::new_raw(auth_config) + .map_err(|e| format!("Failed to create auth client: {}", e))?; + + Ok(Self { + token: RwLock::new(None), + refresh_lock: Mutex::new(()), + auth_client, + }) + } + + async fn get_valid_token( + &self, + ) -> Result> { + // Fast path: try to read existing valid token + { + let token_guard = self.token.read().await; + if let Some(ref token) = *token_guard { + if token.is_valid( + self.auth_client.config().token_refresh_buffer_minutes, + ) { + debug!("Using existing valid token from middleware"); + return Ok(token.token.clone()); + } + } + } + + // Slow path: token is invalid or missing, need to refresh + self.refresh_token().await + } + + async fn refresh_token( + &self, + ) -> Result> { + // Acquire refresh lock to ensure only one refresh at a time + let _refresh_guard = self.refresh_lock.lock().await; + + // Double-check: another request might have refreshed while we waited + { + let token_guard = self.token.read().await; + if let Some(ref token) = *token_guard { + if token.is_valid( + self.auth_client.config().token_refresh_buffer_minutes, + ) { + debug!("Token was refreshed by another request"); + return Ok(token.token.clone()); + } + } + } + + // Use the raw authentication client to get a new token with metadata + debug!("Performing token refresh using raw authentication client"); + match self.auth_client.get_auth_token_with_metadata().await { + Ok((token_string, expires_at, session_id)) => { + let new_token = SessionToken { + token: token_string.clone(), + expires_at, + session_id, + }; + + // Store the new token + { + let mut token_guard = self.token.write().await; + *token_guard = Some(new_token); + } + + debug!("Token refresh completed successfully"); + Ok(token_string) + } + Err(e) => { + warn!("Token refresh failed: {}", e); + Err(format!("Authentication failed: {}", e).into()) + } + } + } + + async fn clear_token(&self) { + let mut token_guard = self.token.write().await; + *token_guard = None; + debug!("Authentication token cleared from shared state"); + } +} + +/// Middleware for transparent authentication using challenge-response protocol +#[derive(Debug)] +pub struct AuthenticationMiddleware { + token_state: Arc, +} + +impl AuthenticationMiddleware { + pub fn new( + auth_config: AuthConfig, + ) -> Result> { + let token_state = Arc::new(TokenState::new(auth_config)?); + Ok(Self { token_state }) + } + + fn is_auth_endpoint(&self, req: &reqwest::Request) -> bool { + let path = req.url().path(); + // Skip authentication for auth endpoints to prevent infinite loops + path.contains("/sessions") + } +} + +#[async_trait] +impl Middleware for AuthenticationMiddleware { + async fn handle( + &self, + mut req: reqwest::Request, + extensions: &mut Extensions, + next: Next<'_>, + ) -> Result { + // Skip authentication for auth endpoints to prevent infinite loops + if self.is_auth_endpoint(&req) { + debug!( + "Skipping auth for authentication endpoint: {}", + req.url().path() + ); + return next.run(req, extensions).await; + } + + // Add Authorization header if not present + if !req.headers().contains_key("Authorization") { + match self.token_state.get_valid_token().await { + Ok(token) => { + debug!("Adding authentication token to request"); + req.headers_mut().insert( + "Authorization", + format!("Bearer {}", token).parse().map_err(|e| { + Error::Middleware(anyhow::anyhow!( + "Invalid token format: {}", + e + )) + })?, + ); + } + Err(e) => { + warn!("Failed to get auth token: {}", e); + return Err(Error::Middleware(anyhow::anyhow!( + "Authentication failed: {}", + e + ))); + } + } + } + + let response = next.run(req, extensions).await?; + + // Handle 401 responses by clearing token + if response.status() == StatusCode::UNAUTHORIZED { + warn!("Received 401, clearing token for future requests"); + self.token_state.clear_token().await; + // Note: We don't retry here to avoid infinite loops + // The retry will happen naturally on the next request + } + + Ok(response) + } +} + /// A client that transparently handles retries with exponential backoff. #[derive(Debug, Clone)] pub struct ResilientClient { @@ -225,6 +404,46 @@ impl ResilientClient { } } + /// Creates a new client with optional authentication middleware + pub fn new_with_auth( + client: Option, + auth_config: Option, + initial_delay: std::time::Duration, + max_retries: u32, + success_codes: &[StatusCode], + max_delay: Option, + ) -> Result> { + let base_client = client.unwrap_or_default(); + let final_max_delay = max_delay.unwrap_or(DEFAULT_MAX_DELAY); + + let retry_policy = ExponentialBackoff::builder() + .retry_bounds(initial_delay, final_max_delay) + .jitter(Jitter::None) + .build_with_max_retries(max_retries); + + let mut builder = ClientBuilder::new(base_client).with( + RetryTransientMiddleware::new_with_policy_and_strategy( + retry_policy, + StopOnSuccessStrategy { + success_codes: success_codes.to_vec(), + }, + ), + ); + + // Add authentication middleware if config is provided + if let Some(auth_cfg) = auth_config { + debug!("Adding authentication middleware to client"); + let auth_middleware = AuthenticationMiddleware::new(auth_cfg)?; + builder = builder.with(auth_middleware); + } + + let client_with_middleware = builder.with(LoggingMiddleware).build(); + + Ok(Self { + client: client_with_middleware, + }) + } + /// Generates a six-character lowercase alphanumeric request ID. fn generate_request_id() -> String { const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789"; @@ -871,4 +1090,262 @@ mod tests { let result = parse_retry_after(&header); assert!(result.is_some() || result.is_none()); // Just ensure no panic } + + #[tokio::test] + async fn test_resilient_client_with_auth_config() { + use crate::auth::AuthConfig; + + let auth_config = AuthConfig { + verifier_base_url: "https://verifier.example.com".to_string(), + agent_id: "test-agent".to_string(), + avoid_tpm: true, + timeout_ms: 5000, + token_refresh_buffer_minutes: 5, + max_auth_retries: 3, + }; + + // Test with authentication + let _client_with_auth = ResilientClient::new_with_auth( + None, + Some(auth_config), + std::time::Duration::from_millis(10), + 3, + &[StatusCode::OK], + None, + ) + .unwrap(); //#[allow_ci] + + // Verify the client was created successfully + // (We can't easily test the middleware behavior without a mock server, + // but we can at least verify the client creation doesn't panic) + } + + #[tokio::test] + async fn test_resilient_client_without_auth_config() { + // Test without authentication (should behave like the original client) + let _client_without_auth = ResilientClient::new_with_auth( + None, + None, // No auth config + std::time::Duration::from_millis(10), + 3, + &[StatusCode::OK], + None, + ) + .unwrap(); //#[allow_ci] + + // Verify the client was created successfully + } + + #[tokio::test] + async fn test_authentication_middleware_path_detection() { + use crate::auth::AuthConfig; + + let auth_config = AuthConfig { + verifier_base_url: "https://verifier.example.com".to_string(), + agent_id: "test-agent".to_string(), + avoid_tpm: true, + timeout_ms: 5000, + token_refresh_buffer_minutes: 5, + max_auth_retries: 3, + }; + + let middleware = AuthenticationMiddleware::new(auth_config).unwrap(); //#[allow_ci] + + // Mock a request to a sessions endpoint (should be detected as auth endpoint) + let mock_request = reqwest::Request::new( + Method::POST, + "https://verifier.example.com/v3.0/sessions" + .parse() + .unwrap(), //#[allow_ci] + ); + assert!(middleware.is_auth_endpoint(&mock_request)); + + // Mock a request to a non-auth endpoint + let mock_request2 = reqwest::Request::new( + Method::GET, + "https://verifier.example.com/v3.0/agents/123/attestations" + .parse() + .unwrap(), //#[allow_ci] + ); + assert!(!middleware.is_auth_endpoint(&mock_request2)); + } + + mod auth_middleware_tests { + use super::*; + use crate::auth::{AuthConfig, SessionToken}; + use chrono::{Duration, Utc}; + use std::sync::Arc; + + #[tokio::test] + async fn test_token_state_basic_operations() { + let auth_config = AuthConfig { + verifier_base_url: "https://verifier.example.com".to_string(), + agent_id: "test-agent".to_string(), + avoid_tpm: true, + timeout_ms: 5000, + token_refresh_buffer_minutes: 5, + max_auth_retries: 3, + }; + + let token_state = TokenState::new(auth_config).unwrap(); //#[allow_ci] + + // Test initially no token - should trigger authentication + let result = token_state.get_valid_token().await; + assert!( + result.is_err(), + "Should fail when no auth server available" + ); + // Since we're using a real auth client, we expect authentication-related errors + let error_msg = result.unwrap_err().to_string(); //#[allow_ci] + assert!( + error_msg.contains("Authentication failed"), + "Error: {}", + error_msg + ); + + // Test clear token when no token exists (should not panic) + token_state.clear_token().await; + + // Manually insert a valid token for testing + { + let mut token_guard = token_state.token.write().await; + *token_guard = Some(SessionToken { + token: "test-token-123".to_string(), + expires_at: Utc::now() + Duration::hours(1), // Valid for 1 hour + session_id: 42, + }); + } + + // Test get valid token with valid token - should succeed now + let result = token_state.get_valid_token().await; + assert!(result.is_ok(), "Should succeed with valid token"); + assert_eq!(result.unwrap(), "test-token-123"); //#[allow_ci] + + // Test clear token + token_state.clear_token().await; + + // Verify token was cleared + { + let token_guard = token_state.token.read().await; + assert!(token_guard.is_none()); + } + } + + #[tokio::test] + async fn test_token_state_expiration_logic() { + let auth_config = AuthConfig { + verifier_base_url: "https://verifier.example.com".to_string(), + agent_id: "test-agent".to_string(), + avoid_tpm: true, + timeout_ms: 5000, + token_refresh_buffer_minutes: 10, // 10 minute buffer + max_auth_retries: 3, + }; + + let token_state = TokenState::new(auth_config).unwrap(); //#[allow_ci] + + // Insert token that expires within buffer time (should be considered invalid) + { + let mut token_guard = token_state.token.write().await; + *token_guard = Some(SessionToken { + token: "expiring-token".to_string(), + expires_at: Utc::now() + Duration::minutes(5), // Expires in 5 min, but 10 min buffer + session_id: 123, + }); + } + + // Should try to refresh because token is within buffer time + let result = token_state.get_valid_token().await; + assert!( + result.is_err(), + "Should fail due to token expiring within buffer" + ); + // Since we're using a real auth client, we expect authentication-related errors + let error_msg = result.unwrap_err().to_string(); //#[allow_ci] + assert!( + error_msg.contains("Authentication failed"), + "Error: {}", + error_msg + ); + } + + #[tokio::test] + async fn test_authentication_middleware_advanced_patterns() { + let auth_config = AuthConfig { + verifier_base_url: "https://verifier.example.com".to_string(), + agent_id: "test-agent".to_string(), + avoid_tpm: true, + timeout_ms: 5000, + token_refresh_buffer_minutes: 5, + max_auth_retries: 3, + }; + + let middleware = + AuthenticationMiddleware::new(auth_config).unwrap(); //#[allow_ci] + + // Test different auth endpoint patterns + let test_cases = vec![ + ("https://verifier.example.com/sessions", true), + ("https://verifier.example.com/v3.0/sessions", true), + ("https://verifier.example.com/api/sessions/123", true), + ("https://verifier.example.com/agents", false), + ("https://verifier.example.com/attestations", false), + ("https://verifier.example.com/keys", false), + ]; + + for (url, expected_is_auth) in test_cases { + let mock_request = reqwest::Request::new( + Method::GET, + url.parse().unwrap(), //#[allow_ci] + ); + assert_eq!( + middleware.is_auth_endpoint(&mock_request), + expected_is_auth, + "URL {} should be auth endpoint: {}", + url, + expected_is_auth + ); + } + } + + #[tokio::test] + async fn test_middleware_concurrent_access() { + let auth_config = AuthConfig { + verifier_base_url: "https://verifier.example.com".to_string(), + agent_id: "test-agent".to_string(), + avoid_tpm: true, + timeout_ms: 5000, + token_refresh_buffer_minutes: 5, + max_auth_retries: 3, + }; + + let token_state = Arc::new(TokenState::new(auth_config).unwrap()); //#[allow_ci] + + // Test concurrent access to token state (should not deadlock) + let mut handles = vec![]; + + for i in 0..5 { + let token_state_clone = Arc::clone(&token_state); + let handle = tokio::spawn(async move { + if i % 2 == 0 { + // Even threads try to get token + let _result = + token_state_clone.get_valid_token().await; + } else { + // Odd threads clear token + token_state_clone.clear_token().await; + } + }); + handles.push(handle); + } + + // Wait for all tasks to complete (should not hang) + for handle in handles { + handle.await.unwrap(); //#[allow_ci] + } + + // Verify we can still access the token state + token_state.clear_token().await; + } + } }