Skip to content

Commit 5d0c62b

Browse files
authored
fix(mcp): oauth issues (#2925)
* fix incorrect scope for mcp oauth * reverts custom tool config enum change * fixes display task overriding sign in notice * updates schema
1 parent dd04793 commit 5d0c62b

File tree

7 files changed

+125
-31
lines changed

7 files changed

+125
-31
lines changed

crates/chat-cli/src/cli/chat/prompt.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -745,10 +745,7 @@ mod tests {
745745
let key_event = KeyEvent(KeyCode::Char(key), Modifiers::CTRL);
746746

747747
// Try to bind and get the previous handler
748-
let previous_handler = test_editor.bind_sequence(
749-
key_event,
750-
EventHandler::Simple(Cmd::Noop)
751-
);
748+
let previous_handler = test_editor.bind_sequence(key_event, EventHandler::Simple(Cmd::Noop));
752749

753750
// If there was a previous handler, it means the key was already bound
754751
// (which could be our custom binding overriding Emacs)

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,7 @@ fn spawn_display_task(
12101210
terminal::Clear(terminal::ClearType::CurrentLine),
12111211
)?;
12121212
queue_oauth_message(&name, &mut output)?;
1213+
queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?;
12131214
},
12141215
},
12151216
Err(_e) => {

crates/chat-cli/src/cli/chat/tools/custom_tool.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ use crate::cli::agent::{
2222
};
2323
use crate::cli::chat::CONTINUATION_LINE;
2424
use crate::cli::chat::token_counter::TokenCounter;
25-
use crate::mcp_client::RunningService;
25+
use crate::mcp_client::{
26+
RunningService,
27+
oauth_util,
28+
};
2629
use crate::os::Os;
2730
use crate::util::MCP_SERVER_TOOL_DELIMITER;
2831
use crate::util::pattern_matching::matches_any_pattern;
@@ -43,17 +46,20 @@ impl Default for TransportType {
4346
}
4447

4548
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
49+
#[serde(rename_all = "camelCase", deny_unknown_fields)]
4650
pub struct CustomToolConfig {
47-
/// The type of transport the mcp server is expecting. For http transport, only url (for now)
48-
/// is taken into account.
51+
/// The transport type to use for communication with the MCP server
4952
#[serde(default)]
5053
pub r#type: TransportType,
51-
/// The URL endpoint for HTTP-based MCP servers
54+
/// The URL for HTTP-based MCP server communication
5255
#[serde(default)]
5356
pub url: String,
5457
/// HTTP headers to include when communicating with HTTP-based MCP servers
5558
#[serde(default)]
5659
pub headers: HashMap<String, String>,
60+
/// Scopes with which oauth is done
61+
#[serde(default = "get_default_scopes")]
62+
pub oauth_scopes: Vec<String>,
5763
/// The command string used to initialize the mcp server
5864
#[serde(default)]
5965
pub command: String,
@@ -74,6 +80,13 @@ pub struct CustomToolConfig {
7480
pub is_from_legacy_mcp_json: bool,
7581
}
7682

83+
pub fn get_default_scopes() -> Vec<String> {
84+
oauth_util::get_default_scopes()
85+
.iter()
86+
.map(|s| (*s).to_string())
87+
.collect::<Vec<_>>()
88+
}
89+
7790
pub fn default_timeout() -> u64 {
7891
120 * 1000
7992
}

crates/chat-cli/src/cli/mcp.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ impl StatusArgs {
406406
style::Print(format!("Disabled: {}\n", cfg.disabled)),
407407
style::Print(format!(
408408
"Env Vars: {}\n",
409-
cfg.env.as_ref().map_or_else(
409+
cfg.env.map_or_else(
410410
|| "(none)".into(),
411411
|e| e
412412
.iter()

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ pub enum McpClientError {
151151
Parse(#[from] url::ParseError),
152152
#[error(transparent)]
153153
Auth(#[from] crate::auth::AuthError),
154+
#[error("{0}")]
155+
MalformedConfig(&'static str),
154156
}
155157

156158
/// Decorates the method passed in with retry logic, but only if the [RunningService] has an
@@ -336,20 +338,23 @@ impl McpClientService {
336338
HttpTransport::WithAuth((transport, mut auth_client)) => {
337339
// The crate does not automatically refresh tokens when they expire. We
338340
// would need to handle that here
339-
let url = self.config.url.clone();
341+
let url = &backup_config.url;
340342
let service = match self.into_dyn().serve(transport).await.map_err(Box::new) {
341343
Ok(service) => service,
342344
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
343345
debug!("## mcp: first hand shake attempt failed: {:?}", e);
344346
let refresh_res = auth_client.refresh_token().await;
345347
let new_self = McpClientService::new(
346348
server_name.clone(),
347-
backup_config,
349+
backup_config.clone(),
348350
messenger_clone.clone(),
349351
);
350352

353+
let scopes = &backup_config.oauth_scopes;
354+
let timeout = backup_config.timeout;
355+
let headers = &backup_config.headers;
351356
let new_transport =
352-
get_http_transport(&os_clone, &url, Some(auth_client.auth_client.clone()), &*messenger_dup).await?;
357+
get_http_transport(&os_clone, url, timeout, scopes, headers,Some(auth_client.auth_client.clone()), &*messenger_dup).await?;
353358

354359
match new_transport {
355360
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
@@ -367,8 +372,8 @@ impl McpClientService {
367372
// again. We do this by deleting the cred
368373
// and discarding the client to trigger a full auth flow
369374
tokio::fs::remove_file(&auth_client.cred_full_path).await?;
370-
let new_transport =
371-
get_http_transport(&os_clone, &url, None, &*messenger_dup).await?;
375+
let new_transport =
376+
get_http_transport(&os_clone, url, timeout, scopes,headers,None, &*messenger_dup).await?;
372377

373378
match new_transport {
374379
HttpTransport::WithAuth((new_transport, new_auth_client)) => {
@@ -495,17 +500,32 @@ impl McpClientService {
495500
}
496501

497502
async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result<Transport, McpClientError> {
498-
// TODO: figure out what to do with headers
499503
let CustomToolConfig {
500-
r#type: transport_type,
504+
r#type,
501505
url,
506+
headers,
507+
oauth_scopes: scopes,
502508
command: command_as_str,
503509
args,
504510
env: config_envs,
511+
timeout,
505512
..
506513
} = &mut self.config;
507514

508-
match transport_type {
515+
let is_malformed_http = matches!(r#type, TransportType::Http) && url.is_empty();
516+
let is_malformed_stdio = matches!(r#type, TransportType::Stdio) && command_as_str.is_empty();
517+
518+
if is_malformed_http {
519+
return Err(McpClientError::MalformedConfig(
520+
"MCP config is malformed: transport type is specified to be http but url is empty",
521+
));
522+
} else if is_malformed_stdio {
523+
return Err(McpClientError::MalformedConfig(
524+
"MCP config is malformed: transport type is specified to be stdio but command is empty",
525+
));
526+
}
527+
528+
match r#type {
509529
TransportType::Stdio => {
510530
let expanded_cmd = canonicalizes_path(os, command_as_str)?;
511531
let command = Command::new(expanded_cmd).configure(|cmd| {
@@ -525,7 +545,7 @@ impl McpClientService {
525545
Ok(Transport::Stdio((tokio_child_process, child_stderr)))
526546
},
527547
TransportType::Http => {
528-
let http_transport = get_http_transport(os, url, None, messenger).await?;
548+
let http_transport = get_http_transport(os, url, *timeout, scopes, headers, None, messenger).await?;
529549

530550
Ok(Transport::Http(http_transport))
531551
},
@@ -562,7 +582,6 @@ impl McpClientService {
562582

563583
async fn on_tool_list_changed(&self, context: NotificationContext<RoleClient>) {
564584
let NotificationContext { peer, .. } = context;
565-
let _timeout = self.config.timeout;
566585

567586
paginated_fetch! {
568587
final_result_type: ListToolsResult,
@@ -578,7 +597,6 @@ impl McpClientService {
578597

579598
async fn on_prompt_list_changed(&self, context: NotificationContext<RoleClient>) {
580599
let NotificationContext { peer, .. } = context;
581-
let _timeout = self.config.timeout;
582600

583601
paginated_fetch! {
584602
final_result_type: ListPromptsResult,

crates/chat-cli/src/mcp_client/oauth_util.rs

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
use std::collections::HashMap;
12
use std::net::SocketAddr;
23
use std::path::PathBuf;
34
use std::pin::Pin;
45
use std::str::FromStr;
56
use std::sync::Arc;
67

7-
use http::StatusCode;
8+
use http::{
9+
HeaderMap,
10+
StatusCode,
11+
};
812
use http_body_util::Full;
913
use hyper::Response;
1014
use hyper::body::Bytes;
@@ -69,6 +73,8 @@ pub enum OauthUtilError {
6973
Directory(#[from] DirectoryError),
7074
#[error(transparent)]
7175
Reqwest(#[from] reqwest::Error),
76+
#[error("{0}")]
77+
Http(String),
7278
#[error("Malformed directory")]
7379
MalformDirectory,
7480
#[error("Missing credential")]
@@ -162,13 +168,16 @@ pub enum HttpTransport {
162168
WithoutAuth(WorkerTransport<StreamableHttpClientWorker<Client>>),
163169
}
164170

165-
fn get_scopes() -> &'static [&'static str] {
166-
&["openid", "mcp", "email", "profile"]
171+
pub fn get_default_scopes() -> &'static [&'static str] {
172+
&["openid", "email", "profile", "offline_access"]
167173
}
168174

169175
pub async fn get_http_transport(
170176
os: &Os,
171177
url: &str,
178+
timeout: u64,
179+
scopes: &[String],
180+
headers: &HashMap<String, String>,
172181
auth_client: Option<AuthClient<Client>>,
173182
messenger: &dyn Messenger,
174183
) -> Result<HttpTransport, OauthUtilError> {
@@ -178,16 +187,28 @@ pub async fn get_http_transport(
178187
let cred_full_path = cred_dir.join(format!("{key}.token.json"));
179188
let reg_full_path = cred_dir.join(format!("{key}.registration.json"));
180189

181-
let reqwest_client = reqwest::Client::default();
190+
let mut client_builder = reqwest::ClientBuilder::new().timeout(std::time::Duration::from_millis(timeout));
191+
if !headers.is_empty() {
192+
let headers = HeaderMap::try_from(headers).map_err(|e| OauthUtilError::Http(e.to_string()))?;
193+
client_builder = client_builder.default_headers(headers);
194+
};
195+
let reqwest_client = client_builder.build()?;
196+
182197
let probe_resp = reqwest_client.get(url.clone()).send().await?;
183198
match probe_resp.status() {
184199
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
185200
debug!("## mcp: requires auth, auth client passed in is {:?}", auth_client);
186201
let auth_client = match auth_client {
187202
Some(auth_client) => auth_client,
188203
None => {
189-
let am =
190-
get_auth_manager(url.clone(), cred_full_path.clone(), reg_full_path.clone(), messenger).await?;
204+
let am = get_auth_manager(
205+
url.clone(),
206+
cred_full_path.clone(),
207+
reg_full_path.clone(),
208+
scopes,
209+
messenger,
210+
)
211+
.await?;
191212
AuthClient::new(reqwest_client, am)
192213
},
193214
};
@@ -204,7 +225,12 @@ pub async fn get_http_transport(
204225
Ok(HttpTransport::WithAuth((transport, auth_dg)))
205226
},
206227
_ => {
207-
let transport = StreamableHttpClientTransport::from_uri(url.as_str());
228+
let transport =
229+
StreamableHttpClientTransport::with_client(reqwest_client, StreamableHttpClientTransportConfig {
230+
uri: url.as_str().into(),
231+
allow_stateless: false,
232+
..Default::default()
233+
});
208234

209235
Ok(HttpTransport::WithoutAuth(transport))
210236
},
@@ -215,6 +241,7 @@ async fn get_auth_manager(
215241
url: Url,
216242
cred_full_path: PathBuf,
217243
reg_full_path: PathBuf,
244+
scopes: &[String],
218245
messenger: &dyn Messenger,
219246
) -> Result<AuthorizationManager, OauthUtilError> {
220247
let cred_as_bytes = tokio::fs::read(&cred_full_path).await;
@@ -237,7 +264,7 @@ async fn get_auth_manager(
237264
_ => {
238265
info!("Error reading cached credentials");
239266
debug!("## mcp: cache read failed. constructing auth manager from scratch");
240-
let (am, redirect_uri) = get_auth_manager_impl(oauth_state, messenger).await?;
267+
let (am, redirect_uri) = get_auth_manager_impl(oauth_state, scopes, messenger).await?;
241268

242269
// Client registration is done in [start_authorization]
243270
// If we have gotten past that point that means we have the info to persist the
@@ -246,7 +273,10 @@ async fn get_auth_manager(
246273
let reg = Registration {
247274
client_id,
248275
client_secret: None,
249-
scopes: get_scopes().iter().map(|s| (*s).to_string()).collect::<Vec<_>>(),
276+
scopes: get_default_scopes()
277+
.iter()
278+
.map(|s| (*s).to_string())
279+
.collect::<Vec<_>>(),
250280
redirect_uri,
251281
};
252282
let reg_as_str = serde_json::to_string_pretty(&reg)?;
@@ -268,6 +298,7 @@ async fn get_auth_manager(
268298

269299
async fn get_auth_manager_impl(
270300
mut oauth_state: OAuthState,
301+
scopes: &[String],
271302
messenger: &dyn Messenger,
272303
) -> Result<(AuthorizationManager, String), OauthUtilError> {
273304
let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0));
@@ -278,7 +309,9 @@ async fn get_auth_manager_impl(
278309
info!("Listening on local host port {:?} for oauth", actual_addr);
279310

280311
let redirect_uri = format!("http://{}", actual_addr);
281-
oauth_state.start_authorization(get_scopes(), &redirect_uri).await?;
312+
let scopes_as_str = scopes.iter().map(String::as_str).collect::<Vec<_>>();
313+
let scopes_as_slice = scopes_as_str.as_slice();
314+
oauth_state.start_authorization(scopes_as_slice, &redirect_uri).await?;
282315

283316
let auth_url = oauth_state.get_authorization_url().await?;
284317
_ = messenger.send_oauth_link(auth_url).await;
@@ -333,9 +366,19 @@ async fn make_svc(
333366
let query = uri.query().unwrap_or("");
334367
let params: std::collections::HashMap<String, String> =
335368
url::form_urlencoded::parse(query.as_bytes()).into_owned().collect();
369+
debug!("## mcp: uri: {}, query: {}, params: {:?}", uri, query, params);
336370

337371
let self_clone = self.clone();
338372
Box::pin(async move {
373+
let error = params.get("error");
374+
let resp = if let Some(err) = error {
375+
mk_response(format!(
376+
"Oauth failed. Check url for precise reasons. Possible reasons: {err}.\nIf this is scope related. You can try configuring the server scopes to be an empty array via adding oauth_scopes: []"
377+
))
378+
} else {
379+
mk_response("You can close this page now".to_string())
380+
};
381+
339382
let code = params.get("code").cloned().unwrap_or_default();
340383
if let Some(sender) = self_clone
341384
.one_shot_sender
@@ -345,7 +388,8 @@ async fn make_svc(
345388
{
346389
sender.send(code).map_err(LoopBackError::Send)?;
347390
}
348-
mk_response("You can close this page now".to_string())
391+
392+
resp
349393
})
350394
}
351395
}

schemas/agent-v1.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,27 @@
7373
"type": "string",
7474
"default": ""
7575
},
76+
"headers": {
77+
"description": "HTTP headers to include when communicating with HTTP-based MCP servers",
78+
"type": "object",
79+
"additionalProperties": {
80+
"type": "string"
81+
},
82+
"default": {}
83+
},
84+
"oauthScopes": {
85+
"description": "Scopes with which oauth is done",
86+
"type": "array",
87+
"items": {
88+
"type": "string"
89+
},
90+
"default": [
91+
"openid",
92+
"email",
93+
"profile",
94+
"offline_access"
95+
]
96+
},
7697
"command": {
7798
"description": "The command string used to initialize the mcp server",
7899
"type": "string",

0 commit comments

Comments
 (0)