Skip to content

Commit bf183f0

Browse files
authored
fix(mcp): hardcodes client id for oauth (#2976)
1 parent 13fedbd commit bf183f0

File tree

1 file changed

+81
-6
lines changed

1 file changed

+81
-6
lines changed

crates/chat-cli/src/mcp_client/oauth_util.rs

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use rmcp::transport::streamable_http_client::{
2828
};
2929
use rmcp::transport::{
3030
AuthorizationManager,
31+
AuthorizationSession,
3132
StreamableHttpClientTransport,
3233
WorkerTransport,
3334
};
@@ -194,10 +195,12 @@ pub async fn get_http_transport(
194195
};
195196
let reqwest_client = client_builder.build()?;
196197

197-
let probe_resp = reqwest_client.get(url.clone()).send().await?;
198+
// The probe request, like all other request, should adhere to the standards as per https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
199+
let mut probe_request = reqwest_client.post(url.clone());
200+
probe_request = probe_request.header("Accept", "application/json, text/event-stream");
201+
let probe_resp = probe_request.send().await?;
198202
match probe_resp.status() {
199203
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
200-
debug!("## mcp: requires auth, auth client passed in is {:?}", auth_client);
201204
let auth_client = match auth_client {
202205
Some(auth_client) => auth_client,
203206
None => {
@@ -215,20 +218,19 @@ pub async fn get_http_transport(
215218
let transport =
216219
StreamableHttpClientTransport::with_client(auth_client.clone(), StreamableHttpClientTransportConfig {
217220
uri: url.as_str().into(),
218-
allow_stateless: false,
221+
allow_stateless: true,
219222
..Default::default()
220223
});
221224

222225
let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client);
223-
debug!("## mcp: transport obtained");
224226

225227
Ok(HttpTransport::WithAuth((transport, auth_dg)))
226228
},
227229
_ => {
228230
let transport =
229231
StreamableHttpClientTransport::with_client(reqwest_client, StreamableHttpClientTransportConfig {
230232
uri: url.as_str().into(),
231-
allow_stateless: false,
233+
allow_stateless: true,
232234
..Default::default()
233235
});
234236

@@ -311,7 +313,7 @@ async fn get_auth_manager_impl(
311313
let redirect_uri = format!("http://{}", actual_addr);
312314
let scopes_as_str = scopes.iter().map(String::as_str).collect::<Vec<_>>();
313315
let scopes_as_slice = scopes_as_str.as_slice();
314-
oauth_state.start_authorization(scopes_as_slice, &redirect_uri).await?;
316+
start_authorization(&mut oauth_state, scopes_as_slice, &redirect_uri).await?;
315317

316318
let auth_url = oauth_state.get_authorization_url().await?;
317319
_ = messenger.send_oauth_link(auth_url).await;
@@ -332,6 +334,79 @@ pub fn compute_key(rs: &Url) -> String {
332334
format!("{:x}", hasher.finalize())
333335
}
334336

337+
/// This is our own implementation of [OAuthState::start_authorization].
338+
/// This differs from [OAuthState::start_authorization] by assigning our own client_id for DCR.
339+
/// We need this because the SDK hardcodes their own client id. And some servers will use client_id
340+
/// to identify if a client is even allowed to perform the auth handshake.
341+
async fn start_authorization(
342+
oauth_state: &mut OAuthState,
343+
scopes: &[&str],
344+
redirect_uri: &str,
345+
) -> Result<(), OauthUtilError> {
346+
// DO NOT CHANGE THIS
347+
// This string has significance as it is used for remote servers to identify us
348+
const CLIENT_ID: &str = "Q DEV CLI";
349+
350+
let stub_cred = get_stub_credentials()?;
351+
oauth_state.set_credentials(CLIENT_ID, stub_cred).await?;
352+
353+
// The setting of credentials would put the oauth state into authorize.
354+
if let OAuthState::Authorized(auth_manager) = oauth_state {
355+
// set redirect uri
356+
let config = OAuthClientConfig {
357+
client_id: CLIENT_ID.to_string(),
358+
client_secret: None,
359+
scopes: scopes.iter().map(|s| (*s).to_string()).collect(),
360+
redirect_uri: redirect_uri.to_string(),
361+
};
362+
363+
// try to dynamic register client
364+
let config = match auth_manager.register_client(CLIENT_ID, redirect_uri).await {
365+
Ok(config) => config,
366+
Err(e) => {
367+
eprintln!("Dynamic registration failed: {}", e);
368+
// fallback to default config
369+
config
370+
},
371+
};
372+
// reset client config
373+
auth_manager.configure_client(config)?;
374+
let auth_url = auth_manager.get_authorization_url(scopes).await?;
375+
376+
let mut stub_auth_manager = AuthorizationManager::new("http://localhost").await?;
377+
std::mem::swap(auth_manager, &mut stub_auth_manager);
378+
379+
let session = AuthorizationSession {
380+
auth_manager: stub_auth_manager,
381+
auth_url,
382+
redirect_uri: redirect_uri.to_string(),
383+
};
384+
385+
let mut new_oauth_state = OAuthState::Session(session);
386+
std::mem::swap(oauth_state, &mut new_oauth_state);
387+
} else {
388+
unreachable!()
389+
}
390+
391+
Ok(())
392+
}
393+
394+
/// This looks silly but [rmcp::transport::auth::OAuthTokenResponse] is private and there is no
395+
/// other way to create this directly
396+
fn get_stub_credentials() -> Result<OAuthTokenResponse, serde_json::Error> {
397+
const STUB_TOKEN: &str = r#"
398+
{
399+
"access_token": "stub",
400+
"token_type": "bearer",
401+
"expires_in": 3600,
402+
"refresh_token": "stub",
403+
"scope": "stub"
404+
}
405+
"#;
406+
407+
serde_json::from_str::<OAuthTokenResponse>(STUB_TOKEN)
408+
}
409+
335410
async fn make_svc(
336411
one_shot_sender: Sender<String>,
337412
socket_addr: SocketAddr,

0 commit comments

Comments
 (0)