Skip to content

Commit da747e1

Browse files
committed
refreshes token on expiration
1 parent a5b92f7 commit da747e1

File tree

4 files changed

+92
-34
lines changed

4 files changed

+92
-34
lines changed

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ use crate::database::settings::Setting;
9090
use crate::mcp_client::messenger::Messenger;
9191
use crate::mcp_client::{
9292
InitializedMcpClient,
93+
InnerService,
9394
McpClientService,
9495
};
9596
use crate::os::Os;
@@ -631,7 +632,9 @@ impl ToolManager {
631632
tokio::spawn(async move {
632633
match handle.await {
633634
Ok(Ok(client)) => {
634-
let client = client.inner_service;
635+
let InnerService::Original(client) = client.inner_service else {
636+
unreachable!();
637+
};
635638
match client.cancel().await {
636639
Ok(_) => info!("Server {server_name_clone} evicted due to agent swap"),
637640
Err(e) => error!("Server {server_name_clone} has failed to cancel: {e}"),
@@ -644,7 +647,9 @@ impl ToolManager {
644647
});
645648
},
646649
InitializedMcpClient::Ready(running_service) => {
647-
let client = running_service.inner_service;
650+
let InnerService::Original(client) = running_service.inner_service else {
651+
unreachable!();
652+
};
648653
match client.cancel().await {
649654
Ok(_) => info!("Server {server_name} evicted due to agent swap"),
650655
Err(e) => error!("Server {server_name} has failed to cancel: {e}"),
@@ -886,7 +891,7 @@ impl ToolManager {
886891
Tool::Custom(CustomTool {
887892
name: tool_name.to_owned(),
888893
server_name: server_name.to_owned(),
889-
client: (*running_service).clone(),
894+
client: running_service.clone(),
890895
auth_client,
891896
params: value.args.as_object().cloned(),
892897
})

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use crossterm::{
99
};
1010
use eyre::Result;
1111
use reqwest::Client;
12-
use rmcp::RoleClient;
1312
use rmcp::model::CallToolRequestParam;
1413
use rmcp::transport::auth::AuthClient;
1514
use schemars::JsonSchema;
@@ -29,6 +28,7 @@ use crate::cli::agent::{
2928
};
3029
use crate::cli::chat::CONTINUATION_LINE;
3130
use crate::cli::chat::token_counter::TokenCounter;
31+
use crate::mcp_client::RunningService;
3232
use crate::os::Os;
3333
use crate::util::MCP_SERVER_TOOL_DELIMITER;
3434
use crate::util::pattern_matching::matches_any_pattern;
@@ -94,7 +94,7 @@ pub struct CustomTool {
9494
/// prefixed to the tool name when presented to the model for disambiguation.
9595
pub server_name: String,
9696
/// Reference to the client that manages communication with the tool's server process.
97-
pub client: rmcp::Peer<RoleClient>,
97+
pub client: RunningService,
9898
/// Optional authentication client for handling authentication with HTTP-based MCP servers.
9999
/// This is used when the MCP server requires authentication for tool invocation.
100100
pub auth_client: Option<AuthClient<Client>>,

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
11
use std::borrow::Cow;
22
use std::collections::HashMap;
3-
use std::ops::{
4-
Deref,
5-
DerefMut,
6-
};
73
use std::process::Stdio;
84

95
use regex::Regex;
106
use reqwest::Client;
117
use rmcp::model::{
12-
ErrorCode,
13-
Implementation,
14-
InitializeRequestParam,
15-
ListPromptsResult,
16-
ListToolsResult,
17-
LoggingLevel,
18-
LoggingMessageNotificationParam,
19-
PaginatedRequestParam,
20-
ServerNotification,
21-
ServerRequest,
8+
CallToolRequestParam, CallToolResult, ErrorCode, GetPromptRequestParam, GetPromptResult, Implementation, InitializeRequestParam, ListPromptsResult, ListToolsResult, LoggingLevel, LoggingMessageNotificationParam, PaginatedRequestParam, ServerNotification, ServerRequest
229
};
2310
use rmcp::service::{
2411
ClientInitializeError,
@@ -151,30 +138,95 @@ pub enum McpClientError {
151138
Auth(#[from] crate::auth::AuthError),
152139
}
153140

154-
pub struct RunningService {
155-
pub inner_service: rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>,
156-
#[allow(dead_code)]
157-
pub auth_dropguard: Option<AuthClientDropGuard>,
141+
macro_rules! decorate_with_auth_retry {
142+
($param_type:ty, $method_name:ident, $return_type:ty) => {
143+
pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> {
144+
let first_attempt = match &self.inner_service {
145+
InnerService::Original(rs) => rs.$method_name(param.clone()).await,
146+
InnerService::Peer(peer) => peer.$method_name(param.clone()).await,
147+
};
148+
149+
match first_attempt {
150+
Ok(result) => Ok(result),
151+
Err(e) => {
152+
// TODO: discern error type prior to retrying
153+
// Not entirely sure what is thrown when auth is required
154+
if let Some(auth_client) = self.get_auth_client() {
155+
let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await;
156+
match refresh_result {
157+
Ok(_) => {
158+
// Retry the operation after token refresh
159+
match &self.inner_service {
160+
InnerService::Original(rs) => rs.$method_name(param).await,
161+
InnerService::Peer(peer) => peer.$method_name(param).await,
162+
}
163+
},
164+
Err(_) => {
165+
// If refresh fails, return the original error
166+
Err(e)
167+
}
168+
}
169+
} else {
170+
// No auth client available, return original error
171+
Err(e)
172+
}
173+
},
174+
}
175+
}
176+
};
158177
}
159178

160-
impl RunningService {
161-
pub fn get_auth_client(&self) -> Option<AuthClient<Client>> {
162-
self.auth_dropguard.as_ref().map(|a| a.auth_client.clone())
179+
pub enum InnerService {
180+
Original(rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>),
181+
Peer(rmcp::service::Peer<RoleClient>),
182+
}
183+
184+
impl std::fmt::Debug for InnerService {
185+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186+
match self {
187+
InnerService::Original(_) => f.debug_tuple("Original").field(&"RunningService<..>").finish(),
188+
InnerService::Peer(peer) => f.debug_tuple("Peer").field(peer).finish(),
189+
}
163190
}
164191
}
165192

166-
impl Deref for RunningService {
167-
type Target = rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>;
193+
impl Clone for InnerService {
194+
fn clone(&self) -> Self {
195+
match self {
196+
InnerService::Original(rs) => InnerService::Peer((*rs).clone()),
197+
InnerService::Peer(peer) => InnerService::Peer(peer.clone())
198+
}
199+
}
200+
}
168201

169-
fn deref(&self) -> &Self::Target {
170-
&self.inner_service
202+
#[derive(Debug)]
203+
pub struct RunningService {
204+
pub inner_service: InnerService,
205+
auth_dropguard: Option<AuthClientDropGuard>,
206+
}
207+
208+
impl Clone for RunningService {
209+
fn clone(&self) -> Self {
210+
let auth_dropguard = self.auth_dropguard.as_ref().map(|dg| {
211+
let mut dg = dg.clone();
212+
dg.should_write = false;
213+
dg
214+
});
215+
216+
RunningService {
217+
inner_service: self.inner_service.clone(),
218+
auth_dropguard
219+
}
171220
}
172221
}
173222

174-
impl DerefMut for RunningService {
175-
fn deref_mut(&mut self) -> &mut Self::Target {
176-
&mut self.inner_service
223+
impl RunningService {
224+
pub fn get_auth_client(&self) -> Option<AuthClient<Client>> {
225+
self.auth_dropguard.as_ref().map(|a| a.auth_client.clone())
177226
}
227+
228+
decorate_with_auth_retry!(CallToolRequestParam, call_tool, CallToolResult);
229+
decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult);
178230
}
179231

180232
pub type StdioTransport = (TokioChildProcess, Option<ChildStderr>);
@@ -397,7 +449,7 @@ impl McpClientService {
397449
});
398450

399451
Ok(RunningService {
400-
inner_service: service,
452+
inner_service: InnerService::Original(service),
401453
auth_dropguard,
402454
})
403455
});

crates/chat-cli/src/mcp_client/oauth_util.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ impl Drop for LoopBackDropGuard {
7878
}
7979
}
8080

81+
#[derive(Clone, Debug)]
8182
pub struct AuthClientDropGuard {
8283
pub should_write: bool,
8384
pub cred_full_path: PathBuf,

0 commit comments

Comments
 (0)