Skip to content

Commit e3dd047

Browse files
committed
Make the compat_router also recover from errors in human-facing routes
1 parent 8f6c854 commit e3dd047

File tree

3 files changed

+47
-29
lines changed

3 files changed

+47
-29
lines changed

crates/cli/src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ pub fn build_router(
269269
}
270270
mas_config::HttpResource::OAuth => router.merge(mas_handlers::api_router::<AppState>()),
271271
mas_config::HttpResource::Compat => {
272-
router.merge(mas_handlers::compat_router::<AppState>())
272+
router.merge(mas_handlers::compat_router::<AppState>(templates.clone()))
273273
}
274274
mas_config::HttpResource::AdminApi => {
275275
let (_, api_router) = mas_handlers::admin_api_router::<AppState>();

crates/handlers/src/lib.rs

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ where
257257
}
258258

259259
#[allow(clippy::trait_duplication_in_bounds)]
260-
pub fn compat_router<S>() -> Router<S>
260+
pub fn compat_router<S>(templates: Templates) -> Router<S>
261261
where
262262
S: Clone + Send + Sync + 'static,
263263
UrlBuilder: FromRef<S>,
@@ -272,7 +272,28 @@ where
272272
BoxClock: FromRequestParts<S>,
273273
BoxRng: FromRequestParts<S>,
274274
{
275-
Router::new()
275+
// A sub-router for human-facing routes with error handling
276+
let human_router = Router::new()
277+
.route(
278+
mas_router::CompatLoginSsoRedirect::route(),
279+
get(self::compat::login_sso_redirect::get),
280+
)
281+
.route(
282+
mas_router::CompatLoginSsoRedirectIdp::route(),
283+
get(self::compat::login_sso_redirect::get),
284+
)
285+
.route(
286+
mas_router::CompatLoginSsoRedirectSlash::route(),
287+
get(self::compat::login_sso_redirect::get),
288+
)
289+
.layer(AndThenLayer::new(
290+
async move |response: axum::response::Response| {
291+
Ok::<_, Infallible>(recover_error(&templates, response))
292+
},
293+
));
294+
295+
// A sub-router for API-facing routes with CORS
296+
let api_router = Router::new()
276297
.route(
277298
mas_router::CompatLogin::route(),
278299
get(self::compat::login::get).post(self::compat::login::post),
@@ -289,18 +310,6 @@ where
289310
mas_router::CompatRefresh::route(),
290311
post(self::compat::refresh::post),
291312
)
292-
.route(
293-
mas_router::CompatLoginSsoRedirect::route(),
294-
get(self::compat::login_sso_redirect::get),
295-
)
296-
.route(
297-
mas_router::CompatLoginSsoRedirectIdp::route(),
298-
get(self::compat::login_sso_redirect::get),
299-
)
300-
.route(
301-
mas_router::CompatLoginSsoRedirectSlash::route(),
302-
get(self::compat::login_sso_redirect::get),
303-
)
304313
.layer(
305314
CorsLayer::new()
306315
.allow_origin(Any)
@@ -314,7 +323,9 @@ where
314323
HeaderName::from_static("x-requested-with"),
315324
])
316325
.max_age(Duration::from_secs(60 * 60)),
317-
)
326+
);
327+
328+
Router::new().merge(human_router).merge(api_router)
318329
}
319330

320331
#[allow(clippy::too_many_lines)]
@@ -454,22 +465,29 @@ where
454465
)
455466
.layer(AndThenLayer::new(
456467
async move |response: axum::response::Response| {
457-
// Error responses should have an ErrorContext attached to them
458-
let ext = response.extensions().get::<ErrorContext>();
459-
if let Some(ctx) = ext {
460-
if let Ok(res) = templates.render_error(ctx) {
461-
let (mut parts, _original_body) = response.into_parts();
462-
parts.headers.remove(CONTENT_TYPE);
463-
parts.headers.remove(CONTENT_LENGTH);
464-
return Ok((parts, Html(res)).into_response());
465-
}
466-
}
467-
468-
Ok::<_, Infallible>(response)
468+
Ok::<_, Infallible>(recover_error(&templates, response))
469469
},
470470
))
471471
}
472472

473+
fn recover_error(
474+
templates: &Templates,
475+
response: axum::response::Response,
476+
) -> axum::response::Response {
477+
// Error responses should have an ErrorContext attached to them
478+
let ext = response.extensions().get::<ErrorContext>();
479+
if let Some(ctx) = ext {
480+
if let Ok(res) = templates.render_error(ctx) {
481+
let (mut parts, _original_body) = response.into_parts();
482+
parts.headers.remove(CONTENT_TYPE);
483+
parts.headers.remove(CONTENT_LENGTH);
484+
return (parts, Html(res)).into_response();
485+
}
486+
}
487+
488+
response
489+
}
490+
473491
/// The fallback handler for all routes that don't match anything else.
474492
///
475493
/// # Errors

crates/handlers/src/test_utils.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ impl TestState {
319319
let app = crate::healthcheck_router()
320320
.merge(crate::discovery_router())
321321
.merge(crate::api_router())
322-
.merge(crate::compat_router())
322+
.merge(crate::compat_router(self.templates.clone()))
323323
.merge(crate::human_router(self.templates.clone()))
324324
// We enable undocumented_oauth2_access for the tests, as it is easier to query the API
325325
// with it

0 commit comments

Comments
 (0)