Skip to content

Commit 6c2e6e2

Browse files
authored
feat(oauth): fixes + cache client credentials (#157)
credentials optimize
1 parent 52c0651 commit 6c2e6e2

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

crates/rmcp/src/transport/auth.rs

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::{
2+
collections::HashMap,
23
sync::Arc,
34
time::{Duration, Instant},
45
};
@@ -70,6 +71,9 @@ pub struct AuthorizationMetadata {
7071
pub issuer: Option<String>,
7172
pub jwks_uri: Option<String>,
7273
pub scopes_supported: Option<Vec<String>>,
74+
// allow additional fields
75+
#[serde(flatten)]
76+
pub additional_fields: HashMap<String, serde_json::Value>,
7377
}
7478

7579
/// oauth2 client config
@@ -100,6 +104,7 @@ type OAuthClient = oauth2::Client<
100104
oauth2::EndpointNotSet,
101105
oauth2::EndpointSet,
102106
>;
107+
type Credentials = (String, Option<OAuthTokenResponse>);
103108

104109
/// oauth2 auth manager
105110
pub struct AuthorizationManager {
@@ -124,9 +129,12 @@ pub struct ClientRegistrationRequest {
124129
#[derive(Debug, Clone, Serialize, Deserialize)]
125130
pub struct ClientRegistrationResponse {
126131
pub client_id: String,
127-
pub client_secret: String,
132+
pub client_secret: Option<String>,
128133
pub client_name: String,
129134
pub redirect_uris: Vec<String>,
135+
// allow additional fields
136+
#[serde(flatten)]
137+
pub additional_fields: HashMap<String, serde_json::Value>,
130138
}
131139

132140
impl AuthorizationManager {
@@ -191,10 +199,22 @@ impl AuthorizationManager {
191199
issuer: None,
192200
jwks_uri: None,
193201
scopes_supported: None,
202+
additional_fields: HashMap::new(),
194203
})
195204
}
196205
}
197206

207+
/// get client id and credentials
208+
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
209+
let credentials = self.credentials.read().await;
210+
let client_id = self
211+
.oauth_client
212+
.as_ref()
213+
.ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?
214+
.client_id();
215+
Ok((client_id.to_string(), credentials.clone()))
216+
}
217+
198218
/// configure oauth2 client with client credentials
199219
pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> {
200220
if self.metadata.is_none() {
@@ -287,6 +307,7 @@ impl AuthorizationManager {
287307
status, error_text
288308
)));
289309
}
310+
290311
debug!("registration response: {:?}", response);
291312
let reg_response = match response.json::<ClientRegistrationResponse>().await {
292313
Ok(response) => response,
@@ -301,7 +322,7 @@ impl AuthorizationManager {
301322

302323
let config = OAuthClientConfig {
303324
client_id: reg_response.client_id,
304-
client_secret: Some(reg_response.client_secret),
325+
client_secret: reg_response.client_secret,
305326
redirect_uri: redirect_uri.to_string(),
306327
scopes: vec![],
307328
};
@@ -310,6 +331,18 @@ impl AuthorizationManager {
310331
Ok(config)
311332
}
312333

334+
/// use provided client id to configure oauth2 client instead of dynamic registration
335+
/// this is useful when you have a stored client id from previous registration
336+
pub fn configure_client_id(&mut self, client_id: &str) -> Result<(), AuthError> {
337+
let config = OAuthClientConfig {
338+
client_id: client_id.to_string(),
339+
client_secret: None,
340+
scopes: vec![],
341+
redirect_uri: self.base_url.to_string(),
342+
};
343+
self.configure_client(config)
344+
}
345+
313346
/// generate authorization url
314347
pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result<String, AuthError> {
315348
let oauth_client = self
@@ -513,6 +546,11 @@ impl AuthorizationSession {
513546
})
514547
}
515548

549+
/// get client_id and credentials
550+
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
551+
self.auth_manager.get_credentials().await
552+
}
553+
516554
/// get authorization url
517555
pub fn get_authorization_url(&self) -> &str {
518556
&self.auth_url
@@ -590,9 +628,54 @@ impl OAuthState {
590628
if let Some(client) = client {
591629
manager.with_client(client)?;
592630
}
631+
593632
Ok(OAuthState::Unauthorized(manager))
594633
}
595634

635+
/// Get client_id and OAuth credentials
636+
pub async fn get_credentials(&self) -> Result<Credentials, AuthError> {
637+
// return client_id and credentials
638+
match self {
639+
OAuthState::Unauthorized(manager) | OAuthState::Authorized(manager) => {
640+
manager.get_credentials().await
641+
}
642+
OAuthState::Session(session) => session.get_credentials().await,
643+
OAuthState::AuthorizedHttpClient(client) => client.auth_manager.get_credentials().await,
644+
}
645+
}
646+
647+
/// Manually set credentials and move into authorized state
648+
/// Useful if you're caching credentials externally and wish to reuse them
649+
pub async fn set_credentials(
650+
&mut self,
651+
client_id: &str,
652+
credentials: OAuthTokenResponse,
653+
) -> Result<(), AuthError> {
654+
if let OAuthState::Unauthorized(manager) = self {
655+
let mut manager = std::mem::replace(
656+
manager,
657+
AuthorizationManager::new("http://localhost").await?,
658+
);
659+
660+
// write credentials
661+
*manager.credentials.write().await = Some(credentials);
662+
663+
// discover metadata
664+
let metadata = manager.discover_metadata().await?;
665+
manager.metadata = Some(metadata);
666+
667+
// set client id and secret
668+
manager.configure_client_id(client_id)?;
669+
670+
*self = OAuthState::Authorized(manager);
671+
Ok(())
672+
} else {
673+
Err(AuthError::InternalError(
674+
"Cannot set credentials in this state".to_string(),
675+
))
676+
}
677+
}
678+
596679
/// start authorization
597680
pub async fn start_authorization(
598681
&mut self,

examples/servers/src/mcp_oauth_server.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ async fn oauth_authorization_server() -> impl IntoResponse {
525525
registration_endpoint: format!("http://{}/oauth/register", BIND_ADDRESS),
526526
issuer: Some(BIND_ADDRESS.to_string()),
527527
jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)),
528+
additional_fields: HashMap::new(),
528529
};
529530
debug!("metadata: {:?}", metadata);
530531
(StatusCode::OK, Json(metadata))
@@ -567,9 +568,10 @@ async fn oauth_register(
567568
// return client information
568569
let response = ClientRegistrationResponse {
569570
client_id,
570-
client_secret,
571+
client_secret: Some(client_secret),
571572
client_name: req.client_name,
572573
redirect_uris: req.redirect_uris,
574+
additional_fields: HashMap::new(),
573575
};
574576

575577
(StatusCode::CREATED, Json(response)).into_response()

0 commit comments

Comments
 (0)