Skip to content

Commit e3321cd

Browse files
committed
Add OAuth token caching and automatic refresh
- Add TokenCache for persisting OAuth tokens to disk - Add TokenProvider trait for dynamic token management - Implement OAuthTokenProvider with automatic token refresh - Update HttpTransport to use dynamic token provider - Add token refresh support in OAuthDiscovery - Fix control flow for OAuth re-authentication when refresh fails - Upgrade all dependencies to latest versions
1 parent a8ab770 commit e3321cd

File tree

11 files changed

+833
-405
lines changed

11 files changed

+833
-405
lines changed

Cargo.lock

Lines changed: 256 additions & 370 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/mcp-client/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "mcp-client"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
edition = "2021"
55
authors = ["Rakibul Yeasin <ryeasin03@gmail.com>"]
66

crates/mcp-client/src/auth.rs

Lines changed: 285 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ use crate::error::{ClientError, Result};
22
use reqwest::Client;
33
use serde::{Deserialize, Serialize};
44
use sha2::{Digest, Sha256};
5+
use std::path::PathBuf;
56
use std::sync::Arc;
67
use std::time::{SystemTime, UNIX_EPOCH};
78
use tokio::sync::RwLock;
8-
use tracing::{debug, info};
9+
use tracing::{debug, info, warn};
910

1011
#[derive(Debug, Clone, Serialize, Deserialize)]
1112
pub struct OAuthClientConfig {
@@ -384,6 +385,289 @@ impl OAuthClient {
384385
}
385386
}
386387

