Skip to content

Commit a5b92f7

Browse files
committed
accommodates open remote mcp servers
1 parent 536fad8 commit a5b92f7

File tree

5 files changed

+222
-73
lines changed

5 files changed

+222
-73
lines changed

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,17 +875,19 @@ impl ToolManager {
875875
});
876876
};
877877

878-
let running_service = (*client.get_running_service().await.map_err(|e| ToolResult {
878+
let running_service = client.get_running_service().await.map_err(|e| ToolResult {
879879
tool_use_id: value.id.clone(),
880880
content: vec![ToolResultContentBlock::Text(format!("Mcp tool client not ready: {e}"))],
881881
status: ToolResultStatus::Error,
882-
})?)
883-
.clone();
882+
})?;
883+
884+
let auth_client = running_service.get_auth_client();
884885

885886
Tool::Custom(CustomTool {
886887
name: tool_name.to_owned(),
887888
server_name: server_name.to_owned(),
888-
client: running_service,
889+
client: (*running_service).clone(),
890+
auth_client,
889891
params: value.args.as_object().cloned(),
890892
})
891893
},

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,24 @@ use std::collections::HashMap;
33
use std::io::Write;
44

55
use crossterm::{
6+
execute,
67
queue,
78
style,
89
};
910
use eyre::Result;
11+
use reqwest::Client;
1012
use rmcp::RoleClient;
1113
use rmcp::model::CallToolRequestParam;
14+
use rmcp::transport::auth::AuthClient;
1215
use schemars::JsonSchema;
1316
use serde::{
1417
Deserialize,
1518
Serialize,
1619
};
17-
use tracing::warn;
20+
use tracing::{
21+
info,
22+
warn,
23+
};
1824

