@@ -14,6 +14,7 @@ use reqwest::Client;
1414use rmcp:: serde_json;
1515use rmcp:: transport:: auth:: {
1616 AuthClient ,
17+ OAuthClientConfig ,
1718 OAuthState ,
1819 OAuthTokenResponse ,
1920} ;
@@ -26,6 +27,10 @@ use rmcp::transport::{
2627 StreamableHttpClientTransport ,
2728 WorkerTransport ,
2829} ;
30+ use serde:: {
31+ Deserialize ,
32+ Serialize ,
33+ } ;
2934use sha2:: {
3035 Digest ,
3136 Sha256 ,
@@ -64,6 +69,8 @@ pub enum OauthUtilError {
6469 Directory ( #[ from] DirectoryError ) ,
6570 #[ error( transparent) ]
6671 Reqwest ( #[ from] reqwest:: Error ) ,
72+ #[ error( "Malformed directory" ) ]
73+ MalformDirectory ,
6774}
6875
6976/// A guard that automatically cancels the cancellation token when dropped.
@@ -79,6 +86,27 @@ impl Drop for LoopBackDropGuard {
7986 }
8087}
8188
89+ /// This is modeled after [OAuthClientConfig]
90+ /// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize
91+ #[ derive( Clone , Serialize , Deserialize , Debug ) ]
92+ pub struct Registration {
93+ pub client_id : String ,
94+ pub client_secret : Option < String > ,
95+ pub scopes : Vec < String > ,
96+ pub redirect_uri : String ,
97+ }
98+
99+ impl From < OAuthClientConfig > for Registration {
100+ fn from ( value : OAuthClientConfig ) -> Self {
101+ Self {
102+ client_id : value. client_id ,
103+ client_secret : value. client_secret ,
104+ scopes : value. scopes ,
105+ redirect_uri : value. redirect_uri ,
106+ }
107+ }
108+ }
109+
82110/// A guard that manages the lifecycle of an authenticated MCP client and automatically
83111/// persists OAuth credentials when dropped.
84112///
@@ -164,6 +192,10 @@ pub enum HttpTransport {
164192 WithoutAuth ( WorkerTransport < StreamableHttpClientWorker < Client > > ) ,
165193}
166194
195+ fn get_scopes ( ) -> & ' static [ & ' static str ] {
196+ & [ "openid" , "mcp" , "email" , "profile" ]
197+ }
198+
167199pub async fn get_http_transport (
168200 os : & Os ,
169201 delete_cache : bool ,
@@ -175,6 +207,7 @@ pub async fn get_http_transport(
175207 let url = Url :: from_str ( url) ?;
176208 let key = compute_key ( & url) ;
177209 let cred_full_path = cred_dir. join ( format ! ( "{key}.token.json" ) ) ;
210+ let reg_full_path = cred_dir. join ( format ! ( "{key}.registration.json" ) ) ;
178211
179212 if delete_cache && cred_full_path. is_file ( ) {
180213 tokio:: fs:: remove_file ( & cred_full_path) . await ?;
@@ -188,7 +221,8 @@ pub async fn get_http_transport(
188221 let auth_client = match auth_client {
189222 Some ( auth_client) => auth_client,
190223 None => {
191- let am = get_auth_manager ( url. clone ( ) , cred_full_path. clone ( ) , messenger) . await ?;
224+ let am =
225+ get_auth_manager ( url. clone ( ) , cred_full_path. clone ( ) , reg_full_path. clone ( ) , messenger) . await ?;
192226 AuthClient :: new ( reqwest_client, am)
193227 } ,
194228 } ;
@@ -215,45 +249,67 @@ pub async fn get_http_transport(
215249async fn get_auth_manager (
216250 url : Url ,
217251 cred_full_path : PathBuf ,
252+ reg_full_path : PathBuf ,
218253 messenger : & dyn Messenger ,
219254) -> Result < AuthorizationManager , OauthUtilError > {
220- let content_as_bytes = tokio:: fs:: read ( & cred_full_path) . await ;
255+ let cred_as_bytes = tokio:: fs:: read ( & cred_full_path) . await ;
256+ let reg_as_bytes = tokio:: fs:: read ( & reg_full_path) . await ;
221257 let mut oauth_state = OAuthState :: new ( url, None ) . await ?;
222258
223- match content_as_bytes {
224- Ok ( bytes) => {
225- let token = serde_json:: from_slice :: < OAuthTokenResponse > ( & bytes) ?;
259+ match ( cred_as_bytes, reg_as_bytes) {
260+ ( Ok ( cred_as_bytes) , Ok ( reg_as_bytes) ) => {
261+ let token = serde_json:: from_slice :: < OAuthTokenResponse > ( & cred_as_bytes) ?;
262+ let reg = serde_json:: from_slice :: < Registration > ( & reg_as_bytes) ?;
226263
227- oauth_state. set_credentials ( "id" , token) . await ?;
264+ oauth_state. set_credentials ( & reg . client_id , token) . await ?;
228265
229266 debug ! ( "## mcp: credentials set with cache" ) ;
230267
231268 Ok ( oauth_state
232269 . into_authorization_manager ( )
233270 . ok_or ( OauthUtilError :: MissingAuthorizationManager ) ?)
234271 } ,
235- Err ( e ) => {
236- info ! ( "Error reading cached credentials: {e} " ) ;
272+ _ => {
273+ info ! ( "Error reading cached credentials" ) ;
237274 debug ! ( "## mcp: cache read failed. constructing auth manager from scratch" ) ;
238- get_auth_manager_impl ( oauth_state, messenger) . await
275+ let ( am, redirect_uri) = get_auth_manager_impl ( oauth_state, messenger) . await ?;
276+
277+ // Client registration is done in [start_authorization]
278+ // If we have gotten past that point that means we have the info to persist the
279+ // registration on disk. These are info that we need to refresh stake
280+ // tokens. This is in contrast to tokens, which we only persist when we drop
281+ // the client (because that way we can write once and ensure what is on the
282+ // disk always the most up to date)
283+ let ( client_id, _credentials) = am. get_credentials ( ) . await ?;
284+ let reg = Registration {
285+ client_id,
286+ client_secret : None ,
287+ scopes : get_scopes ( ) . iter ( ) . map ( |s| ( * s) . to_string ( ) ) . collect :: < Vec < _ > > ( ) ,
288+ redirect_uri,
289+ } ;
290+ let reg_as_str = serde_json:: to_string_pretty ( & reg) ?;
291+ let reg_parent_path = reg_full_path. parent ( ) . ok_or ( OauthUtilError :: MalformDirectory ) ?;
292+ tokio:: fs:: create_dir ( reg_parent_path) . await ?;
293+ tokio:: fs:: write ( reg_full_path, & reg_as_str) . await ?;
294+
295+ Ok ( am)
239296 } ,
240297 }
241298}
242299
243300async fn get_auth_manager_impl (
244301 mut oauth_state : OAuthState ,
245302 messenger : & dyn Messenger ,
246- ) -> Result < AuthorizationManager , OauthUtilError > {
303+ ) -> Result < ( AuthorizationManager , String ) , OauthUtilError > {
247304 let socket_addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , 0 ) ) ;
248305 let cancellation_token = tokio_util:: sync:: CancellationToken :: new ( ) ;
249306 let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < String > ( ) ;
250307
251308 let ( actual_addr, _dg) = make_svc ( tx, socket_addr, cancellation_token) . await ?;
252309 info ! ( "Listening on local host port {:?} for oauth" , actual_addr) ;
253310
254- oauth_state
255- . start_authorization ( & [ "mcp" , "profile" , "email" ] , & format ! ( "http://{}" , actual_addr) )
256- . await ?;
311+ let redirect_uri = format ! ( "http://{}" , actual_addr) ;
312+ oauth_state. start_authorization ( get_scopes ( ) , & redirect_uri) . await ?;
257313
258314 let auth_url = oauth_state. get_authorization_url ( ) . await ?;
259315 _ = messenger. send_oauth_link ( auth_url) . await ;
@@ -264,7 +320,7 @@ async fn get_auth_manager_impl(
264320 . into_authorization_manager ( )
265321 . ok_or ( OauthUtilError :: MissingAuthorizationManager ) ?;
266322
267- Ok ( am )
323+ Ok ( ( am , redirect_uri ) )
268324}
269325
270326pub fn compute_key ( rs : & Url ) -> String {
@@ -320,7 +376,7 @@ async fn make_svc(
320376 {
321377 sender. send ( code) . map_err ( LoopBackError :: Send ) ?;
322378 }
323- mk_response ( "Auth code sent " . to_string ( ) )
379+ mk_response ( "You can close this page now " . to_string ( ) )
324380 } )
325381 }
326382 }
0 commit comments