Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/chat-cli/src/cli/chat/cli/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ impl McpArgs {
let msg = msg
.iter()
.map(|record| match record {
LoadingRecord::Err(content) | LoadingRecord::Warn(content) | LoadingRecord::Success(content) => {
content.clone()
},
LoadingRecord::Err(timestamp, content)
| LoadingRecord::Warn(timestamp, content)
| LoadingRecord::Success(timestamp, content) => format!("[{timestamp}]: {content}"),
})
.collect::<Vec<_>>()
.join("\n--- tools refreshed ---\n");
Expand Down
64 changes: 46 additions & 18 deletions crates/chat-cli/src/cli/chat/tool_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,26 @@ enum LoadingMsg {
/// surface (since we would only want to surface fatal errors in non-interactive mode).
#[derive(Clone, Debug)]
pub enum LoadingRecord {
Success(String),
Warn(String),
Err(String),
Success(String, String),
Warn(String, String),
Err(String, String),
}

impl LoadingRecord {
pub fn success(msg: String) -> Self {
let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string();
LoadingRecord::Success(timestamp, msg)
}

pub fn warn(msg: String) -> Self {
let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string();
LoadingRecord::Warn(timestamp, msg)
}

pub fn err(msg: String) -> Self {
let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string();
LoadingRecord::Err(timestamp, msg)
}
}

pub struct ToolManagerBuilder {
Expand Down Expand Up @@ -473,10 +490,11 @@ pub enum PromptQueryResult {
/// - `IllegalChar`: The tool name contains characters that are not allowed
/// - `EmptyDescription`: The tool description is empty or missing
#[allow(dead_code)]
enum OutOfSpecName {
enum ToolValidationViolation {
TooLong(String),
IllegalChar(String),
EmptyDescription(String),
DescriptionTooLong(String),
}

#[derive(Clone, Default, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -814,7 +832,7 @@ impl ToolManager {
.lock()
.await
.iter()
.any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_))))
.any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(..))))
{
queue!(
stderr,
Expand Down Expand Up @@ -962,7 +980,7 @@ impl ToolManager {
if !conflicts.is_empty() {
let mut record_lock = self.mcp_load_record.lock().await;
for (server_name, msg) in conflicts {
let record = LoadingRecord::Err(msg);
let record = LoadingRecord::err(msg);
record_lock
.entry(server_name)
.and_modify(|v| v.push(record.clone()))
Expand Down Expand Up @@ -1494,9 +1512,9 @@ fn spawn_orchestrator_task(
drop(buf_writer);
let record = String::from_utf8_lossy(record_temp_buf).to_string();
let record = if process_result.is_err() {
LoadingRecord::Warn(record)
LoadingRecord::warn(record)
} else {
LoadingRecord::Success(record)
LoadingRecord::success(record)
};
load_record
.lock()
Expand All @@ -1522,7 +1540,7 @@ fn spawn_orchestrator_task(
let _ = buf_writer.flush();
drop(buf_writer);
let record = String::from_utf8_lossy(record_temp_buf).to_string();
let record = LoadingRecord::Err(record);
let record = LoadingRecord::err(record);
load_record
.lock()
.await
Expand Down Expand Up @@ -1606,7 +1624,7 @@ fn spawn_orchestrator_task(
let _ = buf_writer.flush();
drop(buf_writer);
let record = String::from_utf8_lossy(record_temp_buf).to_string();
let record = LoadingRecord::Err(record);
let record = LoadingRecord::err(record);
load_record
.lock()
.await
Expand All @@ -1626,7 +1644,7 @@ fn spawn_orchestrator_task(
let _ = buf_writer.flush();
drop(buf_writer);
let record_str = String::from_utf8_lossy(record_temp_buf).to_string();
let record = LoadingRecord::Warn(record_str.clone());
let record = LoadingRecord::warn(record_str.clone());
load_record
.lock()
.await
Expand Down Expand Up @@ -1720,7 +1738,7 @@ async fn process_tool_specs(
//
// For non-compliance due to point 1, we shall change it on behalf of the users.
// For the rest, we simply throw a warning and reject the tool.
let mut out_of_spec_tool_names = Vec::<OutOfSpecName>::new();
let mut out_of_spec_tool_names = Vec::<ToolValidationViolation>::new();
let mut hasher = DefaultHasher::new();
let mut number_of_tools = 0_usize;

Expand All @@ -1745,12 +1763,18 @@ async fn process_tool_specs(
}
});
if model_tool_name.len() > 64 {
out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone()));
out_of_spec_tool_names.push(ToolValidationViolation::TooLong(spec.name.clone()));
continue;
} else if spec.description.is_empty() {
out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone()));
out_of_spec_tool_names.push(ToolValidationViolation::EmptyDescription(spec.name.clone()));
continue;
}

if spec.description.len() > 10_004 {
spec.description.truncate(10_004);
out_of_spec_tool_names.push(ToolValidationViolation::DescriptionTooLong(spec.name.clone()));
}

tn_map.insert(model_tool_name.clone(), ToolInfo {
server_name: server_name.to_string(),
host_tool_name: spec.name.clone(),
Expand Down Expand Up @@ -1788,21 +1812,25 @@ async fn process_tool_specs(
if !out_of_spec_tool_names.is_empty() {
Err(eyre::eyre!(out_of_spec_tool_names.iter().fold(
String::from(
"The following tools are out of spec. They will be excluded from the list of available tools:\n",
"The following tools are out of spec. They may have been excluded from the list of available tools:\n",
),
|mut acc, name| {
let (tool_name, msg) = match name {
OutOfSpecName::TooLong(tool_name) => (
ToolValidationViolation::TooLong(tool_name) => (
tool_name.as_str(),
"tool name exceeds max length of 64 when combined with server name",
),
OutOfSpecName::IllegalChar(tool_name) => (
ToolValidationViolation::IllegalChar(tool_name) => (
tool_name.as_str(),
"tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$",
),
OutOfSpecName::EmptyDescription(tool_name) => {
ToolValidationViolation::EmptyDescription(tool_name) => {
(tool_name.as_str(), "tool schema contains empty description")
},
ToolValidationViolation::DescriptionTooLong(tool_name) => (
tool_name.as_str(),
"tool description is longer than 10024 characters and has been truncated",
),
};
acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str());
acc
Expand Down
14 changes: 12 additions & 2 deletions crates/chat-cli/src/mcp_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ pub enum McpClientError {
Auth(#[from] crate::auth::AuthError),
}

/// Decorates the method passed in with retry logic, but only if the [RunningService] has an
/// instance of [AuthClientDropGuard].
/// The various methods to interact with the mcp server provided by RMCP supposedly does refresh
/// token once the token expires but that logic would require us to also note down the time at
/// which a token is obtained since the only time related information in the token is the duration
/// for which a token is valid. However, if we do solely rely on the internals of these methods to
/// refresh tokens, we would have no way of knowing when a token is obtained. (Maybe there is a
/// method that would allow us to configure what extra info to include in the token. If you find it,
/// feel free to remove this. That would also enable us to simplify the definition of
/// [RunningService])
macro_rules! decorate_with_auth_retry {
($param_type:ty, $method_name:ident, $return_type:ty) => {
pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> {
Expand All @@ -166,7 +176,7 @@ macro_rules! decorate_with_auth_retry {
// TODO: discern error type prior to retrying
// Not entirely sure what is thrown when auth is required
if let Some(auth_client) = self.get_auth_client() {
let refresh_result = auth_client.get_access_token().await;
let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await;
match refresh_result {
Ok(_) => {
// Retry the operation after token refresh
Expand Down Expand Up @@ -340,7 +350,7 @@ impl McpClientService {
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
debug!("## mcp: first hand shake attempt failed: {:?}", e);
let refresh_res =
auth_dg.auth_client.get_access_token().await;
auth_dg.auth_client.auth_manager.lock().await.refresh_token().await;
let new_self = McpClientService::new(
server_name.clone(),
backup_config,
Expand Down
86 changes: 71 additions & 15 deletions crates/chat-cli/src/mcp_client/oauth_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use reqwest::Client;
use rmcp::serde_json;
use rmcp::transport::auth::{
AuthClient,
OAuthClientConfig,
OAuthState,
OAuthTokenResponse,
};
Expand All @@ -26,6 +27,10 @@ use rmcp::transport::{
StreamableHttpClientTransport,
WorkerTransport,
};
use serde::{
Deserialize,
Serialize,
};
use sha2::{
Digest,
Sha256,
Expand Down Expand Up @@ -64,6 +69,8 @@ pub enum OauthUtilError {
Directory(#[from] DirectoryError),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
#[error("Malformed directory")]
MalformDirectory,
}

/// A guard that automatically cancels the cancellation token when dropped.
Expand All @@ -79,6 +86,27 @@ impl Drop for LoopBackDropGuard {
}
}

/// This is modeled after [OAuthClientConfig]
/// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct Registration {
pub client_id: String,
pub client_secret: Option<String>,
pub scopes: Vec<String>,
pub redirect_uri: String,
}

impl From<OAuthClientConfig> for Registration {
fn from(value: OAuthClientConfig) -> Self {
Self {
client_id: value.client_id,
client_secret: value.client_secret,
scopes: value.scopes,
redirect_uri: value.redirect_uri,
}
}
}

/// A guard that manages the lifecycle of an authenticated MCP client and automatically
/// persists OAuth credentials when dropped.
///
Expand Down Expand Up @@ -164,6 +192,10 @@ pub enum HttpTransport {
WithoutAuth(WorkerTransport<StreamableHttpClientWorker<Client>>),
}

fn get_scopes() -> &'static [&'static str] {
&["openid", "mcp", "email", "profile"]
}

pub async fn get_http_transport(
os: &Os,
delete_cache: bool,
Expand All @@ -175,6 +207,7 @@ pub async fn get_http_transport(
let url = Url::from_str(url)?;
let key = compute_key(&url);
let cred_full_path = cred_dir.join(format!("{key}.token.json"));
let reg_full_path = cred_dir.join(format!("{key}.registration.json"));

if delete_cache && cred_full_path.is_file() {
tokio::fs::remove_file(&cred_full_path).await?;
Expand All @@ -188,7 +221,8 @@ pub async fn get_http_transport(
let auth_client = match auth_client {
Some(auth_client) => auth_client,
None => {
let am = get_auth_manager(url.clone(), cred_full_path.clone(), messenger).await?;
let am =
get_auth_manager(url.clone(), cred_full_path.clone(), reg_full_path.clone(), messenger).await?;
AuthClient::new(reqwest_client, am)
},
};
Expand All @@ -215,45 +249,67 @@ pub async fn get_http_transport(
async fn get_auth_manager(
url: Url,
cred_full_path: PathBuf,
reg_full_path: PathBuf,
messenger: &dyn Messenger,
) -> Result<AuthorizationManager, OauthUtilError> {
let content_as_bytes = tokio::fs::read(&cred_full_path).await;
let cred_as_bytes = tokio::fs::read(&cred_full_path).await;
let reg_as_bytes = tokio::fs::read(&reg_full_path).await;
let mut oauth_state = OAuthState::new(url, None).await?;

match content_as_bytes {
Ok(bytes) => {
let token = serde_json::from_slice::<OAuthTokenResponse>(&bytes)?;
match (cred_as_bytes, reg_as_bytes) {
(Ok(cred_as_bytes), Ok(reg_as_bytes)) => {
let token = serde_json::from_slice::<OAuthTokenResponse>(&cred_as_bytes)?;
let reg = serde_json::from_slice::<Registration>(&reg_as_bytes)?;

oauth_state.set_credentials("id", token).await?;
oauth_state.set_credentials(&reg.client_id, token).await?;

debug!("## mcp: credentials set with cache");

Ok(oauth_state
.into_authorization_manager()
.ok_or(OauthUtilError::MissingAuthorizationManager)?)
},
Err(e) => {
info!("Error reading cached credentials: {e}");
_ => {
info!("Error reading cached credentials");
debug!("## mcp: cache read failed. constructing auth manager from scratch");
get_auth_manager_impl(oauth_state, messenger).await
let (am, redirect_uri) = get_auth_manager_impl(oauth_state, messenger).await?;

// Client registration is done in [start_authorization]
// If we have gotten past that point that means we have the info to persist the
// registration on disk. These are info that we need to refresh stake
// tokens. This is in contrast to tokens, which we only persist when we drop
// the client (because that way we can write once and ensure what is on the
// disk always the most up to date)
let (client_id, _credentials) = am.get_credentials().await?;
let reg = Registration {
client_id,
client_secret: None,
scopes: get_scopes().iter().map(|s| (*s).to_string()).collect::<Vec<_>>(),
redirect_uri,
};
let reg_as_str = serde_json::to_string_pretty(&reg)?;
let reg_parent_path = reg_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?;
tokio::fs::create_dir(reg_parent_path).await?;
tokio::fs::write(reg_full_path, &reg_as_str).await?;

Ok(am)
},
}
}

async fn get_auth_manager_impl(
mut oauth_state: OAuthState,
messenger: &dyn Messenger,
) -> Result<AuthorizationManager, OauthUtilError> {
) -> Result<(AuthorizationManager, String), OauthUtilError> {
let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0));
let cancellation_token = tokio_util::sync::CancellationToken::new();
let (tx, rx) = tokio::sync::oneshot::channel::<String>();

let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?;
info!("Listening on local host port {:?} for oauth", actual_addr);

oauth_state
.start_authorization(&["mcp", "profile", "email"], &format!("http://{}", actual_addr))
.await?;
let redirect_uri = format!("http://{}", actual_addr);
oauth_state.start_authorization(get_scopes(), &redirect_uri).await?;

let auth_url = oauth_state.get_authorization_url().await?;
_ = messenger.send_oauth_link(auth_url).await;
Expand All @@ -264,7 +320,7 @@ async fn get_auth_manager_impl(
.into_authorization_manager()
.ok_or(OauthUtilError::MissingAuthorizationManager)?;

Ok(am)
Ok((am, redirect_uri))
}

pub fn compute_key(rs: &Url) -> String {
Expand Down Expand Up @@ -320,7 +376,7 @@ async fn make_svc(
{
sender.send(code).map_err(LoopBackError::Send)?;
}
mk_response("Auth code sent".to_string())
mk_response("You can close this page now".to_string())
})
}
}
Expand Down
Loading