1925
use super::InvokeOutput;
2026
use crate::cli::agent::{
@@ -89,19 +95,47 @@ pub struct CustomTool {
8995
pub server_name: String,
9096
/// Reference to the client that manages communication with the tool's server process.
9197
pub client: rmcp::Peer<RoleClient>,
98+
/// Optional authentication client for handling authentication with HTTP-based MCP servers.
99+
/// This is used when the MCP server requires authentication for tool invocation.
100+
pub auth_client: Option<AuthClient<Client>>,
92101
/// Optional parameters to pass to the tool when invoking the method.
93102
/// Structured as a JSON value to accommodate various parameter types and structures.
94103
pub params: Option<serde_json::Map<String, serde_json::Value>>,
95104
}
96105

97106
impl CustomTool {
98-
pub async fn invoke(&self, _os: &Os, _updates: impl Write) -> Result<InvokeOutput> {
107+
pub async fn invoke(&self, _os: &Os, updates: &mut impl Write) -> Result<InvokeOutput> {
99108
let params = CallToolRequestParam {
100109
name: Cow::from(self.name.clone()),
101110
arguments: self.params.clone(),
102111
};
103112

104-
let resp = self.client.call_tool(params).await?;
113+
let mut has_retried = false;
114+
let resp = loop {
115+
match self.client.call_tool(params.clone()).await {
116+
Ok(resp) => break resp,
117+
Err(_e) if !has_retried => {
118+
// TODO: discern error type prior to retrying
119+
has_retried = true;
120+
if let Some(auth_client) = &self.auth_client {
121+
let auth_res = auth_client.auth_manager.lock().await.refresh_token().await;
122+
if let Err(e) = auth_res {
123+
let _ = execute!(
124+
updates,
125+
style::SetForegroundColor(style::Color::Red),
126+
style::Print(format!(
127+
"Authentication token has expired: {e}. Please reauthenticate by restarting session.\n"
128+
)),
129+
style::ResetColor,
130+
);
131+
} else {
132+
info!("Token refreshed for {}", self.server_name);
133+
}
134+
}
135+
},
136+
Err(e) => return Err(e.into()),
137+
}
138+
};
105139

106140
if resp.is_error.is_none_or(|v| !v) {
107141
Ok(InvokeOutput {

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 87 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use std::ops::{
55
DerefMut,
66
};
77
use std::process::Stdio;
8-
use std::str::FromStr;
98

109
use regex::Regex;
1110
use reqwest::Client;
@@ -27,15 +26,9 @@ use rmcp::service::{
2726
NotificationContext,
2827
};
2928
use rmcp::transport::auth::AuthClient;
30-
use rmcp::transport::streamable_http_client::{
31-
StreamableHttpClientTransportConfig,
32-
StreamableHttpClientWorker,
33-
};
3429
use rmcp::transport::{
3530
ConfigureCommandExt,
36-
StreamableHttpClientTransport,
3731
TokioChildProcess,
38-
WorkerTransport,
3932
};
4033
use rmcp::{
4134
ErrorData,
@@ -50,26 +43,25 @@ use tokio::process::{
5043
Command,
5144
};
5245
use tokio::task::JoinHandle;
53-
use tracing::error;
54-
use url::Url;
46+
use tracing::{
47+
error,
48+
info,
49+
};
5550

5651
use super::messenger::Messenger;
52+
use super::oauth_util::HttpTransport;
5753
use super::{
5854
AuthClientDropGuard,
5955
OauthUtilError,
60-
compute_key,
56+
get_http_transport,
6157
};
6258
use crate::cli::chat::server_messenger::ServerMessenger;
6359
use crate::cli::chat::tools::custom_tool::{
6460
CustomToolConfig,
6561
TransportType,
6662
};
67-
use crate::mcp_client::get_auth_manager;
6863
use crate::os::Os;
69-
use crate::util::directories::{
70-
DirectoryError,
71-
get_mcp_auth_dir,
72-
};
64+
use crate::util::directories::DirectoryError;
7365

7466
/// Fetches all pages of specified resources from a server
7567
macro_rules! paginated_fetch {
@@ -152,9 +144,11 @@ pub enum McpClientError {
152144
#[error(transparent)]
153145
Directory(#[from] DirectoryError),
154146
#[error(transparent)]
155-
Oauth(#[from] OauthUtilError),
147+
OauthUtil(#[from] OauthUtilError),
156148
#[error(transparent)]
157149
Parse(#[from] url::ParseError),
150+
#[error(transparent)]
151+
Auth(#[from] crate::auth::AuthError),
158152
}
159153

160154
pub struct RunningService {
@@ -163,6 +157,12 @@ pub struct RunningService {
163157
pub auth_dropguard: Option<AuthClientDropGuard>,
164158
}
165159

160+
impl RunningService {
161+
pub fn get_auth_client(&self) -> Option<AuthClient<Client>> {
162+
self.auth_dropguard.as_ref().map(|a| a.auth_client.clone())
163+
}
164+
}
165+
166166
impl Deref for RunningService {
167167
type Target = rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>;
168168

@@ -177,10 +177,6 @@ impl DerefMut for RunningService {
177177
}
178178
}
179179

180-
pub type HttpTransport = (
181-
WorkerTransport<StreamableHttpClientWorker<AuthClient<Client>>>,
182-
AuthClientDropGuard,
183-
);
184180
pub type StdioTransport = (TokioChildProcess, Option<ChildStderr>);
185181

186182
// TODO: add sse support (even though it's deprecated)
@@ -231,11 +227,13 @@ impl McpClientService {
231227
let os_clone = os.clone();
232228

233229
let handle: JoinHandle<Result<RunningService, McpClientError>> = tokio::spawn(async move {
234-
let messenger_clone = self.messenger.duplicate();
230+
let messenger_clone = self.messenger.clone();
235231
let server_name = self.server_name.clone();
232+
let backup_config = self.config.clone();
236233

237234
let result: Result<_, McpClientError> = async {
238-
let (service, stderr, auth_client) = match self.get_transport(&os_clone, &messenger_clone).await? {
235+
let messenger_dup = messenger_clone.duplicate();
236+
let (service, stderr, auth_client) = match self.get_transport(&os_clone, &*messenger_dup).await? {
239237
Transport::Stdio((child_process, stderr)) => {
240238
let service = self
241239
.into_dyn()
@@ -245,10 +243,71 @@ impl McpClientService {
245243

246244
(service, stderr, None)
247245
},
248-
Transport::Http((transport, auth_dg)) => {
249-
let service = self.into_dyn().serve(transport).await.map_err(Box::new)?;
246+
Transport::Http(http_transport) => {
247+
match http_transport {
248+
HttpTransport::WithAuth((transport, mut auth_dg)) => {
249+
// The crate does not automatically refresh tokens when they expire. We
250+
// would need to handle that here
251+
let url = self.config.url.clone();
252+
let service = match self.into_dyn().serve(transport).await.map_err(Box::new) {
253+
Ok(service) => service,
254+
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
255+
let refresh_res =
256+
auth_dg.auth_client.auth_manager.lock().await.refresh_token().await;
257+
let new_self = McpClientService::new(
258+
server_name.clone(),
259+
backup_config,
260+
messenger_clone.clone(),
261+
);
262+
263+
let new_transport =
264+
get_http_transport(&os_clone, true, &url, &*messenger_dup).await?;
265+
266+
match new_transport {
267+
HttpTransport::WithAuth((new_transport, new_auth_dg)) => {
268+
auth_dg.should_write = false;
269+
auth_dg = new_auth_dg;
270+
271+
match refresh_res {
272+
Ok(_token) => {
273+
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
274+
},
275+
Err(e) => {
276+
info!("Retry for http transport failed {e}. Possible reauth needed");
277+
// This could be because the refresh token is expired, in which
278+
// case we would need to have user go through the auth flow
279+
// again
280+
let new_transport =
281+
get_http_transport(&os_clone, true, &url, &*messenger_dup).await?;
282+
283+
match new_transport {
284+
HttpTransport::WithAuth((new_transport, new_auth_dg)) => {
285+
auth_dg = new_auth_dg;
286+
auth_dg.should_write = false;
287+
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
288+
},
289+
HttpTransport::WithoutAuth(new_transport) => {
290+
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
291+
},
292+
}
293+
},
294+
}
295+
},
296+
HttpTransport::WithoutAuth(new_transport) =>
297+
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?,
298+
}
299+
},
300+
Err(e) => return Err(e.into()),
301+
};
302+
303+
(service, None, Some(auth_dg))
304+
},
305+
HttpTransport::WithoutAuth(transport) => {
306+
let service = self.into_dyn().serve(transport).await.map_err(Box::new)?;
250307

251-
(service, None, Some(auth_dg))
308+
(service, None, None)
309+
},
310+
}
252311
},
253312
};
254313

@@ -346,7 +405,7 @@ impl McpClientService {
346405
Ok(InitializedMcpClient::Pending(handle))
347406
}
348407

349-
async fn get_transport(&mut self, os: &Os, messenger: &Box<dyn Messenger>) -> Result<Transport, McpClientError> {
408+
async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result<Transport, McpClientError> {
350409
// TODO: figure out what to do with headers
351410
let CustomToolConfig {
352411
r#type: transport_type,
@@ -376,26 +435,9 @@ impl McpClientService {
376435
Ok(Transport::Stdio((tokio_child_process, child_stderr)))
377436
},
378437
TransportType::Http => {
379-
let cred_dir = get_mcp_auth_dir(os)?;
380-
let url = Url::from_str(url)?;
381-
let key = compute_key(&url);
382-
let cred_full_path = cred_dir.join(format!("{key}.token.json"));
383-
384-
let am = get_auth_manager(url.clone(), cred_full_path.clone(), messenger).await?;
385-
let client = AuthClient::new(reqwest::Client::default(), am);
386-
let transport =
387-
StreamableHttpClientTransport::with_client(client.clone(), StreamableHttpClientTransportConfig {
388-
uri: url.as_str().into(),
389-
allow_stateless: false,
390-
..Default::default()
391-
});
392-
393-
let auth_dg = AuthClientDropGuard {
394-
path: cred_full_path,
395-
auth_client: client,
396-
};
438+
let http_transport = get_http_transport(os, false, url, messenger).await?;
397439

398-
Ok(Transport::Http((transport, auth_dg)))
440+
Ok(Transport::Http(http_transport))
399441
},
400442
}
401443
}

crates/chat-cli/src/mcp_client/messenger.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ impl Messenger for NullMessenger {
111111
Ok(())
112112
}
113113

114-
async fn send_oauth_link(&self, link: String) -> MessengerResult {
114+
async fn send_oauth_link(&self, _link: String) -> MessengerResult {
115115
Ok(())
116116
}
117117

0 commit comments

Comments
 (0)