@@ -257,7 +257,7 @@ where
257
257
}
258
258
259
259
#[ allow( clippy:: trait_duplication_in_bounds) ]
260
- pub fn compat_router < S > ( ) -> Router < S >
260
+ pub fn compat_router < S > ( templates : Templates ) -> Router < S >
261
261
where
262
262
S : Clone + Send + Sync + ' static ,
263
263
UrlBuilder : FromRef < S > ,
@@ -272,7 +272,28 @@ where
272
272
BoxClock : FromRequestParts < S > ,
273
273
BoxRng : FromRequestParts < S > ,
274
274
{
275
- Router :: new ( )
275
+ // A sub-router for human-facing routes with error handling
276
+ let human_router = Router :: new ( )
277
+ . route (
278
+ mas_router:: CompatLoginSsoRedirect :: route ( ) ,
279
+ get ( self :: compat:: login_sso_redirect:: get) ,
280
+ )
281
+ . route (
282
+ mas_router:: CompatLoginSsoRedirectIdp :: route ( ) ,
283
+ get ( self :: compat:: login_sso_redirect:: get) ,
284
+ )
285
+ . route (
286
+ mas_router:: CompatLoginSsoRedirectSlash :: route ( ) ,
287
+ get ( self :: compat:: login_sso_redirect:: get) ,
288
+ )
289
+ . layer ( AndThenLayer :: new (
290
+ async move |response : axum:: response:: Response | {
291
+ Ok :: < _ , Infallible > ( recover_error ( & templates, response) )
292
+ } ,
293
+ ) ) ;
294
+
295
+ // A sub-router for API-facing routes with CORS
296
+ let api_router = Router :: new ( )
276
297
. route (
277
298
mas_router:: CompatLogin :: route ( ) ,
278
299
get ( self :: compat:: login:: get) . post ( self :: compat:: login:: post) ,
@@ -289,18 +310,6 @@ where
289
310
mas_router:: CompatRefresh :: route ( ) ,
290
311
post ( self :: compat:: refresh:: post) ,
291
312
)
292
- . route (
293
- mas_router:: CompatLoginSsoRedirect :: route ( ) ,
294
- get ( self :: compat:: login_sso_redirect:: get) ,
295
- )
296
- . route (
297
- mas_router:: CompatLoginSsoRedirectIdp :: route ( ) ,
298
- get ( self :: compat:: login_sso_redirect:: get) ,
299
- )
300
- . route (
301
- mas_router:: CompatLoginSsoRedirectSlash :: route ( ) ,
302
- get ( self :: compat:: login_sso_redirect:: get) ,
303
- )
304
313
. layer (
305
314
CorsLayer :: new ( )
306
315
. allow_origin ( Any )
@@ -314,7 +323,9 @@ where
314
323
HeaderName :: from_static ( "x-requested-with" ) ,
315
324
] )
316
325
. max_age ( Duration :: from_secs ( 60 * 60 ) ) ,
317
- )
326
+ ) ;
327
+
328
+ Router :: new ( ) . merge ( human_router) . merge ( api_router)
318
329
}
319
330
320
331
#[ allow( clippy:: too_many_lines) ]
@@ -454,22 +465,29 @@ where
454
465
)
455
466
. layer ( AndThenLayer :: new (
456
467
async move |response : axum:: response:: Response | {
457
- // Error responses should have an ErrorContext attached to them
458
- let ext = response. extensions ( ) . get :: < ErrorContext > ( ) ;
459
- if let Some ( ctx) = ext {
460
- if let Ok ( res) = templates. render_error ( ctx) {
461
- let ( mut parts, _original_body) = response. into_parts ( ) ;
462
- parts. headers . remove ( CONTENT_TYPE ) ;
463
- parts. headers . remove ( CONTENT_LENGTH ) ;
464
- return Ok ( ( parts, Html ( res) ) . into_response ( ) ) ;
465
- }
466
- }
467
-
468
- Ok :: < _ , Infallible > ( response)
468
+ Ok :: < _ , Infallible > ( recover_error ( & templates, response) )
469
469
} ,
470
470
) )
471
471
}
472
472
473
+ fn recover_error (
474
+ templates : & Templates ,
475
+ response : axum:: response:: Response ,
476
+ ) -> axum:: response:: Response {
477
+ // Error responses should have an ErrorContext attached to them
478
+ let ext = response. extensions ( ) . get :: < ErrorContext > ( ) ;
479
+ if let Some ( ctx) = ext {
480
+ if let Ok ( res) = templates. render_error ( ctx) {
481
+ let ( mut parts, _original_body) = response. into_parts ( ) ;
482
+ parts. headers . remove ( CONTENT_TYPE ) ;
483
+ parts. headers . remove ( CONTENT_LENGTH ) ;
484
+ return ( parts, Html ( res) ) . into_response ( ) ;
485
+ }
486
+ }
487
+
488
+ response
489
+ }
490
+
473
491
/// The fallback handler for all routes that don't match anything else.
474
492
///
475
493
/// # Errors
0 commit comments