Skip to content

Commit a9f170c

Browse files
authored
fixes remote mcp creds not being written when obtained (#2878)
1 parent f1b48c9 commit a9f170c

File tree

2 files changed

+44
-84
lines changed

2 files changed

+44
-84
lines changed

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::collections::HashMap;
33
use std::process::Stdio;
44

55
use regex::Regex;
6-
use reqwest::Client;
76
use rmcp::model::{
87
CallToolRequestParam,
98
CallToolResult,
@@ -25,7 +24,6 @@ use rmcp::service::{
2524
DynService,
2625
NotificationContext,
2726
};
28-
use rmcp::transport::auth::AuthClient;
2927
use rmcp::transport::{
3028
ConfigureCommandExt,
3129
TokioChildProcess,
@@ -52,7 +50,7 @@ use tracing::{
5250
use super::messenger::Messenger;
5351
use super::oauth_util::HttpTransport;
5452
use super::{
55-
AuthClientDropGuard,
53+
AuthClientWrapper,
5654
OauthUtilError,
5755
get_http_transport,
5856
};
@@ -175,10 +173,11 @@ macro_rules! decorate_with_auth_retry {
175173
Err(e) => {
176174
// TODO: discern error type prior to retrying
177175
// Not entirely sure what is thrown when auth is required
178-
if let Some(auth_client) = self.get_auth_client() {
179-
let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await;
176+
if let Some(auth_client) = self.auth_client.as_ref() {
177+
let refresh_result = auth_client.refresh_token().await;
180178
match refresh_result {
181179
Ok(_) => {
180+
info!("Token refreshed");
182181
// Retry the operation after token refresh
183182
match &self.inner_service {
184183
InnerService::Original(rs) => rs.$method_name(param).await,
@@ -245,20 +244,14 @@ impl Clone for InnerService {
245244
#[derive(Debug)]
246245
pub struct RunningService {
247246
pub inner_service: InnerService,
248-
auth_dropguard: Option<AuthClientDropGuard>,
247+
auth_client: Option<AuthClientWrapper>,
249248
}
250249

251250
impl Clone for RunningService {
252251
fn clone(&self) -> Self {
253-
let auth_dropguard = self.auth_dropguard.as_ref().map(|dg| {
254-
let mut dg = dg.clone();
255-
dg.should_write = false;
256-
dg
257-
});
258-
259252
RunningService {
260253
inner_service: self.inner_service.clone(),
261-
auth_dropguard,
254+
auth_client: self.auth_client.clone(),
262255
}
263256
}
264257
}
@@ -267,10 +260,6 @@ impl RunningService {
267260
decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult);
268261

269262
decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult);
270-
271-
pub fn get_auth_client(&self) -> Option<AuthClient<Client>> {
272-
self.auth_dropguard.as_ref().map(|a| a.auth_client.clone())
273-
}
274263
}
275264

276265
pub type StdioTransport = (TokioChildProcess, Option<ChildStderr>);
@@ -341,32 +330,30 @@ impl McpClientService {
341330
},
342331
Transport::Http(http_transport) => {
343332
match http_transport {
344-
HttpTransport::WithAuth((transport, mut auth_dg)) => {
333+
HttpTransport::WithAuth((transport, mut auth_client)) => {
345334
// The crate does not automatically refresh tokens when they expire. We
346335
// would need to handle that here
347336
let url = self.config.url.clone();
348337
let service = match self.into_dyn().serve(transport).await.map_err(Box::new) {
349338
Ok(service) => service,
350339
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
351340
debug!("## mcp: first hand shake attempt failed: {:?}", e);
352-
let refresh_res =
353-
auth_dg.auth_client.auth_manager.lock().await.refresh_token().await;
341+
let refresh_res = auth_client.refresh_token().await;
354342
let new_self = McpClientService::new(
355343
server_name.clone(),
356344
backup_config,
357345
messenger_clone.clone(),
358346
);
359347

360348
let new_transport =
361-
get_http_transport(&os_clone, true, &url, Some(auth_dg.auth_client.clone()), &*messenger_dup).await?;
349+
get_http_transport(&os_clone, true, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?;
362350

363351
match new_transport {
364-
HttpTransport::WithAuth((new_transport, new_auth_dg)) => {
365-
auth_dg.should_write = false;
366-
auth_dg = new_auth_dg;
352+
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
353+
auth_client = new_auth_client;
367354

368355
match refresh_res {
369-
Ok(_token) => {
356+
Ok(_) => {
370357
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
371358
},
372359
Err(e) => {
@@ -379,9 +366,8 @@ impl McpClientService {
379366
get_http_transport(&os_clone, true, &url, None, &*messenger_dup).await?;
380367

381368
match new_transport {
382-
HttpTransport::WithAuth((new_transport, new_auth_dg)) => {
383-
auth_dg = new_auth_dg;
384-
auth_dg.should_write = false;
369+
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
370+
auth_client = new_auth_client;
385371
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
386372
},
387373
HttpTransport::WithoutAuth(new_transport) => {
@@ -398,7 +384,7 @@ impl McpClientService {
398384
Err(e) => return Err(e.into()),
399385
};
400386

401-
(service, None, Some(auth_dg))
387+
(service, None, Some(auth_client))
402388
},
403389
HttpTransport::WithoutAuth(transport) => {
404390
let service = self.into_dyn().serve(transport).await.map_err(Box::new)?;
@@ -496,7 +482,7 @@ impl McpClientService {
496482

497483
Ok(RunningService {
498484
inner_service: InnerService::Original(service),
499-
auth_dropguard,
485+
auth_client: auth_dropguard,
500486
})
501487
});
502488

crates/chat-cli/src/mcp_client/oauth_util.rs

Lines changed: 28 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ pub enum OauthUtilError {
7171
Reqwest(#[from] reqwest::Error),
7272
#[error("Malformed directory")]
7373
MalformDirectory,
74+
#[error("Missing credential")]
75+
MissingCredentials,
7476
}
7577

7678
/// A guard that automatically cancels the cancellation token when dropped.
@@ -107,68 +109,36 @@ impl From<OAuthClientConfig> for Registration {
107109
}
108110
}
109111

110-
/// A guard that manages the lifecycle of an authenticated MCP client and automatically
111-
/// persists OAuth credentials when dropped.
112+
/// A wrapper that manages an authenticated MCP client.
112113
///
113-
/// This struct wraps an `AuthClient` and ensures that OAuth tokens are written to disk
114-
/// when the guard goes out of scope, unless explicitly disabled via `should_write`.
115-
/// This provides automatic credential caching for MCP server connections that require
116-
/// OAuth authentication.
114+
/// This struct wraps an `AuthClient` and provides access to OAuth credentials
115+
/// for MCP server connections that require authentication. The credentials
116+
/// are managed separately from this wrapper's lifecycle.
117117
#[derive(Clone, Debug)]
118-
pub struct AuthClientDropGuard {
119-
pub should_write: bool,
118+
pub struct AuthClientWrapper {
120119
pub cred_full_path: PathBuf,
121120
pub auth_client: AuthClient<Client>,
122121
}
123122

124-
impl AuthClientDropGuard {
123+
impl AuthClientWrapper {
125124
pub fn new(cred_full_path: PathBuf, auth_client: AuthClient<Client>) -> Self {
126125
Self {
127-
should_write: true,
128126
cred_full_path,
129127
auth_client,
130128
}
131129
}
132-
}
133130

134-
impl Drop for AuthClientDropGuard {
135-
fn drop(&mut self) {
136-
if !self.should_write {
137-
return;
138-
}
131+
/// Refreshes token in memory using the registration read from when the auth client was
132+
/// spawned. This also persists the retrieved token
133+
pub async fn refresh_token(&self) -> Result<(), OauthUtilError> {
134+
let cred = self.auth_client.auth_manager.lock().await.refresh_token().await?;
135+
let parent_path = self.cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?;
136+
tokio::fs::create_dir_all(parent_path).await?;
139137

140-
let auth_client_clone = self.auth_client.clone();
141-
let path = self.cred_full_path.clone();
138+
let cred_as_bytes = serde_json::to_string_pretty(&cred)?;
139+
tokio::fs::write(&self.cred_full_path, &cred_as_bytes).await?;
142140

143-
tokio::spawn(async move {
144-
let Ok((client_id, cred)) = auth_client_clone.auth_manager.lock().await.get_credentials().await else {
145-
error!("Failed to retrieve credentials in drop routine");
146-
return;
147-
};
148-
let Some(cred) = cred else {
149-
error!("Failed to retrieve credentials in drop routine from {client_id}");
150-
return;
151-
};
152-
let Some(parent_path) = path.parent() else {
153-
error!("Failed to retrieve parent path for token in drop routine for {client_id}");
154-
return;
155-
};
156-
if let Err(e) = tokio::fs::create_dir_all(parent_path).await {
157-
error!("Error making parent directory for token cache in drop routine for {client_id}: {e}");
158-
return;
159-
}
160-
161-
let serialized_cred = match serde_json::to_string_pretty(&cred) {
162-
Ok(cred) => cred,
163-
Err(e) => {
164-
error!("Failed to serialize credentials for {client_id}: {e}");
165-
return;
166-
},
167-
};
168-
if let Err(e) = tokio::fs::write(path, &serialized_cred).await {
169-
error!("Error making writing token cache in drop routine: {e}");
170-
}
171-
});
141+
Ok(())
172142
}
173143
}
174144

@@ -186,7 +156,7 @@ pub enum HttpTransport {
186156
WithAuth(
187157
(
188158
WorkerTransport<StreamableHttpClientWorker<AuthClient<Client>>>,
189-
AuthClientDropGuard,
159+
AuthClientWrapper,
190160
),
191161
),
192162
WithoutAuth(WorkerTransport<StreamableHttpClientWorker<Client>>),
@@ -233,7 +203,7 @@ pub async fn get_http_transport(
233203
..Default::default()
234204
});
235205

236-
let auth_dg = AuthClientDropGuard::new(cred_full_path, auth_client);
206+
let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client);
237207
debug!("## mcp: transport obtained");
238208

239209
Ok(HttpTransport::WithAuth((transport, auth_dg)))
@@ -276,11 +246,8 @@ async fn get_auth_manager(
276246

277247
// Client registration is done in [start_authorization]
278248
// If we have gotten past that point that means we have the info to persist the
279-
// registration on disk. These are info that we need to refresh stake
280-
// tokens. This is in contrast to tokens, which we only persist when we drop
281-
// the client (because that way we can write once and ensure what is on the
282-
// disk always the most up to date)
283-
let (client_id, _credentials) = am.get_credentials().await?;
249+
// registration on disk.
250+
let (client_id, credentials) = am.get_credentials().await?;
284251
let reg = Registration {
285252
client_id,
286253
client_secret: None,
@@ -292,6 +259,13 @@ async fn get_auth_manager(
292259
tokio::fs::create_dir_all(reg_parent_path).await?;
293260
tokio::fs::write(reg_full_path, &reg_as_str).await?;
294261

262+
let credentials = credentials.ok_or(OauthUtilError::MissingCredentials)?;
263+
264+
let cred_parent_path = cred_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?;
265+
tokio::fs::create_dir_all(cred_parent_path).await?;
266+
let reg_as_str = serde_json::to_string_pretty(&credentials)?;
267+
tokio::fs::write(cred_full_path, &reg_as_str).await?;
268+
295269
Ok(am)
296270
},
297271
}

0 commit comments

Comments
 (0)