1+ use crate :: auth:: { AuthConfig , SessionToken } ;
2+ use anyhow;
13use async_trait:: async_trait;
24use http:: Extensions ;
35use log:: { debug, warn} ;
@@ -13,7 +15,9 @@ use reqwest_retry::{
1315 Retryable , RetryableStrategy ,
1416} ;
1517use serde:: Serialize ;
18+ use std:: sync:: Arc ;
1619use std:: time:: Duration ;
20+ use tokio:: sync:: { Mutex , RwLock } ;
1721
1822// We define a default maximum delay for retries, which in the pracitical sense
1923// is set to 1 hour. This can be adjusted based on the application's needs.
@@ -82,6 +86,153 @@ impl RetryableStrategy for StopOnSuccessStrategy {
8286 }
8387}
8488
89+ /// Shared state for authentication tokens with proper concurrency control
90+ #[ derive( Debug ) ]
91+ struct TokenState {
92+ /// RwLock for the actual token - allows concurrent reads
93+ token : RwLock < Option < SessionToken > > ,
94+ /// Mutex for refresh operations - ensures single writer
95+ refresh_lock : Mutex < ( ) > ,
96+ /// Authentication configuration
97+ auth_config : AuthConfig ,
98+ }
99+
100+ impl TokenState {
101+ fn new ( auth_config : AuthConfig ) -> Self {
102+ Self {
103+ token : RwLock :: new ( None ) ,
104+ refresh_lock : Mutex :: new ( ( ) ) ,
105+ auth_config,
106+ }
107+ }
108+
109+ async fn get_valid_token (
110+ & self ,
111+ ) -> Result < String , Box < dyn std:: error:: Error + Send + Sync > > {
112+ // Fast path: try to read existing valid token
113+ {
114+ let token_guard = self . token . read ( ) . await ;
115+ if let Some ( ref token) = * token_guard {
116+ if token
117+ . is_valid ( self . auth_config . token_refresh_buffer_minutes )
118+ {
119+ debug ! ( "Using existing valid token from middleware" ) ;
120+ return Ok ( token. token . clone ( ) ) ;
121+ }
122+ }
123+ }
124+
125+ // Slow path: token is invalid or missing, need to refresh
126+ self . refresh_token ( ) . await
127+ }
128+
129+ async fn refresh_token (
130+ & self ,
131+ ) -> Result < String , Box < dyn std:: error:: Error + Send + Sync > > {
132+ // Acquire refresh lock to ensure only one refresh at a time
133+ let _refresh_guard = self . refresh_lock . lock ( ) . await ;
134+
135+ // Double-check: another request might have refreshed while we waited
136+ {
137+ let token_guard = self . token . read ( ) . await ;
138+ if let Some ( ref token) = * token_guard {
139+ if token
140+ . is_valid ( self . auth_config . token_refresh_buffer_minutes )
141+ {
142+ debug ! ( "Token was refreshed by another request" ) ;
143+ return Ok ( token. token . clone ( ) ) ;
144+ }
145+ }
146+ }
147+
148+ // TODO: Next, we'll integrate with the actual AuthenticationClient
149+ // For now, this is a placeholder that will be replaced
150+ warn ! ( "Token refresh not yet implemented" ) ;
151+ Err ( "Authentication not yet integrated" . into ( ) )
152+ }
153+
154+ async fn clear_token ( & self ) {
155+ let mut token_guard = self . token . write ( ) . await ;
156+ * token_guard = None ;
157+ debug ! ( "Authentication token cleared from shared state" ) ;
158+ }
159+ }
160+
161+ /// Middleware for transparent authentication using challenge-response protocol
162+ #[ derive( Debug ) ]
163+ pub struct AuthenticationMiddleware {
164+ token_state : Arc < TokenState > ,
165+ }
166+
167+ impl AuthenticationMiddleware {
168+ pub fn new ( auth_config : AuthConfig ) -> Self {
169+ let token_state = Arc :: new ( TokenState :: new ( auth_config) ) ;
170+ Self { token_state }
171+ }
172+
173+ fn is_auth_endpoint ( & self , req : & reqwest:: Request ) -> bool {
174+ let path = req. url ( ) . path ( ) ;
175+ // Skip authentication for auth endpoints to prevent infinite loops
176+ path. contains ( "/sessions" )
177+ }
178+ }
179+
180+ #[ async_trait]
181+ impl Middleware for AuthenticationMiddleware {
182+ async fn handle (
183+ & self ,
184+ mut req : reqwest:: Request ,
185+ extensions : & mut Extensions ,
186+ next : Next < ' _ > ,
187+ ) -> Result < Response , Error > {
188+ // Skip authentication for auth endpoints to prevent infinite loops
189+ if self . is_auth_endpoint ( & req) {
190+ debug ! (
191+ "Skipping auth for authentication endpoint: {}" ,
192+ req. url( ) . path( )
193+ ) ;
194+ return next. run ( req, extensions) . await ;
195+ }
196+
197+ // Add Authorization header if not present
198+ if !req. headers ( ) . contains_key ( "Authorization" ) {
199+ match self . token_state . get_valid_token ( ) . await {
200+ Ok ( token) => {
201+ debug ! ( "Adding authentication token to request" ) ;
202+ req. headers_mut ( ) . insert (
203+ "Authorization" ,
204+ format ! ( "Bearer {}" , token) . parse ( ) . map_err ( |e| {
205+ Error :: Middleware ( anyhow:: anyhow!(
206+ "Invalid token format: {}" ,
207+ e
208+ ) )
209+ } ) ?,
210+ ) ;
211+ }
212+ Err ( e) => {
213+ warn ! ( "Failed to get auth token: {}" , e) ;
214+ return Err ( Error :: Middleware ( anyhow:: anyhow!(
215+ "Authentication failed: {}" ,
216+ e
217+ ) ) ) ;
218+ }
219+ }
220+ }
221+
222+ let response = next. run ( req, extensions) . await ?;
223+
224+ // Handle 401 responses by clearing token
225+ if response. status ( ) == StatusCode :: UNAUTHORIZED {
226+ warn ! ( "Received 401, clearing token for future requests" ) ;
227+ self . token_state . clear_token ( ) . await ;
228+ // Note: We don't retry here to avoid infinite loops
229+ // The retry will happen naturally on the next request
230+ }
231+
232+ Ok ( response)
233+ }
234+ }
235+
85236/// A client that transparently handles retries with exponential backoff.
86237#[ derive( Debug , Clone ) ]
87238pub struct ResilientClient {
@@ -120,6 +271,46 @@ impl ResilientClient {
120271 }
121272 }
122273
274+ /// Creates a new client with optional authentication middleware
275+ pub fn new_with_auth (
276+ client : Option < Client > ,
277+ auth_config : Option < AuthConfig > ,
278+ initial_delay : std:: time:: Duration ,
279+ max_retries : u32 ,
280+ success_codes : & [ StatusCode ] ,
281+ max_delay : Option < std:: time:: Duration > ,
282+ ) -> Self {
283+ let base_client = client. unwrap_or_default ( ) ;
284+ let final_max_delay = max_delay. unwrap_or ( DEFAULT_MAX_DELAY ) ;
285+
286+ let retry_policy = ExponentialBackoff :: builder ( )
287+ . retry_bounds ( initial_delay, final_max_delay)
288+ . jitter ( Jitter :: None )
289+ . build_with_max_retries ( max_retries) ;
290+
291+ let mut builder = ClientBuilder :: new ( base_client) . with (
292+ RetryTransientMiddleware :: new_with_policy_and_strategy (
293+ retry_policy,
294+ StopOnSuccessStrategy {
295+ success_codes : success_codes. to_vec ( ) ,
296+ } ,
297+ ) ,
298+ ) ;
299+
300+ // Add authentication middleware if config is provided
301+ if let Some ( auth_cfg) = auth_config {
302+ debug ! ( "Adding authentication middleware to client" ) ;
303+ let auth_middleware = AuthenticationMiddleware :: new ( auth_cfg) ;
304+ builder = builder. with ( auth_middleware) ;
305+ }
306+
307+ let client_with_middleware = builder. with ( LoggingMiddleware ) . build ( ) ;
308+
309+ Self {
310+ client : client_with_middleware,
311+ }
312+ }
313+
123314 /// Generates a six-character lowercase alphanumeric request ID.
124315 fn generate_request_id ( ) -> String {
125316 const CHARSET : & [ u8 ] = b"abcdefghijklmnopqrstuvwxyz0123456789" ;
@@ -502,4 +693,79 @@ mod tests {
502693 assert ! ( response. is_ok( ) ) ;
503694 assert_eq ! ( response. unwrap( ) . status( ) , StatusCode :: OK ) ; //#[allow_ci]
504695 }
696+
697+ #[ tokio:: test]
698+ async fn test_resilient_client_with_auth_config ( ) {
699+ use crate :: auth:: AuthConfig ;
700+
701+ let auth_config = AuthConfig {
702+ verifier_base_url : "https://verifier.example.com" . to_string ( ) ,
703+ agent_id : "test-agent" . to_string ( ) ,
704+ avoid_tpm : true ,
705+ timeout_ms : 5000 ,
706+ token_refresh_buffer_minutes : 5 ,
707+ max_auth_retries : 3 ,
708+ } ;
709+
710+ // Test with authentication
711+ let _client_with_auth = ResilientClient :: new_with_auth (
712+ None ,
713+ Some ( auth_config) ,
714+ std:: time:: Duration :: from_millis ( 10 ) ,
715+ 3 ,
716+ & [ StatusCode :: OK ] ,
717+ None ,
718+ ) ;
719+
720+ // Verify the client was created successfully
721+ }
722+
723+ #[ tokio:: test]
724+ async fn test_resilient_client_without_auth_config ( ) {
725+ // Test without authentication (should behave like the original client)
726+ let _client_without_auth = ResilientClient :: new_with_auth (
727+ None ,
728+ None , // No auth config
729+ std:: time:: Duration :: from_millis ( 10 ) ,
730+ 3 ,
731+ & [ StatusCode :: OK ] ,
732+ None ,
733+ ) ;
734+
735+ // Verify the client was created successfully
736+ }
737+
738+ #[ tokio:: test]
739+ async fn test_authentication_middleware_path_detection ( ) {
740+ use crate :: auth:: AuthConfig ;
741+
742+ let auth_config = AuthConfig {
743+ verifier_base_url : "https://verifier.example.com" . to_string ( ) ,
744+ agent_id : "test-agent" . to_string ( ) ,
745+ avoid_tpm : true ,
746+ timeout_ms : 5000 ,
747+ token_refresh_buffer_minutes : 5 ,
748+ max_auth_retries : 3 ,
749+ } ;
750+
751+ let middleware = AuthenticationMiddleware :: new ( auth_config) ;
752+
753+ // Mock a request to a sessions endpoint (should be detected as auth endpoint)
754+ let mock_request = reqwest:: Request :: new (
755+ Method :: POST ,
756+ "https://verifier.example.com/v3.0/sessions"
757+ . parse ( )
758+ . unwrap ( ) , //#[allow_ci]
759+ ) ;
760+ assert ! ( middleware. is_auth_endpoint( & mock_request) ) ;
761+
762+ // Mock a request to a non-auth endpoint
763+ let mock_request2 = reqwest:: Request :: new (
764+ Method :: GET ,
765+ "https://verifier.example.com/v3.0/agents/123/attestations"
766+ . parse ( )
767+ . unwrap ( ) , //#[allow_ci]
768+ ) ;
769+ assert ! ( !middleware. is_auth_endpoint( & mock_request2) ) ;
770+ }
505771}
0 commit comments