Skip to content

Commit 5a5b89c

Browse files
feat: implement authentication middleware with shared token state
Add transparent authentication middleware to ResilientClient using the challenge-response protocol with proper concurrency control. This implementation uses the shared auth types from the core library. Signed-off-by: Sergio Correia <[email protected]>
1 parent 85937c3 commit 5a5b89c

File tree

1 file changed

+266
-0
lines changed

1 file changed

+266
-0
lines changed

keylime/src/resilient_client.rs

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::auth::{AuthConfig, SessionToken};
2+
use anyhow;
13
use async_trait::async_trait;
24
use http::Extensions;
35
use log::{debug, warn};
@@ -13,7 +15,9 @@ use reqwest_retry::{
1315
Retryable, RetryableStrategy,
1416
};
1517
use serde::Serialize;
18+
use std::sync::Arc;
1619
use std::time::Duration;
20+
use tokio::sync::{Mutex, RwLock};
1721

1822
// We define a default maximum delay for retries, which in the pracitical sense
1923
// is set to 1 hour. This can be adjusted based on the application's needs.
@@ -82,6 +86,153 @@ impl RetryableStrategy for StopOnSuccessStrategy {
8286
}
8387
}
8488

89+
/// Shared state for authentication tokens with proper concurrency control
90+
#[derive(Debug)]
91+
struct TokenState {
92+
/// RwLock for the actual token - allows concurrent reads
93+
token: RwLock<Option<SessionToken>>,
94+
/// Mutex for refresh operations - ensures single writer
95+
refresh_lock: Mutex<()>,
96+
/// Authentication configuration
97+
auth_config: AuthConfig,
98+
}
99+
100+
impl TokenState {
101+
fn new(auth_config: AuthConfig) -> Self {
102+
Self {
103+
token: RwLock::new(None),
104+
refresh_lock: Mutex::new(()),
105+
auth_config,
106+
}
107+
}
108+
109+
async fn get_valid_token(
110+
&self,
111+
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
112+
// Fast path: try to read existing valid token
113+
{
114+
let token_guard = self.token.read().await;
115+
if let Some(ref token) = *token_guard {
116+
if token
117+
.is_valid(self.auth_config.token_refresh_buffer_minutes)
118+
{
119+
debug!("Using existing valid token from middleware");
120+
return Ok(token.token.clone());
121+
}
122+
}
123+
}
124+
125+
// Slow path: token is invalid or missing, need to refresh
126+
self.refresh_token().await
127+
}
128+
129+
async fn refresh_token(
130+
&self,
131+
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
132+
// Acquire refresh lock to ensure only one refresh at a time
133+
let _refresh_guard = self.refresh_lock.lock().await;
134+
135+
// Double-check: another request might have refreshed while we waited
136+
{
137+
let token_guard = self.token.read().await;
138+
if let Some(ref token) = *token_guard {
139+
if token
140+
.is_valid(self.auth_config.token_refresh_buffer_minutes)
141+
{
142+
debug!("Token was refreshed by another request");
143+
return Ok(token.token.clone());
144+
}
145+
}
146+
}
147+
148+
// TODO: Next, we'll integrate with the actual AuthenticationClient
149+
// For now, this is a placeholder that will be replaced
150+
warn!("Token refresh not yet implemented");
151+
Err("Authentication not yet integrated".into())
152+
}
153+
154+
async fn clear_token(&self) {
155+
let mut token_guard = self.token.write().await;
156+
*token_guard = None;
157+
debug!("Authentication token cleared from shared state");
158+
}
159+
}
160+
161+
/// Middleware for transparent authentication using challenge-response protocol
162+
#[derive(Debug)]
163+
pub struct AuthenticationMiddleware {
164+
token_state: Arc<TokenState>,
165+
}
166+
167+
impl AuthenticationMiddleware {
168+
pub fn new(auth_config: AuthConfig) -> Self {
169+
let token_state = Arc::new(TokenState::new(auth_config));
170+
Self { token_state }
171+
}
172+
173+
fn is_auth_endpoint(&self, req: &reqwest::Request) -> bool {
174+
let path = req.url().path();
175+
// Skip authentication for auth endpoints to prevent infinite loops
176+
path.contains("/sessions")
177+
}
178+
}
179+
180+
#[async_trait]
181+
impl Middleware for AuthenticationMiddleware {
182+
async fn handle(
183+
&self,
184+
mut req: reqwest::Request,
185+
extensions: &mut Extensions,
186+
next: Next<'_>,
187+
) -> Result<Response, Error> {
188+
// Skip authentication for auth endpoints to prevent infinite loops
189+
if self.is_auth_endpoint(&req) {
190+
debug!(
191+
"Skipping auth for authentication endpoint: {}",
192+
req.url().path()
193+
);
194+
return next.run(req, extensions).await;
195+
}
196+
197+
// Add Authorization header if not present
198+
if !req.headers().contains_key("Authorization") {
199+
match self.token_state.get_valid_token().await {
200+
Ok(token) => {
201+
debug!("Adding authentication token to request");
202+
req.headers_mut().insert(
203+
"Authorization",
204+
format!("Bearer {}", token).parse().map_err(|e| {
205+
Error::Middleware(anyhow::anyhow!(
206+
"Invalid token format: {}",
207+
e
208+
))
209+
})?,
210+
);
211+
}
212+
Err(e) => {
213+
warn!("Failed to get auth token: {}", e);
214+
return Err(Error::Middleware(anyhow::anyhow!(
215+
"Authentication failed: {}",
216+
e
217+
)));
218+
}
219+
}
220+
}
221+
222+
let response = next.run(req, extensions).await?;
223+
224+
// Handle 401 responses by clearing token
225+
if response.status() == StatusCode::UNAUTHORIZED {
226+
warn!("Received 401, clearing token for future requests");
227+
self.token_state.clear_token().await;
228+
// Note: We don't retry here to avoid infinite loops
229+
// The retry will happen naturally on the next request
230+
}
231+
232+
Ok(response)
233+
}
234+
}
235+
85236
/// A client that transparently handles retries with exponential backoff.
86237
#[derive(Debug, Clone)]
87238
pub struct ResilientClient {
@@ -120,6 +271,46 @@ impl ResilientClient {
120271
}
121272
}
122273

