Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ winnow = "=0.6.2"
winreg = "0.55.0"
schemars = "1.0.4"
jsonschema = "0.30.0"
rmcp = { version = "0.6.3", features = ["client", "transport-sse-client-reqwest", "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth"] }
rmcp = { version = "0.7.0", features = ["client", "transport-sse-client-reqwest", "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth"] }

[workspace.lints.rust]
future_incompatible = "warn"
Expand Down
179 changes: 46 additions & 133 deletions crates/chat-cli/src/mcp_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use regex::Regex;
use rmcp::model::{
CallToolRequestParam,
CallToolResult,
ClientResult,
ErrorCode,
GetPromptRequestParam,
GetPromptResult,
Expand Down Expand Up @@ -42,17 +43,15 @@ use tokio::process::{
};
use tokio::task::JoinHandle;
use tracing::{
debug,
error,
info,
};

use super::messenger::Messenger;
use super::oauth_util::HttpTransport;
use super::{
AuthClientWrapper,
HttpServiceBuilder,
OauthUtilError,
get_http_transport,
};
use crate::cli::chat::server_messenger::ServerMessenger;
use crate::cli::chat::tools::custom_tool::{
Expand Down Expand Up @@ -266,37 +265,10 @@ impl RunningService {
decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult);
}

pub type StdioTransport = (TokioChildProcess, Option<ChildStderr>);

// TODO: add sse support (even though it's deprecated)
/// Represents the different transport mechanisms available for MCP (Model Context Protocol)
/// communication.
///
/// This enum encapsulates the two primary ways to communicate with MCP servers:
/// - HTTP-based transport for remote servers
/// - Standard I/O transport for local process-based servers
pub enum Transport {
/// HTTP transport for communicating with remote MCP servers over network protocols.
/// Uses a streamable HTTP client with authentication support.
Http(HttpTransport),
/// Standard I/O transport for communicating with local MCP servers via child processes.
/// Communication happens through stdin/stdout pipes.
Stdio(StdioTransport),
}

impl std::fmt::Debug for Transport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Transport::Http(_) => f.debug_tuple("Http").field(&"HttpTransport").finish(),
Transport::Stdio(_) => f.debug_tuple("Stdio").field(&"TokioChildProcess").finish(),
}
}
}

/// This struct implements the [Service] trait from rmcp. It is within this trait the logic of
/// server driven data flow (i.e. requests and notifications that are sent from the server) are
/// handled.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct McpClientService {
pub config: CustomToolConfig,
server_name: String,
Expand All @@ -312,103 +284,14 @@ impl McpClientService {
}
}

