@@ -2,10 +2,11 @@ use crate::error::{ClientError, Result};
22use reqwest:: Client ;
33use serde:: { Deserialize , Serialize } ;
44use sha2:: { Digest , Sha256 } ;
5+ use std:: path:: PathBuf ;
56use std:: sync:: Arc ;
67use std:: time:: { SystemTime , UNIX_EPOCH } ;
78use tokio:: sync:: RwLock ;
8- use tracing:: { debug, info} ;
9+ use tracing:: { debug, info, warn } ;
910
1011#[ derive( Debug , Clone , Serialize , Deserialize ) ]
1112pub struct OAuthClientConfig {
@@ -384,6 +385,289 @@ impl OAuthClient {
384385 }
385386}
386387
388+ /// Trait for providing OAuth tokens dynamically
389+ #[ async_trait:: async_trait]
390+ pub trait TokenProvider : Send + Sync {
391+ /// Get a valid access token, refreshing if necessary
392+ async fn get_token ( & self ) -> Result < String > ;
393+ }
394+
395+ /// Static token provider - just returns a fixed token string
396+ pub struct StaticTokenProvider {
397+ token : String ,
398+ }
399+
400+ impl StaticTokenProvider {
401+ pub fn new ( token : String ) -> Self {
402+ Self { token }
403+ }
404+ }
405+
406+ #[ async_trait:: async_trait]
407+ impl TokenProvider for StaticTokenProvider {
408+ async fn get_token ( & self ) -> Result < String > {
409+ Ok ( self . token . clone ( ) )
410+ }
411+ }
412+
413+ /// OAuth token provider - manages token lifecycle with automatic refresh
414+ pub struct OAuthTokenProvider {
415+ server_name : String ,
416+ oauth_client : Arc < OAuthClient > ,
417+ token_cache : Arc < TokenCache > ,
418+ current_token : Arc < RwLock < Option < ClientToken > > > ,
419+ }
420+
421+ impl OAuthTokenProvider {
422+ pub fn new (
423+ server_name : String ,
424+ oauth_client : Arc < OAuthClient > ,
425+ token_cache : Arc < TokenCache > ,
426+ initial_token : Option < ClientToken > ,
427+ ) -> Self {
428+ Self {
429+ server_name,
430+ oauth_client,
431+ token_cache,
432+ current_token : Arc :: new ( RwLock :: new ( initial_token) ) ,
433+ }
434+ }
435+
436+ /// Check if token needs refresh (expired or expiring soon)
437+ fn needs_refresh ( token : & ClientToken ) -> bool {
438+ if let Some ( expires_at) = token. expires_at {
439+ let now = SystemTime :: now ( )
440+ . duration_since ( UNIX_EPOCH )
441+ . map ( |d| d. as_secs ( ) )
442+ . unwrap_or ( 0 ) ;
443+ // Refresh 60 seconds before expiry
444+ now + 60 >= expires_at
445+ } else {
446+ false
447+ }
448+ }
449+ }
450+
451+ #[ async_trait:: async_trait]
452+ impl TokenProvider for OAuthTokenProvider {
453+ async fn get_token ( & self ) -> Result < String > {
454+ // Check current token
455+ {
456+ let token_guard = self . current_token . read ( ) . await ;
457+ if let Some ( ref token) = * token_guard {
458+ if !Self :: needs_refresh ( token) {
459+ return Ok ( token. access_token . clone ( ) ) ;
460+ }
461+ }
462+ }
463+
464+ // Need to refresh - acquire write lock
465+ let mut token_guard = self . current_token . write ( ) . await ;
466+
467+ // Double-check after acquiring write lock
468+ if let Some ( ref token) = * token_guard {
469+ if !Self :: needs_refresh ( token) {
470+ return Ok ( token. access_token . clone ( ) ) ;
471+ }
472+
473+ // Try to refresh
474+ if token. refresh_token . is_some ( ) {
475+ debug ! ( "Refreshing expired token for: {}" , self . server_name) ;
476+ self . oauth_client . set_token ( token. clone ( ) ) . await ;
477+
478+ match self . oauth_client . refresh_token ( ) . await {
479+ Ok ( new_token) => {
480+ info ! ( "Token refreshed for: {}" , self . server_name) ;
481+ // Update cache
482+ if let Err ( e) = self . token_cache . save ( & self . server_name , & new_token) {
483+ warn ! ( "Failed to update token cache: {}" , e) ;
484+ }
485+ let access_token = new_token. access_token . clone ( ) ;
486+ * token_guard = Some ( new_token) ;
487+ return Ok ( access_token) ;
488+ }
489+ Err ( e) => {
490+ warn ! ( "Token refresh failed for '{}': {}" , self . server_name, e) ;
491+ }
492+ }
493+ }
494+ }
495+
496+ // No valid token available
497+ Err ( ClientError :: OAuthError (
498+ "No valid token available and refresh failed" . to_string ( ) ,
499+ ) )
500+ }
501+ }
502+
503+ /// Cached token entry with metadata
504+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
505+ pub struct CachedToken {
506+ pub token : ClientToken ,
507+ pub server_name : String ,
508+ pub created_at : u64 ,
509+ }
510+
511+ /// Token cache for persisting OAuth tokens to disk
512+ pub struct TokenCache {
513+ cache_dir : PathBuf ,
514+ }
515+
516+ impl TokenCache {
517+ /// Create a new token cache using the default cache directory
518+ pub fn new ( ) -> Result < Self > {
519+ let cache_dir = Self :: default_cache_dir ( ) ?;
520+ Self :: with_dir ( cache_dir)
521+ }
522+
523+ /// Create a new token cache with a custom directory
524+ pub fn with_dir ( cache_dir : PathBuf ) -> Result < Self > {
525+ std:: fs:: create_dir_all ( & cache_dir)
526+ . map_err ( |e| ClientError :: OAuthError ( format ! ( "Failed to create cache directory: {}" , e) ) ) ?;
527+ Ok ( Self { cache_dir } )
528+ }
529+
530+ /// Get the default cache directory
531+ fn default_cache_dir ( ) -> Result < PathBuf > {
532+ let base = dirs:: cache_dir ( )
533+ . or_else ( dirs:: home_dir)
534+ . ok_or_else ( || ClientError :: OAuthError ( "Cannot determine cache directory" . to_string ( ) ) ) ?;
535+ Ok ( base. join ( "mcp-connect" ) . join ( "tokens" ) )
536+ }
537+
538+ /// Generate a cache key from server name
539+ fn cache_key ( server_name : & str ) -> String {
540+ use base64:: { engine:: general_purpose:: URL_SAFE_NO_PAD , Engine } ;
541+ let mut hasher = Sha256 :: new ( ) ;
542+ hasher. update ( server_name. as_bytes ( ) ) ;
543+ let hash = hasher. finalize ( ) ;
544+ URL_SAFE_NO_PAD . encode ( & hash[ ..16 ] )
545+ }
546+
547+ /// Get the cache file path for a server
548+ fn cache_path ( & self , server_name : & str ) -> PathBuf {
549+ let key = Self :: cache_key ( server_name) ;
550+ self . cache_dir . join ( format ! ( "{}.json" , key) )
551+ }
552+
553+ /// Load a cached token for a server
554+ pub fn load ( & self , server_name : & str ) -> Option < CachedToken > {
555+ let path = self . cache_path ( server_name) ;
556+ match std:: fs:: read_to_string ( & path) {
557+ Ok ( content) => {
558+ match serde_json:: from_str :: < CachedToken > ( & content) {
559+ Ok ( cached) => {
560+ debug ! ( "Loaded cached token for server: {}" , server_name) ;
561+ Some ( cached)
562+ }
563+ Err ( e) => {
564+ warn ! ( "Failed to parse cached token: {}" , e) ;
565+ None
566+ }
567+ }
568+ }
569+ Err ( _) => None ,
570+ }
571+ }
572+
573+ /// Save a token to the cache
574+ pub fn save ( & self , server_name : & str , token : & ClientToken ) -> Result < ( ) > {
575+ let now = SystemTime :: now ( )
576+ . duration_since ( UNIX_EPOCH )
577+ . map ( |d| d. as_secs ( ) )
578+ . unwrap_or ( 0 ) ;
579+
580+ let cached = CachedToken {
581+ token : token. clone ( ) ,
582+ server_name : server_name. to_string ( ) ,
583+ created_at : now,
584+ } ;
585+
586+ let path = self . cache_path ( server_name) ;
587+ let content = serde_json:: to_string_pretty ( & cached)
588+ . map_err ( |e| ClientError :: OAuthError ( format ! ( "Failed to serialize token: {}" , e) ) ) ?;
589+
590+ std:: fs:: write ( & path, content)
591+ . map_err ( |e| ClientError :: OAuthError ( format ! ( "Failed to write token cache: {}" , e) ) ) ?;
592+
593+ info ! ( "Saved token to cache for server: {}" , server_name) ;
594+ Ok ( ( ) )
595+ }
596+
597+ /// Remove a cached token
598+ pub fn remove ( & self , server_name : & str ) -> Result < ( ) > {
599+ let path = self . cache_path ( server_name) ;
600+ if path. exists ( ) {
601+ std:: fs:: remove_file ( & path)
602+ . map_err ( |e| ClientError :: OAuthError ( format ! ( "Failed to remove cached token: {}" , e) ) ) ?;
603+ debug ! ( "Removed cached token for server: {}" , server_name) ;
604+ }
605+ Ok ( ( ) )
606+ }
607+
608+ /// Check if a cached token is still valid (not expired)
609+ pub fn is_token_valid ( token : & ClientToken ) -> bool {
610+ if let Some ( expires_at) = token. expires_at {
611+ let now = SystemTime :: now ( )
612+ . duration_since ( UNIX_EPOCH )
613+ . map ( |d| d. as_secs ( ) )
614+ . unwrap_or ( 0 ) ;
615+ // Consider token expired 60 seconds before actual expiry for safety
616+ now + 60 < expires_at
617+ } else {
618+ true // No expiration means valid
619+ }
620+ }
621+
622+ /// Load a valid token, or return None if expired/missing
623+ pub fn load_valid ( & self , server_name : & str ) -> Option < ClientToken > {
624+ self . load ( server_name) . and_then ( |cached| {
625+ if Self :: is_token_valid ( & cached. token ) {
626+ Some ( cached. token )
627+ } else {
628+ debug ! ( "Cached token for '{}' is expired" , server_name) ;
629+ None
630+ }
631+ } )
632+ }
633+
634+ /// Load token and refresh if expired (requires OAuthClient)
635+ pub async fn load_or_refresh (
636+ & self ,
637+ server_name : & str ,
638+ oauth_client : & OAuthClient ,
639+ ) -> Result < ClientToken > {
640+ if let Some ( cached) = self . load ( server_name) {
641+ if Self :: is_token_valid ( & cached. token ) {
642+ debug ! ( "Using valid cached token for: {}" , server_name) ;
643+ return Ok ( cached. token ) ;
644+ }
645+
646+ // Token expired, try to refresh
647+ if cached. token . refresh_token . is_some ( ) {
648+ debug ! ( "Attempting to refresh expired token for: {}" , server_name) ;
649+ oauth_client. set_token ( cached. token ) . await ;
650+ match oauth_client. refresh_token ( ) . await {
651+ Ok ( new_token) => {
652+ self . save ( server_name, & new_token) ?;
653+ return Ok ( new_token) ;
654+ }
655+ Err ( e) => {
656+ warn ! ( "Failed to refresh token for '{}': {}" , server_name, e) ;
657+ // Remove invalid cached token
658+ let _ = self . remove ( server_name) ;
659+ }
660+ }
661+ } else {
662+ debug ! ( "No refresh token available for: {}" , server_name) ;
663+ let _ = self . remove ( server_name) ;
664+ }
665+ }
666+
667+ Err ( ClientError :: OAuthError ( "No valid cached token available" . to_string ( ) ) )
668+ }
669+ }
670+
387671#[ cfg( test) ]
388672mod tests {
389673 use super :: * ;
0 commit comments