Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/axum-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ publish = false
workspace = true

[dependencies]
anyhow.workspace = true
axum.workspace = true
axum-extra.workspace = true
base64ct.workspace = true
Expand Down
12 changes: 2 additions & 10 deletions crates/axum-utils/src/error_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
// Please see LICENSE in the repository root for full details.

use axum::response::{IntoResponse, Response};
use http::StatusCode;

use crate::record_error;
use crate::InternalError;

/// A simple wrapper around an error that implements [`IntoResponse`].
#[derive(Debug, thiserror::Error)]
Expand All @@ -19,13 +18,6 @@ where
T: std::error::Error + 'static,
{
fn into_response(self) -> Response {
// TODO: make this a bit more user friendly
let sentry_event_id = record_error!(self.0);
(
StatusCode::INTERNAL_SERVER_ERROR,
sentry_event_id,
self.0.to_string(),
)
.into_response()
InternalError::from(self.0).into_response()
}
}
98 changes: 67 additions & 31 deletions crates/axum-utils/src/fancy_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,91 @@ use mas_templates::ErrorContext;

use crate::sentry::SentryEventID;

pub struct FancyError {
context: ErrorContext,
fn build_context(mut err: &dyn std::error::Error) -> ErrorContext {
let description = err.to_string();
let mut details = Vec::new();
while let Some(source) = err.source() {
err = source;
details.push(err.to_string());
}

ErrorContext::new()
.with_description(description)
.with_details(details.join("\n"))
}

impl FancyError {
#[must_use]
pub fn new(context: ErrorContext) -> Self {
Self { context }
pub struct GenericError {
error: Box<dyn std::error::Error + 'static>,
code: StatusCode,
}

impl IntoResponse for GenericError {
fn into_response(self) -> Response {
tracing::warn!(message = &*self.error);
let context = build_context(&*self.error);
let context_text = format!("{context}");

(
self.code,
TypedHeader(ContentType::text()),
Extension(context),
context_text,
)
.into_response()
}
}

impl std::fmt::Display for FancyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let code = self.context.code().unwrap_or("Internal error");
match (self.context.description(), self.context.details()) {
(Some(description), Some(details)) => {
write!(f, "{code}: {description} ({details})")
}
(Some(message), None) | (None, Some(message)) => {
write!(f, "{code}: {message}")
}
(None, None) => {
write!(f, "{code}")
}
impl GenericError {
pub fn new(code: StatusCode, err: impl std::error::Error + 'static) -> Self {
Self {
error: Box::new(err),
code,
}
}
}

impl<E: std::fmt::Debug + std::fmt::Display> From<E> for FancyError {
fn from(err: E) -> Self {
let context = ErrorContext::new()
.with_description(format!("{err}"))
.with_details(format!("{err:?}"));
FancyError { context }
}
pub struct InternalError {
error: Box<dyn std::error::Error + 'static>,
}

impl IntoResponse for FancyError {
impl IntoResponse for InternalError {
fn into_response(self) -> Response {
tracing::error!(message = %self.context);
let error = format!("{}", self.context);
tracing::error!(message = &*self.error);
let event_id = SentryEventID::for_last_event();
let context = build_context(&*self.error);
let context_text = format!("{context}");

(
StatusCode::INTERNAL_SERVER_ERROR,
TypedHeader(ContentType::text()),
event_id,
Extension(self.context),
error,
Extension(context),
context_text,
)
.into_response()
}
}

impl<E: std::error::Error + 'static> From<E> for InternalError {
fn from(err: E) -> Self {
Self {
error: Box::new(err),
}
}
}

impl InternalError {
/// Create a new error from a boxed error
#[must_use]
pub fn new(error: Box<dyn std::error::Error + 'static>) -> Self {
Self { error }
}

/// Create a new error from an [`anyhow::Error`]
#[must_use]
pub fn from_anyhow(err: anyhow::Error) -> Self {
Self {
error: err.into_boxed_dyn_error(),
}
}
}
2 changes: 1 addition & 1 deletion crates/axum-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ pub use axum;

pub use self::{
error_wrapper::ErrorWrapper,
fancy_error::FancyError,
fancy_error::{GenericError, InternalError},
session::{SessionInfo, SessionInfoExt},
};
6 changes: 3 additions & 3 deletions crates/handlers/src/admin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use axum::{
};
use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use indexmap::IndexMap;
use mas_axum_utils::FancyError;
use mas_axum_utils::InternalError;
use mas_http::CorsLayerExt;
use mas_matrix::HomeserverConnection;
use mas_policy::PolicyFactory;
Expand Down Expand Up @@ -180,7 +180,7 @@ where
async fn swagger(
State(url_builder): State<UrlBuilder>,
State(templates): State<Templates>,
) -> Result<Html<String>, FancyError> {
) -> Result<Html<String>, InternalError> {
let ctx = ApiDocContext::from_url_builder(&url_builder);
let res = templates.render_swagger(&ctx)?;
Ok(Html(res))
Expand All @@ -189,7 +189,7 @@ async fn swagger(
async fn swagger_callback(
State(url_builder): State<UrlBuilder>,
State(templates): State<Templates>,
) -> Result<Html<String>, FancyError> {
) -> Result<Html<String>, InternalError> {
let ctx = ApiDocContext::from_url_builder(&url_builder);
let res = templates.render_swagger_callback(&ctx)?;
Ok(Html(res))
Expand Down
12 changes: 7 additions & 5 deletions crates/handlers/src/compat/login_sso_complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use axum::{
};
use chrono::Duration;
use mas_axum_utils::{
FancyError,
InternalError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
Expand Down Expand Up @@ -59,7 +59,7 @@ pub async fn get(
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> {
) -> Result<Response, InternalError> {
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
Expand Down Expand Up @@ -93,7 +93,8 @@ pub async fn get(
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
.context("Could not find compat SSO login")
.map_err(InternalError::from_anyhow)?;

// Bail out if that login session is more than 30min old
if clock.now() > login.created_at + Duration::microseconds(30 * 60 * 1000 * 1000) {
Expand Down Expand Up @@ -132,7 +133,7 @@ pub async fn post(
Path(id): Path<Ulid>,
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> {
) -> Result<Response, InternalError> {
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
Expand Down Expand Up @@ -166,7 +167,8 @@ pub async fn post(
.compat_sso_login()
.lookup(id)
.await?
.context("Could not find compat SSO login")?;
.context("Could not find compat SSO login")
.map_err(InternalError::from_anyhow)?;

// Bail out if that login session isn't pending, or is more than 30min old
if !login.is_pending()
Expand Down
4 changes: 2 additions & 2 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use futures_util::TryStreamExt;
use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{
FancyError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
InternalError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
};
use mas_data_model::{BrowserSession, Session, SiteConfig, User};
use mas_matrix::HomeserverConnection;
Expand Down Expand Up @@ -383,7 +383,7 @@ pub async fn get(
authorization: Option<TypedHeader<Authorization<Bearer>>>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
) -> Result<impl IntoResponse, InternalError> {
let token = authorization
.as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token());
Expand Down
4 changes: 2 additions & 2 deletions crates/handlers/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
// Please see LICENSE in the repository root for full details.

use axum::{extract::State, response::IntoResponse};
use mas_axum_utils::FancyError;
use mas_axum_utils::InternalError;
use sqlx::PgPool;
use tracing::{Instrument, info_span};

pub async fn get(State(pool): State<PgPool>) -> Result<impl IntoResponse, FancyError> {
pub async fn get(State(pool): State<PgPool>) -> Result<impl IntoResponse, InternalError> {
let mut conn = pool.acquire().await?;

sqlx::query("SELECT $1")
Expand Down
22 changes: 10 additions & 12 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use hyper::{
ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
},
};
use mas_axum_utils::{FancyError, cookies::CookieJar};
use mas_axum_utils::{InternalError, cookies::CookieJar};
use mas_data_model::SiteConfig;
use mas_http::CorsLayerExt;
use mas_keystore::{Encrypter, Keystore};
Expand Down Expand Up @@ -437,16 +437,14 @@ where
)
.layer(AndThenLayer::new(
async move |response: axum::response::Response| {
if response.status().is_server_error() {
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx) {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
parts.headers.remove(CONTENT_LENGTH);
return Ok((parts, Html(res)).into_response());
}
// Error responses should have an ErrorContext attached to them
let ext = response.extensions().get::<ErrorContext>();
if let Some(ctx) = ext {
if let Ok(res) = templates.render_error(ctx) {
let (mut parts, _original_body) = response.into_parts();
parts.headers.remove(CONTENT_TYPE);
parts.headers.remove(CONTENT_LENGTH);
return Ok((parts, Html(res)).into_response());
}
}

Expand All @@ -466,7 +464,7 @@ pub async fn fallback(
method: Method,
version: Version,
PreferredLanguage(locale): PreferredLanguage,
) -> Result<impl IntoResponse, FancyError> {
) -> Result<impl IntoResponse, InternalError> {
let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
// XXX: this should look at the Accept header and return JSON if requested

Expand Down
Loading
Loading