1+ use std:: collections:: HashMap ;
12use std:: net:: SocketAddr ;
23use std:: path:: PathBuf ;
34use std:: pin:: Pin ;
45use std:: str:: FromStr ;
56use std:: sync:: Arc ;
67
7- use http:: StatusCode ;
8+ use http:: {
9+ HeaderMap ,
10+ StatusCode ,
11+ } ;
812use http_body_util:: Full ;
913use hyper:: Response ;
1014use hyper:: body:: Bytes ;
@@ -69,6 +73,8 @@ pub enum OauthUtilError {
6973 Directory ( #[ from] DirectoryError ) ,
7074 #[ error( transparent) ]
7175 Reqwest ( #[ from] reqwest:: Error ) ,
76+ #[ error( "{0}" ) ]
77+ Http ( String ) ,
7278 #[ error( "Malformed directory" ) ]
7379 MalformDirectory ,
7480 #[ error( "Missing credential" ) ]
@@ -162,13 +168,16 @@ pub enum HttpTransport {
162168 WithoutAuth ( WorkerTransport < StreamableHttpClientWorker < Client > > ) ,
163169}
164170
165- fn get_scopes ( ) -> & ' static [ & ' static str ] {
166- & [ "openid" , "mcp " , "email " , "profile " ]
171+ pub fn get_default_scopes ( ) -> & ' static [ & ' static str ] {
172+ & [ "openid" , "email " , "profile " , "offline_access " ]
167173}
168174
169175pub async fn get_http_transport (
170176 os : & Os ,
171177 url : & str ,
178+ timeout : u64 ,
179+ scopes : & [ String ] ,
180+ headers : & HashMap < String , String > ,
172181 auth_client : Option < AuthClient < Client > > ,
173182 messenger : & dyn Messenger ,
174183) -> Result < HttpTransport , OauthUtilError > {
@@ -178,16 +187,28 @@ pub async fn get_http_transport(
178187 let cred_full_path = cred_dir. join ( format ! ( "{key}.token.json" ) ) ;
179188 let reg_full_path = cred_dir. join ( format ! ( "{key}.registration.json" ) ) ;
180189
181- let reqwest_client = reqwest:: Client :: default ( ) ;
190+ let mut client_builder = reqwest:: ClientBuilder :: new ( ) . timeout ( std:: time:: Duration :: from_millis ( timeout) ) ;
191+ if !headers. is_empty ( ) {
192+ let headers = HeaderMap :: try_from ( headers) . map_err ( |e| OauthUtilError :: Http ( e. to_string ( ) ) ) ?;
193+ client_builder = client_builder. default_headers ( headers) ;
194+ } ;
195+ let reqwest_client = client_builder. build ( ) ?;
196+
182197 let probe_resp = reqwest_client. get ( url. clone ( ) ) . send ( ) . await ?;
183198 match probe_resp. status ( ) {
184199 StatusCode :: UNAUTHORIZED | StatusCode :: FORBIDDEN => {
185200 debug ! ( "## mcp: requires auth, auth client passed in is {:?}" , auth_client) ;
186201 let auth_client = match auth_client {
187202 Some ( auth_client) => auth_client,
188203 None => {
189- let am =
190- get_auth_manager ( url. clone ( ) , cred_full_path. clone ( ) , reg_full_path. clone ( ) , messenger) . await ?;
204+ let am = get_auth_manager (
205+ url. clone ( ) ,
206+ cred_full_path. clone ( ) ,
207+ reg_full_path. clone ( ) ,
208+ scopes,
209+ messenger,
210+ )
211+ . await ?;
191212 AuthClient :: new ( reqwest_client, am)
192213 } ,
193214 } ;
@@ -204,7 +225,12 @@ pub async fn get_http_transport(
204225 Ok ( HttpTransport :: WithAuth ( ( transport, auth_dg) ) )
205226 } ,
206227 _ => {
207- let transport = StreamableHttpClientTransport :: from_uri ( url. as_str ( ) ) ;
228+ let transport =
229+ StreamableHttpClientTransport :: with_client ( reqwest_client, StreamableHttpClientTransportConfig {
230+ uri : url. as_str ( ) . into ( ) ,
231+ allow_stateless : false ,
232+ ..Default :: default ( )
233+ } ) ;
208234
209235 Ok ( HttpTransport :: WithoutAuth ( transport) )
210236 } ,
@@ -215,6 +241,7 @@ async fn get_auth_manager(
215241 url : Url ,
216242 cred_full_path : PathBuf ,
217243 reg_full_path : PathBuf ,
244+ scopes : & [ String ] ,
218245 messenger : & dyn Messenger ,
219246) -> Result < AuthorizationManager , OauthUtilError > {
220247 let cred_as_bytes = tokio:: fs:: read ( & cred_full_path) . await ;
@@ -237,7 +264,7 @@ async fn get_auth_manager(
237264 _ => {
238265 info ! ( "Error reading cached credentials" ) ;
239266 debug ! ( "## mcp: cache read failed. constructing auth manager from scratch" ) ;
240- let ( am, redirect_uri) = get_auth_manager_impl ( oauth_state, messenger) . await ?;
267+ let ( am, redirect_uri) = get_auth_manager_impl ( oauth_state, scopes , messenger) . await ?;
241268
242269 // Client registration is done in [start_authorization]
243270 // If we have gotten past that point that means we have the info to persist the
@@ -246,7 +273,10 @@ async fn get_auth_manager(
246273 let reg = Registration {
247274 client_id,
248275 client_secret : None ,
249- scopes : get_scopes ( ) . iter ( ) . map ( |s| ( * s) . to_string ( ) ) . collect :: < Vec < _ > > ( ) ,
276+ scopes : get_default_scopes ( )
277+ . iter ( )
278+ . map ( |s| ( * s) . to_string ( ) )
279+ . collect :: < Vec < _ > > ( ) ,
250280 redirect_uri,
251281 } ;
252282 let reg_as_str = serde_json:: to_string_pretty ( & reg) ?;
@@ -268,6 +298,7 @@ async fn get_auth_manager(
268298
269299async fn get_auth_manager_impl (
270300 mut oauth_state : OAuthState ,
301+ scopes : & [ String ] ,
271302 messenger : & dyn Messenger ,
272303) -> Result < ( AuthorizationManager , String ) , OauthUtilError > {
273304 let socket_addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , 0 ) ) ;
@@ -278,7 +309,9 @@ async fn get_auth_manager_impl(
278309 info ! ( "Listening on local host port {:?} for oauth" , actual_addr) ;
279310
280311 let redirect_uri = format ! ( "http://{}" , actual_addr) ;
281- oauth_state. start_authorization ( get_scopes ( ) , & redirect_uri) . await ?;
312+ let scopes_as_str = scopes. iter ( ) . map ( String :: as_str) . collect :: < Vec < _ > > ( ) ;
313+ let scopes_as_slice = scopes_as_str. as_slice ( ) ;
314+ oauth_state. start_authorization ( scopes_as_slice, & redirect_uri) . await ?;
282315
283316 let auth_url = oauth_state. get_authorization_url ( ) . await ?;
284317 _ = messenger. send_oauth_link ( auth_url) . await ;
@@ -333,9 +366,19 @@ async fn make_svc(
333366 let query = uri. query ( ) . unwrap_or ( "" ) ;
334367 let params: std:: collections:: HashMap < String , String > =
335368 url:: form_urlencoded:: parse ( query. as_bytes ( ) ) . into_owned ( ) . collect ( ) ;
369+ debug ! ( "## mcp: uri: {}, query: {}, params: {:?}" , uri, query, params) ;
336370
337371 let self_clone = self . clone ( ) ;
338372 Box :: pin ( async move {
373+ let error = params. get ( "error" ) ;
374+ let resp = if let Some ( err) = error {
375+ mk_response ( format ! (
376+ "Oauth failed. Check url for precise reasons. Possible reasons: {err}.\n If this is scope related. You can try configuring the server scopes to be an empty array via adding oauth_scopes: []"
377+ ) )
378+ } else {
379+ mk_response ( "You can close this page now" . to_string ( ) )
380+ } ;
381+
339382 let code = params. get ( "code" ) . cloned ( ) . unwrap_or_default ( ) ;
340383 if let Some ( sender) = self_clone
341384 . one_shot_sender
@@ -345,7 +388,8 @@ async fn make_svc(
345388 {
346389 sender. send ( code) . map_err ( LoopBackError :: Send ) ?;
347390 }
348- mk_response ( "You can close this page now" . to_string ( ) )
391+
392+ resp
349393 } )
350394 }
351395 }
0 commit comments