Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions crates/chat-cli/src/mcp_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::collections::HashMap;
use std::process::Stdio;

use regex::Regex;
use reqwest::Client;
use rmcp::model::{
CallToolRequestParam,
CallToolResult,
Expand All @@ -25,7 +24,6 @@ use rmcp::service::{
DynService,
NotificationContext,
};
use rmcp::transport::auth::AuthClient;
use rmcp::transport::{
ConfigureCommandExt,
TokioChildProcess,
Expand All @@ -52,7 +50,7 @@ use tracing::{
use super::messenger::Messenger;
use super::oauth_util::HttpTransport;
use super::{
AuthClientDropGuard,
AuthClientWrapper,
OauthUtilError,
get_http_transport,
};
Expand Down Expand Up @@ -175,10 +173,11 @@ macro_rules! decorate_with_auth_retry {
Err(e) => {
// TODO: discern error type prior to retrying
// Not entirely sure what is thrown when auth is required
if let Some(auth_client) = self.get_auth_client() {
let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await;
if let Some(auth_client) = self.auth_client.as_ref() {
let refresh_result = auth_client.refresh_token().await;
match refresh_result {
Ok(_) => {
info!("Token refreshed");
// Retry the operation after token refresh
match &self.inner_service {
InnerService::Original(rs) => rs.$method_name(param).await,
Expand Down Expand Up @@ -245,20 +244,14 @@ impl Clone for InnerService {
#[derive(Debug)]
pub struct RunningService {
pub inner_service: InnerService,
auth_dropguard: Option<AuthClientDropGuard>,
auth_client: Option<AuthClientWrapper>,
}

impl Clone for RunningService {
fn clone(&self) -> Self {
let auth_dropguard = self.auth_dropguard.as_ref().map(|dg| {
let mut dg = dg.clone();
dg.should_write = false;
dg
});

RunningService {
inner_service: self.inner_service.clone(),
auth_dropguard,
auth_client: self.auth_client.clone(),
}
}
}
Expand All @@ -267,10 +260,6 @@ impl RunningService {
decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult);

decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult);

pub fn get_auth_client(&self) -> Option<AuthClient<Client>> {
self.auth_dropguard.as_ref().map(|a| a.auth_client.clone())
}
}

pub type StdioTransport = (TokioChildProcess, Option<ChildStderr>);
Expand Down Expand Up @@ -341,32 +330,30 @@ impl McpClientService {
},
Transport::Http(http_transport) => {
match http_transport {
HttpTransport::WithAuth((transport, mut auth_dg)) => {
HttpTransport::WithAuth((transport, mut auth_client)) => {
// The crate does not automatically refresh tokens when they expire. We
// would need to handle that here
let url = self.config.url.clone();
let service = match self.into_dyn().serve(transport).await.map_err(Box::new) {
Ok(service) => service,
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
debug!("## mcp: first hand shake attempt failed: {:?}", e);
let refresh_res =
auth_dg.auth_client.auth_manager.lock().await.refresh_token().await;
let refresh_res = auth_client.refresh_token().await;
let new_self = McpClientService::new(
server_name.clone(),
backup_config,
messenger_clone.clone(),
);

let new_transport =
get_http_transport(&os_clone, true, &url, Some(auth_dg.auth_client.clone()), &*messenger_dup).await?;
get_http_transport(&os_clone, true, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?;

match new_transport {
HttpTransport::WithAuth((new_transport, new_auth_dg)) => {
auth_dg.should_write = false;
auth_dg = new_auth_dg;
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
auth_client = new_auth_client;

match refresh_res {
Ok(_token) => {
Ok(_) => {
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
},
Err(e) => {
Expand All @@ -379,9 +366,8 @@ impl McpClientService {
get_http_transport(&os_clone, true, &url, None, &*messenger_dup).await?;

match new_transport {
HttpTransport::WithAuth((new_transport, new_auth_dg)) => {
auth_dg = new_auth_dg;
auth_dg.should_write = false;
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
auth_client = new_auth_client;
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
},
HttpTransport::WithoutAuth(new_transport) => {
Expand All @@ -398,7 +384,7 @@ impl McpClientService {
Err(e) => return Err(e.into()),
};

(service, None, Some(auth_dg))
(service, None, Some(auth_client))
},
HttpTransport::WithoutAuth(transport) => {
let service = self.into_dyn().serve(transport).await.map_err(Box::new)?;
Expand Down Expand Up @@ -496,7 +482,7 @@ impl McpClientService {

Ok(RunningService {
inner_service: InnerService::Original(service),
auth_dropguard,
auth_client: auth_dropguard,
})
});

Expand Down
82 changes: 28 additions & 54 deletions crates/chat-cli/src/mcp_client/oauth_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub enum OauthUtilError {
Reqwest(#[from] reqwest::Error),
#[error("Malformed directory")]
MalformDirectory,
#[error("Missing credential")]
MissingCredentials,
}

/// A guard that automatically cancels the cancellation token when dropped.
Expand Down Expand Up @@ -107,68 +109,36 @@ impl From<OAuthClientConfig> for Registration {
}
}

/// A guard that manages the lifecycle of an authenticated MCP client and automatically
/// persists OAuth credentials when dropped.
/// A wrapper that manages an authenticated MCP client.
///
/// This struct wraps an `AuthClient` and ensures that OAuth tokens are written to disk
/// when the guard goes out of scope, unless explicitly disabled via `should_write`.
/// This provides automatic credential caching for MCP server connections that require
/// OAuth authentication.
/// This struct wraps an `AuthClient` and provides access to OAuth credentials
/// for MCP server connections that require authentication. The credentials
/// are managed separately from this wrapper's lifecycle.
#[derive(Clone, Debug)]
pub struct AuthClientDropGuard {
pub should_write: bool,
pub struct AuthClientWrapper {
pub cred_full_path: PathBuf,
pub auth_client: AuthClient<Client>,
}

impl AuthClientDropGuard {
impl AuthClientWrapper {
pub fn new(cred_full_path: PathBuf, auth_client: AuthClient<Client>) -> Self {
Self {
should_write: true,
cred_full_path,
auth_client,
}
}
}

impl Drop for AuthClientDropGuard {
fn drop(&mut self) {
if !self.should_write {
return;
}
/// Refreshes token in memory using the registration read from when the auth client was
/// spawned. This also persists the retrieved token
pub async fn refresh_token(&self) -> Result<(), OauthUtilError> {
let cred = self.auth_client.auth_manager.lock().await.refresh_token().await?;
let parent_path = self.cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?;
tokio::fs::create_dir_all(parent_path).await?;

let auth_client_clone = self.auth_client.clone();
let path = self.cred_full_path.clone();
let cred_as_bytes = serde_json::to_string_pretty(&cred)?;
tokio::fs::write(&self.cred_full_path, &cred_as_bytes).await?;

tokio::spawn(async move {
let Ok((client_id, cred)) = auth_client_clone.auth_manager.lock().await.get_credentials().await else {
error!("Failed to retrieve credentials in drop routine");
return;
};
let Some(cred) = cred else {
error!("Failed to retrieve credentials in drop routine from {client_id}");
return;
};
let Some(parent_path) = path.parent() else {
error!("Failed to retrieve parent path for token in drop routine for {client_id}");
return;
};
if let Err(e) = tokio::fs::create_dir_all(parent_path).await {
error!("Error making parent directory for token cache in drop routine for {client_id}: {e}");
return;
}

let serialized_cred = match serde_json::to_string_pretty(&cred) {
Ok(cred) => cred,
Err(e) => {
error!("Failed to serialize credentials for {client_id}: {e}");
return;
},
};
if let Err(e) = tokio::fs::write(path, &serialized_cred).await {
error!("Error making writing token cache in drop routine: {e}");
}
});
Ok(())
}
}

Expand All @@ -186,7 +156,7 @@ pub enum HttpTransport {
WithAuth(
(
WorkerTransport<StreamableHttpClientWorker<AuthClient<Client>>>,
AuthClientDropGuard,
AuthClientWrapper,
),
),
WithoutAuth(WorkerTransport<StreamableHttpClientWorker<Client>>),
Expand Down Expand Up @@ -233,7 +203,7 @@ pub async fn get_http_transport(
..Default::default()
});

let auth_dg = AuthClientDropGuard::new(cred_full_path, auth_client);
let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client);
debug!("## mcp: transport obtained");

Ok(HttpTransport::WithAuth((transport, auth_dg)))
Expand Down Expand Up @@ -276,11 +246,8 @@ async fn get_auth_manager(

// Client registration is done in [start_authorization]
// If we have gotten past that point that means we have the info to persist the
// registration on disk. These are info that we need to refresh stake
// tokens. This is in contrast to tokens, which we only persist when we drop
// the client (because that way we can write once and ensure what is on the
// disk always the most up to date)
let (client_id, _credentials) = am.get_credentials().await?;
// registration on disk.
let (client_id, credentials) = am.get_credentials().await?;
let reg = Registration {
client_id,
client_secret: None,
Expand All @@ -292,6 +259,13 @@ async fn get_auth_manager(
tokio::fs::create_dir_all(reg_parent_path).await?;
tokio::fs::write(reg_full_path, &reg_as_str).await?;

let credentials = credentials.ok_or(OauthUtilError::MissingCredentials)?;

let cred_parent_path = cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?;
tokio::fs::create_dir_all(cred_parent_path).await?;
let reg_as_str = serde_json::to_string_pretty(&credentials)?;
tokio::fs::write(cred_full_path, &reg_as_str).await?;

Ok(am)
},
}
Expand Down
Loading