@@ -5,7 +5,6 @@ use std::ops::{
55 DerefMut ,
66} ;
77use std:: process:: Stdio ;
8- use std:: str:: FromStr ;
98
109use regex:: Regex ;
1110use reqwest:: Client ;
@@ -27,15 +26,9 @@ use rmcp::service::{
2726 NotificationContext ,
2827} ;
2928use rmcp:: transport:: auth:: AuthClient ;
30- use rmcp:: transport:: streamable_http_client:: {
31- StreamableHttpClientTransportConfig ,
32- StreamableHttpClientWorker ,
33- } ;
3429use rmcp:: transport:: {
3530 ConfigureCommandExt ,
36- StreamableHttpClientTransport ,
3731 TokioChildProcess ,
38- WorkerTransport ,
3932} ;
4033use rmcp:: {
4134 ErrorData ,
@@ -50,26 +43,25 @@ use tokio::process::{
5043 Command ,
5144} ;
5245use tokio:: task:: JoinHandle ;
53- use tracing:: error;
54- use url:: Url ;
46+ use tracing:: {
47+ error,
48+ info,
49+ } ;
5550
5651use super :: messenger:: Messenger ;
52+ use super :: oauth_util:: HttpTransport ;
5753use super :: {
5854 AuthClientDropGuard ,
5955 OauthUtilError ,
60- compute_key ,
56+ get_http_transport ,
6157} ;
6258use crate :: cli:: chat:: server_messenger:: ServerMessenger ;
6359use crate :: cli:: chat:: tools:: custom_tool:: {
6460 CustomToolConfig ,
6561 TransportType ,
6662} ;
67- use crate :: mcp_client:: get_auth_manager;
6863use crate :: os:: Os ;
69- use crate :: util:: directories:: {
70- DirectoryError ,
71- get_mcp_auth_dir,
72- } ;
64+ use crate :: util:: directories:: DirectoryError ;
7365
7466/// Fetches all pages of specified resources from a server
7567macro_rules! paginated_fetch {
@@ -152,9 +144,11 @@ pub enum McpClientError {
152144 #[ error( transparent) ]
153145 Directory ( #[ from] DirectoryError ) ,
154146 #[ error( transparent) ]
155- Oauth ( #[ from] OauthUtilError ) ,
147+ OauthUtil ( #[ from] OauthUtilError ) ,
156148 #[ error( transparent) ]
157149 Parse ( #[ from] url:: ParseError ) ,
150+ #[ error( transparent) ]
151+ Auth ( #[ from] crate :: auth:: AuthError ) ,
158152}
159153
160154pub struct RunningService {
@@ -163,6 +157,12 @@ pub struct RunningService {
163157 pub auth_dropguard : Option < AuthClientDropGuard > ,
164158}
165159
160+ impl RunningService {
161+ pub fn get_auth_client ( & self ) -> Option < AuthClient < Client > > {
162+ self . auth_dropguard . as_ref ( ) . map ( |a| a. auth_client . clone ( ) )
163+ }
164+ }
165+
166166impl Deref for RunningService {
167167 type Target = rmcp:: service:: RunningService < RoleClient , Box < dyn DynService < RoleClient > > > ;
168168
@@ -177,10 +177,6 @@ impl DerefMut for RunningService {
177177 }
178178}
179179
180- pub type HttpTransport = (
181- WorkerTransport < StreamableHttpClientWorker < AuthClient < Client > > > ,
182- AuthClientDropGuard ,
183- ) ;
184180pub type StdioTransport = ( TokioChildProcess , Option < ChildStderr > ) ;
185181
186182// TODO: add sse support (even though it's deprecated)
@@ -231,11 +227,13 @@ impl McpClientService {
231227 let os_clone = os. clone ( ) ;
232228
233229 let handle: JoinHandle < Result < RunningService , McpClientError > > = tokio:: spawn ( async move {
234- let messenger_clone = self . messenger . duplicate ( ) ;
230+ let messenger_clone = self . messenger . clone ( ) ;
235231 let server_name = self . server_name . clone ( ) ;
232+ let backup_config = self . config . clone ( ) ;
236233
237234 let result: Result < _ , McpClientError > = async {
238- let ( service, stderr, auth_client) = match self . get_transport ( & os_clone, & messenger_clone) . await ? {
235+ let messenger_dup = messenger_clone. duplicate ( ) ;
236+ let ( service, stderr, auth_client) = match self . get_transport ( & os_clone, & * messenger_dup) . await ? {
239237 Transport :: Stdio ( ( child_process, stderr) ) => {
240238 let service = self
241239 . into_dyn ( )
@@ -245,10 +243,71 @@ impl McpClientService {
245243
246244 ( service, stderr, None )
247245 } ,
248- Transport :: Http ( ( transport, auth_dg) ) => {
249- let service = self . into_dyn ( ) . serve ( transport) . await . map_err ( Box :: new) ?;
246+ Transport :: Http ( http_transport) => {
247+ match http_transport {
248+ HttpTransport :: WithAuth ( ( transport, mut auth_dg) ) => {
249+ // The crate does not automatically refresh tokens when they expire. We
250+ // would need to handle that here
251+ let url = self . config . url . clone ( ) ;
252+ let service = match self . into_dyn ( ) . serve ( transport) . await . map_err ( Box :: new) {
253+ Ok ( service) => service,
254+ Err ( e) if matches ! ( * e, ClientInitializeError :: ConnectionClosed ( _) ) => {
255+ let refresh_res =
256+ auth_dg. auth_client . auth_manager . lock ( ) . await . refresh_token ( ) . await ;
257+ let new_self = McpClientService :: new (
258+ server_name. clone ( ) ,
259+ backup_config,
260+ messenger_clone. clone ( ) ,
261+ ) ;
262+
263+ let new_transport =
264+ get_http_transport ( & os_clone, true , & url, & * messenger_dup) . await ?;
265+
266+ match new_transport {
267+ HttpTransport :: WithAuth ( ( new_transport, new_auth_dg) ) => {
268+ auth_dg. should_write = false ;
269+ auth_dg = new_auth_dg;
270+
271+ match refresh_res {
272+ Ok ( _token) => {
273+ new_self. into_dyn ( ) . serve ( new_transport) . await . map_err ( Box :: new) ?
274+ } ,
275+ Err ( e) => {
276+ info ! ( "Retry for http transport failed {e}. Possible reauth needed" ) ;
277+ // This could be because the refresh token is expired, in which
278+ // case we would need to have user go through the auth flow
279+ // again
280+ let new_transport =
281+ get_http_transport ( & os_clone, true , & url, & * messenger_dup) . await ?;
282+
283+ match new_transport {
284+ HttpTransport :: WithAuth ( ( new_transport, new_auth_dg) ) => {
285+ auth_dg = new_auth_dg;
286+ auth_dg. should_write = false ;
287+ new_self. into_dyn ( ) . serve ( new_transport) . await . map_err ( Box :: new) ?
288+ } ,
289+ HttpTransport :: WithoutAuth ( new_transport) => {
290+ new_self. into_dyn ( ) . serve ( new_transport) . await . map_err ( Box :: new) ?
291+ } ,
292+ }
293+ } ,
294+ }
295+ } ,
296+ HttpTransport :: WithoutAuth ( new_transport) =>
297+ new_self. into_dyn ( ) . serve ( new_transport) . await . map_err ( Box :: new) ?,
298+ }
299+ } ,
300+ Err ( e) => return Err ( e. into ( ) ) ,
301+ } ;
302+
303+ ( service, None , Some ( auth_dg) )
304+ } ,
305+ HttpTransport :: WithoutAuth ( transport) => {
306+ let service = self . into_dyn ( ) . serve ( transport) . await . map_err ( Box :: new) ?;
250307
251- ( service, None , Some ( auth_dg) )
308+ ( service, None , None )
309+ } ,
310+ }
252311 } ,
253312 } ;
254313
@@ -346,7 +405,7 @@ impl McpClientService {
346405 Ok ( InitializedMcpClient :: Pending ( handle) )
347406 }
348407
349- async fn get_transport ( & mut self , os : & Os , messenger : & Box < dyn Messenger > ) -> Result < Transport , McpClientError > {
408+ async fn get_transport ( & mut self , os : & Os , messenger : & dyn Messenger ) -> Result < Transport , McpClientError > {
350409 // TODO: figure out what to do with headers
351410 let CustomToolConfig {
352411 r#type : transport_type,
@@ -376,26 +435,9 @@ impl McpClientService {
376435 Ok ( Transport :: Stdio ( ( tokio_child_process, child_stderr) ) )
377436 } ,
378437 TransportType :: Http => {
379- let cred_dir = get_mcp_auth_dir ( os) ?;
380- let url = Url :: from_str ( url) ?;
381- let key = compute_key ( & url) ;
382- let cred_full_path = cred_dir. join ( format ! ( "{key}.token.json" ) ) ;
383-
384- let am = get_auth_manager ( url. clone ( ) , cred_full_path. clone ( ) , messenger) . await ?;
385- let client = AuthClient :: new ( reqwest:: Client :: default ( ) , am) ;
386- let transport =
387- StreamableHttpClientTransport :: with_client ( client. clone ( ) , StreamableHttpClientTransportConfig {
388- uri : url. as_str ( ) . into ( ) ,
389- allow_stateless : false ,
390- ..Default :: default ( )
391- } ) ;
392-
393- let auth_dg = AuthClientDropGuard {
394- path : cred_full_path,
395- auth_client : client,
396- } ;
438+ let http_transport = get_http_transport ( os, false , url, messenger) . await ?;
397439
398- Ok ( Transport :: Http ( ( transport , auth_dg ) ) )
440+ Ok ( Transport :: Http ( http_transport ) )
399441 } ,
400442 }
401443 }
0 commit comments