274+
/// Creates a new client with optional authentication middleware
275+
pub fn new_with_auth(
276+
client: Option<Client>,
277+
auth_config: Option<AuthConfig>,
278+
initial_delay: std::time::Duration,
279+
max_retries: u32,
280+
success_codes: &[StatusCode],
281+
max_delay: Option<std::time::Duration>,
282+
) -> Self {
283+
let base_client = client.unwrap_or_default();
284+
let final_max_delay = max_delay.unwrap_or(DEFAULT_MAX_DELAY);
285+
286+
let retry_policy = ExponentialBackoff::builder()
287+
.retry_bounds(initial_delay, final_max_delay)
288+
.jitter(Jitter::None)
289+
.build_with_max_retries(max_retries);
290+
291+
let mut builder = ClientBuilder::new(base_client).with(
292+
RetryTransientMiddleware::new_with_policy_and_strategy(
293+
retry_policy,
294+
StopOnSuccessStrategy {
295+
success_codes: success_codes.to_vec(),
296+
},
297+
),
298+
);
299+
300+
// Add authentication middleware if config is provided
301+
if let Some(auth_cfg) = auth_config {
302+
debug!("Adding authentication middleware to client");
303+
let auth_middleware = AuthenticationMiddleware::new(auth_cfg);
304+
builder = builder.with(auth_middleware);
305+
}
306+
307+
let client_with_middleware = builder.with(LoggingMiddleware).build();
308+
309+
Self {
310+
client: client_with_middleware,
311+
}
312+
}
313+
123314
/// Generates a six-character lowercase alphanumeric request ID.
124315
fn generate_request_id() -> String {
125316
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789";
@@ -502,4 +693,79 @@ mod tests {
502693
assert!(response.is_ok());
503694
assert_eq!(response.unwrap().status(), StatusCode::OK); //#[allow_ci]
504695
}
696+
697+
#[tokio::test]
698+
async fn test_resilient_client_with_auth_config() {
699+
use crate::auth::AuthConfig;
700+
701+
let auth_config = AuthConfig {
702+
verifier_base_url: "https://verifier.example.com".to_string(),
703+
agent_id: "test-agent".to_string(),
704+
avoid_tpm: true,
705+
timeout_ms: 5000,
706+
token_refresh_buffer_minutes: 5,
707+
max_auth_retries: 3,
708+
};
709+
710+
// Test with authentication
711+
let _client_with_auth = ResilientClient::new_with_auth(
712+
None,
713+
Some(auth_config),
714+
std::time::Duration::from_millis(10),
715+
3,
716+
&[StatusCode::OK],
717+
None,
718+
);
719+
720+
// Verify the client was created successfully
721+
}
722+
723+
#[tokio::test]
724+
async fn test_resilient_client_without_auth_config() {
725+
// Test without authentication (should behave like the original client)
726+
let _client_without_auth = ResilientClient::new_with_auth(
727+
None,
728+
None, // No auth config
729+
std::time::Duration::from_millis(10),
730+
3,
731+
&[StatusCode::OK],
732+
None,
733+
);
734+
735+
// Verify the client was created successfully
736+
}
737+
738+
#[tokio::test]
739+
async fn test_authentication_middleware_path_detection() {
740+
use crate::auth::AuthConfig;
741+
742+
let auth_config = AuthConfig {
743+
verifier_base_url: "https://verifier.example.com".to_string(),
744+
agent_id: "test-agent".to_string(),
745+
avoid_tpm: true,
746+
timeout_ms: 5000,
747+
token_refresh_buffer_minutes: 5,
748+
max_auth_retries: 3,
749+
};
750+
751+
let middleware = AuthenticationMiddleware::new(auth_config);
752+
753+
// Mock a request to a sessions endpoint (should be detected as auth endpoint)
754+
let mock_request = reqwest::Request::new(
755+
Method::POST,
756+
"https://verifier.example.com/v3.0/sessions"
757+
.parse()
758+
.unwrap(), //#[allow_ci]
759+
);
760+
assert!(middleware.is_auth_endpoint(&mock_request));
761+
762+
// Mock a request to a non-auth endpoint
763+
let mock_request2 = reqwest::Request::new(
764+
Method::GET,
765+
"https://verifier.example.com/v3.0/agents/123/attestations"
766+
.parse()
767+
.unwrap(), //#[allow_ci]
768+
);
769+
assert!(!middleware.is_auth_endpoint(&mock_request2));
770+
}
505771
}

0 commit comments

Comments
 (0)