pub async fn init(mut self, os: &Os) -> Result<InitializedMcpClient, McpClientError> {
pub async fn init(self, os: &Os) -> Result<InitializedMcpClient, McpClientError> {
let os_clone = os.clone();

let handle: JoinHandle<Result<RunningService, McpClientError>> = tokio::spawn(async move {
let messenger_clone = self.messenger.clone();
let server_name = self.server_name.clone();
let backup_config = self.config.clone();

let result: Result<_, McpClientError> = async {
let messenger_dup = messenger_clone.duplicate();
let (service, stderr, auth_client) = match self.get_transport(&os_clone, &*messenger_dup).await? {
Transport::Stdio((child_process, stderr)) => {
let service = self
.into_dyn()
.serve::<TokioChildProcess, _, _>(child_process)
.await
.map_err(Box::new)?;

(service, stderr, None)
},
Transport::Http(http_transport) => {
match http_transport {
HttpTransport::WithAuth((transport, mut auth_client)) => {
// The crate does not automatically refresh tokens when they expire. We
// would need to handle that here
let url = &backup_config.url;
let service = match self.into_dyn().serve(transport).await.map_err(Box::new) {
Ok(service) => service,
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
debug!("## mcp: first hand shake attempt failed: {:?}", e);
let refresh_res = auth_client.refresh_token().await;
let new_self = McpClientService::new(
server_name.clone(),
backup_config.clone(),
messenger_clone.clone(),
);

let scopes = &backup_config.oauth_scopes;
let timeout = backup_config.timeout;
let headers = &backup_config.headers;
let new_transport =
get_http_transport(&os_clone, url, timeout, scopes, headers,Some(auth_client.auth_client.clone()), &*messenger_dup).await?;

match new_transport {
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
auth_client = new_auth_client;

match refresh_res {
Ok(_) => {
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
},
Err(e) => {
error!("## mcp: token refresh attempt failed: {:?}", e);
info!("Retry for http transport failed {e}. Possible reauth needed");
// This could be because the refresh token is expired, in which
// case we would need to have user go through the auth flow
// again. We do this by deleting the cred
// and discarding the client to trigger a full auth flow
tokio::fs::remove_file(&auth_client.cred_full_path).await?;
let new_transport =
get_http_transport(&os_clone, url, timeout, scopes,headers,None, &*messenger_dup).await?;

match new_transport {
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
auth_client = new_auth_client;
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
},
HttpTransport::WithoutAuth(new_transport) => {
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?
},
}
},
}
},
HttpTransport::WithoutAuth(new_transport) =>
new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?,
}
},
Err(e) => return Err(e.into()),
};

(service, None, Some(auth_client))
},
HttpTransport::WithoutAuth(transport) => {
let service = self.into_dyn().serve(transport).await.map_err(Box::new)?;

(service, None, None)
},
}
},
};

Ok((service, stderr, auth_client))
}
.await;

let (service, child_stderr, auth_dropguard) = match result {
let (service, child_stderr, auth_dropguard) = match self.into_service(&os_clone, &messenger_clone).await {
Ok((service, stderr, auth_dg)) => (service, stderr, auth_dg),
Err(e) => {
let msg = e.to_string();
Expand Down Expand Up @@ -498,18 +381,24 @@ impl McpClientService {
Ok(InitializedMcpClient::Pending(handle))
}

async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result<Transport, McpClientError> {
async fn into_service(
mut self,
os: &Os,
messenger: &dyn Messenger,
) -> Result<
(
rmcp::service::RunningService<RoleClient, Box<dyn DynService<RoleClient>>>,
Option<ChildStderr>,
Option<AuthClientWrapper>,
),
McpClientError,
> {
let CustomToolConfig {
r#type,
url,
headers,
oauth_scopes: scopes,
command: command_as_str,
args,
env: config_envs,
timeout,
..
} = &mut self.config;
} = &self.config;

let is_malformed_http = matches!(r#type, TransportType::Http) && url.is_empty();
let is_malformed_stdio = matches!(r#type, TransportType::Stdio) && command_as_str.is_empty();
Expand All @@ -526,6 +415,13 @@ impl McpClientService {

match r#type {
TransportType::Stdio => {
let CustomToolConfig {
command: command_as_str,
args,
env: config_envs,
..
} = &mut self.config;

let context = |input: &str| Ok(os.env.get(input).ok());
let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string());
let expanded_cmd = shellexpand::full_with_context(command_as_str, home_dir, context)?;
Expand All @@ -544,12 +440,28 @@ impl McpClientService {
let (tokio_child_process, child_stderr) =
TokioChildProcess::builder(command).stderr(Stdio::piped()).spawn()?;

Ok(Transport::Stdio((tokio_child_process, child_stderr)))
let service = self
.into_dyn()
.serve::<TokioChildProcess, _, _>(tokio_child_process)
.await
.map_err(Box::new)?;

Ok((service, child_stderr, None))
},
TransportType::Http => {
let http_transport = get_http_transport(os, url, *timeout, scopes, headers, None, messenger).await?;
let CustomToolConfig {
url,
headers,
oauth_scopes: scopes,
timeout,
..
} = &self.config;

let http_service_builder = HttpServiceBuilder::new(url, os, url, *timeout, scopes, headers, messenger);

let (service, auth_client_wrapper) = http_service_builder.try_build(&self).await?;

Ok(Transport::Http(http_transport))
Ok((service, None, auth_client_wrapper))
},
}
}
Expand Down Expand Up @@ -620,7 +532,7 @@ impl Service<RoleClient> for McpClientService {
_context: rmcp::service::RequestContext<RoleClient>,
) -> Result<<RoleClient as rmcp::service::ServiceRole>::Resp, rmcp::ErrorData> {
match request {
ServerRequest::PingRequest(_) => Err(rmcp::ErrorData::method_not_found::<rmcp::model::PingRequestMethod>()),
ServerRequest::PingRequest(_) => Ok(ClientResult::empty(())),
ServerRequest::CreateMessageRequest(_) => Err(rmcp::ErrorData::method_not_found::<
rmcp::model::CreateMessageRequestMethod,
>()),
Expand Down Expand Up @@ -660,6 +572,7 @@ impl Service<RoleClient> for McpClientService {
client_info: Implementation {
name: "Q DEV CLI".to_string(),
version: "1.0.0".to_string(),
..Default::default()
},
}
}
Expand Down
Loading
Loading