388+
/// Trait for providing OAuth tokens dynamically
389+
#[async_trait::async_trait]
390+
pub trait TokenProvider: Send + Sync {
391+
/// Get a valid access token, refreshing if necessary
392+
async fn get_token(&self) -> Result<String>;
393+
}
394+
395+
/// Static token provider - just returns a fixed token string
396+
pub struct StaticTokenProvider {
397+
token: String,
398+
}
399+
400+
impl StaticTokenProvider {
401+
pub fn new(token: String) -> Self {
402+
Self { token }
403+
}
404+
}
405+
406+
#[async_trait::async_trait]
407+
impl TokenProvider for StaticTokenProvider {
408+
async fn get_token(&self) -> Result<String> {
409+
Ok(self.token.clone())
410+
}
411+
}
412+
413+
/// OAuth token provider - manages token lifecycle with automatic refresh
414+
pub struct OAuthTokenProvider {
415+
server_name: String,
416+
oauth_client: Arc<OAuthClient>,
417+
token_cache: Arc<TokenCache>,
418+
current_token: Arc<RwLock<Option<ClientToken>>>,
419+
}
420+
421+
impl OAuthTokenProvider {
422+
pub fn new(
423+
server_name: String,
424+
oauth_client: Arc<OAuthClient>,
425+
token_cache: Arc<TokenCache>,
426+
initial_token: Option<ClientToken>,
427+
) -> Self {
428+
Self {
429+
server_name,
430+
oauth_client,
431+
token_cache,
432+
current_token: Arc::new(RwLock::new(initial_token)),
433+
}
434+
}
435+
436+
/// Check if token needs refresh (expired or expiring soon)
437+
fn needs_refresh(token: &ClientToken) -> bool {
438+
if let Some(expires_at) = token.expires_at {
439+
let now = SystemTime::now()
440+
.duration_since(UNIX_EPOCH)
441+
.map(|d| d.as_secs())
442+
.unwrap_or(0);
443+
// Refresh 60 seconds before expiry
444+
now + 60 >= expires_at
445+
} else {
446+
false
447+
}
448+
}
449+
}
450+
451+
#[async_trait::async_trait]
452+
impl TokenProvider for OAuthTokenProvider {
453+
async fn get_token(&self) -> Result<String> {
454+
// Check current token
455+
{
456+
let token_guard = self.current_token.read().await;
457+
if let Some(ref token) = *token_guard {
458+
if !Self::needs_refresh(token) {
459+
return Ok(token.access_token.clone());
460+
}
461+
}
462+
}
463+
464+
// Need to refresh - acquire write lock
465+
let mut token_guard = self.current_token.write().await;
466+
467+
// Double-check after acquiring write lock
468+
if let Some(ref token) = *token_guard {
469+
if !Self::needs_refresh(token) {
470+
return Ok(token.access_token.clone());
471+
}
472+
473+
// Try to refresh
474+
if token.refresh_token.is_some() {
475+
debug!("Refreshing expired token for: {}", self.server_name);
476+
self.oauth_client.set_token(token.clone()).await;
477+
478+
match self.oauth_client.refresh_token().await {
479+
Ok(new_token) => {
480+
info!("Token refreshed for: {}", self.server_name);
481+
// Update cache
482+
if let Err(e) = self.token_cache.save(&self.server_name, &new_token) {
483+
warn!("Failed to update token cache: {}", e);
484+
}
485+
let access_token = new_token.access_token.clone();
486+
*token_guard = Some(new_token);
487+
return Ok(access_token);
488+
}
489+
Err(e) => {
490+
warn!("Token refresh failed for '{}': {}", self.server_name, e);
491+
}
492+
}
493+
}
494+
}
495+
496+
// No valid token available
497+
Err(ClientError::OAuthError(
498+
"No valid token available and refresh failed".to_string(),
499+
))
500+
}
501+
}
502+
503+
/// Cached token entry with metadata
504+
#[derive(Debug, Clone, Serialize, Deserialize)]
505+
pub struct CachedToken {
506+
pub token: ClientToken,
507+
pub server_name: String,
508+
pub created_at: u64,
509+
}
510+
511+
/// Token cache for persisting OAuth tokens to disk
512+
pub struct TokenCache {
513+
cache_dir: PathBuf,
514+
}
515+
516+
impl TokenCache {
517+
/// Create a new token cache using the default cache directory
518+
pub fn new() -> Result<Self> {
519+
let cache_dir = Self::default_cache_dir()?;
520+
Self::with_dir(cache_dir)
521+
}
522+
523+
/// Create a new token cache with a custom directory
524+
pub fn with_dir(cache_dir: PathBuf) -> Result<Self> {
525+
std::fs::create_dir_all(&cache_dir)
526+
.map_err(|e| ClientError::OAuthError(format!("Failed to create cache directory: {}", e)))?;
527+
Ok(Self { cache_dir })
528+
}
529+
530+
/// Get the default cache directory
531+
fn default_cache_dir() -> Result<PathBuf> {
532+
let base = dirs::cache_dir()
533+
.or_else(dirs::home_dir)
534+
.ok_or_else(|| ClientError::OAuthError("Cannot determine cache directory".to_string()))?;
535+
Ok(base.join("mcp-connect").join("tokens"))
536+
}
537+
538+
/// Generate a cache key from server name
539+
fn cache_key(server_name: &str) -> String {
540+
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
541+
let mut hasher = Sha256::new();
542+
hasher.update(server_name.as_bytes());
543+
let hash = hasher.finalize();
544+
URL_SAFE_NO_PAD.encode(&hash[..16])
545+
}
546+
547+
/// Get the cache file path for a server
548+
fn cache_path(&self, server_name: &str) -> PathBuf {
549+
let key = Self::cache_key(server_name);
550+
self.cache_dir.join(format!("{}.json", key))
551+
}
552+
553+
/// Load a cached token for a server
554+
pub fn load(&self, server_name: &str) -> Option<CachedToken> {
555+
let path = self.cache_path(server_name);
556+
match std::fs::read_to_string(&path) {
557+
Ok(content) => {
558+
match serde_json::from_str::<CachedToken>(&content) {
559+
Ok(cached) => {
560+
debug!("Loaded cached token for server: {}", server_name);
561+
Some(cached)
562+
}
563+
Err(e) => {
564+
warn!("Failed to parse cached token: {}", e);
565+
None
566+
}
567+
}
568+
}
569+
Err(_) => None,
570+
}
571+
}
572+
573+
/// Save a token to the cache
574+
pub fn save(&self, server_name: &str, token: &ClientToken) -> Result<()> {
575+
let now = SystemTime::now()
576+
.duration_since(UNIX_EPOCH)
577+
.map(|d| d.as_secs())
578+
.unwrap_or(0);
579+
580+
let cached = CachedToken {
581+
token: token.clone(),
582+
server_name: server_name.to_string(),
583+
created_at: now,
584+
};
585+
586+
let path = self.cache_path(server_name);
587+
let content = serde_json::to_string_pretty(&cached)
588+
.map_err(|e| ClientError::OAuthError(format!("Failed to serialize token: {}", e)))?;
589+
590+
std::fs::write(&path, content)
591+
.map_err(|e| ClientError::OAuthError(format!("Failed to write token cache: {}", e)))?;
592+
593+
info!("Saved token to cache for server: {}", server_name);
594+
Ok(())
595+
}
596+
597+
/// Remove a cached token
598+
pub fn remove(&self, server_name: &str) -> Result<()> {
599+
let path = self.cache_path(server_name);
600+
if path.exists() {
601+
std::fs::remove_file(&path)
602+
.map_err(|e| ClientError::OAuthError(format!("Failed to remove cached token: {}", e)))?;
603+
debug!("Removed cached token for server: {}", server_name);
604+
}
605+
Ok(())
606+
}
607+
608+
/// Check if a cached token is still valid (not expired)
609+
pub fn is_token_valid(token: &ClientToken) -> bool {
610+
if let Some(expires_at) = token.expires_at {
611+
let now = SystemTime::now()
612+
.duration_since(UNIX_EPOCH)
613+
.map(|d| d.as_secs())
614+
.unwrap_or(0);
615+
// Consider token expired 60 seconds before actual expiry for safety
616+
now + 60 < expires_at
617+
} else {
618+
true // No expiration means valid
619+
}
620+
}
621+
622+
/// Load a valid token, or return None if expired/missing
623+
pub fn load_valid(&self, server_name: &str) -> Option<ClientToken> {
624+
self.load(server_name).and_then(|cached| {
625+
if Self::is_token_valid(&cached.token) {
626+
Some(cached.token)
627+
} else {
628+
debug!("Cached token for '{}' is expired", server_name);
629+
None
630+
}
631+
})
632+
}
633+
634+
/// Load token and refresh if expired (requires OAuthClient)
635+
pub async fn load_or_refresh(
636+
&self,
637+
server_name: &str,
638+
oauth_client: &OAuthClient,
639+
) -> Result<ClientToken> {
640+
if let Some(cached) = self.load(server_name) {
641+
if Self::is_token_valid(&cached.token) {
642+
debug!("Using valid cached token for: {}", server_name);
643+
return Ok(cached.token);
644+
}
645+
646+
// Token expired, try to refresh
647+
if cached.token.refresh_token.is_some() {
648+
debug!("Attempting to refresh expired token for: {}", server_name);
649+
oauth_client.set_token(cached.token).await;
650+
match oauth_client.refresh_token().await {
651+
Ok(new_token) => {
652+
self.save(server_name, &new_token)?;
653+
return Ok(new_token);
654+
}
655+
Err(e) => {
656+
warn!("Failed to refresh token for '{}': {}", server_name, e);
657+
// Remove invalid cached token
658+
let _ = self.remove(server_name);
659+
}
660+
}
661+
} else {
662+
debug!("No refresh token available for: {}", server_name);
663+
let _ = self.remove(server_name);
664+
}
665+
}
666+
667+
Err(ClientError::OAuthError("No valid cached token available".to_string()))
668+
}
669+
}
670+
387671
#[cfg(test)]
388672
mod tests {
389673
use super::*;

crates/mcp-client/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,8 @@ pub mod oauth_discovery;
77
pub use client::McpRemoteClient;
88
pub use error::ClientError;
99
pub use transport::{HttpTransport, StdioTransport, TcpTransport};
10-
pub use auth::{OAuthClient, OAuthClientConfig, ClientToken};
10+
pub use auth::{
11+
OAuthClient, OAuthClientConfig, ClientToken, TokenCache, CachedToken,
12+
TokenProvider, StaticTokenProvider, OAuthTokenProvider,
13+
};
1114
pub use oauth_discovery::{OAuthDiscovery, OAuthMetadata, OAuthRequirement};

0 commit comments

Comments
 (0)