Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions crates/axum-utils/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// Please see LICENSE in the repository root for full details.

use mas_data_model::BrowserSession;
use mas_storage::{RepositoryAccess, user::BrowserSessionRepository};
use mas_storage::RepositoryAccess;
use serde::{Deserialize, Serialize};
use ulid::Ulid;

Expand Down Expand Up @@ -33,13 +33,12 @@ impl SessionInfo {
self
}

/// Load the [`BrowserSession`] from database
/// Load the active [`BrowserSession`] from database
///
/// # Errors
///
/// Returns an error if the session is not found or if the session is not
/// active anymore
pub async fn load_session<E>(
/// Returns an error if the underlying repository fails to load the session.
pub async fn load_active_session<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<BrowserSession>, E> {
Expand All @@ -56,6 +55,12 @@ impl SessionInfo {

Ok(maybe_session)
}

/// Get the current session ID, if any
#[must_use]
pub fn current_session_id(&self) -> Option<Ulid> {
self.current
}
}

pub trait SessionInfoExt {
Expand Down
6 changes: 4 additions & 2 deletions crates/data-model/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ pub struct User {
pub sub: String,
pub created_at: DateTime<Utc>,
pub locked_at: Option<DateTime<Utc>>,
pub deactivated_at: Option<DateTime<Utc>>,
pub can_request_admin: bool,
}

impl User {
/// Returns `true` unless the user is locked.
/// Returns `true` unless the user is locked or deactivated.
#[must_use]
pub fn is_valid(&self) -> bool {
self.locked_at.is_none()
self.locked_at.is_none() && self.deactivated_at.is_none()
}
}

Expand All @@ -42,6 +43,7 @@ impl User {
sub: "123-456".to_owned(),
created_at: now,
locked_at: None,
deactivated_at: None,
can_request_admin: false,
}]
}
Expand Down
39 changes: 31 additions & 8 deletions crates/handlers/src/compat/login_sso_complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use axum::{
};
use chrono::Duration;
use mas_axum_utils::{
FancyError, SessionInfoExt,
FancyError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
Expand All @@ -28,7 +28,10 @@ use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use ulid::Ulid;

use crate::PreferredLanguage;
use crate::{
PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
};

#[derive(Serialize)]
struct AllParams<'s> {
Expand Down Expand Up @@ -61,10 +64,20 @@ pub async fn get(
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};

let maybe_session = session_info.load_session(&mut repo).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let Some(session) = maybe_session else {
// If there is no session, redirect to the login or register screen
Expand Down Expand Up @@ -126,10 +139,20 @@ pub async fn post(
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(&clock, form)?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};

let maybe_session = session_info.load_session(&mut repo).await?;
cookie_jar.verify_form(&clock, form)?;

let Some(session) = maybe_session else {
// If there is no session, redirect to the login or register screen
Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ async fn get_requester(

RequestingEntity::OAuth2Session(Box::new((session, user)))
} else {
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;

if let Some(session) = maybe_session.as_ref() {
activity_tracker
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ mod activity_tracker;
mod captcha;
mod preferred_language;
mod rate_limit;
mod session;
#[cfg(test)]
mod test_utils;

Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/oauth2/authorization/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub(crate) async fn get(
) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info();

let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;

let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());

Expand Down
2 changes: 1 addition & 1 deletion crates/handlers/src/oauth2/authorization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ pub(crate) async fn get(
let callback_destination = callback_destination.clone();
let locale = locale.clone();
async move {
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default();

// Check if the request/request_uri/registration params are used. If so, reply
Expand Down
123 changes: 73 additions & 50 deletions crates/handlers/src/oauth2/consent.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
Expand All @@ -11,7 +11,6 @@ use axum::{
use axum_extra::TypedHeader;
use hyper::StatusCode;
use mas_axum_utils::{
SessionInfoExt,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
sentry::SentryEventID,
Expand All @@ -27,7 +26,10 @@ use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Tem
use thiserror::Error;
use ulid::Ulid;

use crate::{BoundActivityTracker, PreferredLanguage, impl_from_error_for_route};
use crate::{
BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
session::{SessionOrFallback, load_session_or_fallback},
};

#[derive(Debug, Error)]
pub enum RouteError {
Expand All @@ -54,6 +56,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(crate::session::SessionLoadError);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
Expand Down Expand Up @@ -85,9 +88,18 @@ pub(crate) async fn get(
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info();

let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};

let user_agent = user_agent.map(|ua| ua.to_string());

Expand All @@ -107,48 +119,48 @@ pub(crate) async fn get(
return Err(RouteError::GrantNotPending);
}

if let Some(session) = maybe_session {
activity_tracker
.record_browser_session(&clock, &session)
.await;

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;

if res.valid() {
let ctx = ConsentContext::new(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);

let content = templates.render_consent(&ctx)?;

Ok((cookie_jar, Html(content)).into_response())
} else {
let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);

let content = templates.render_policy_violation(&ctx)?;

Ok((cookie_jar, Html(content)).into_response())
}
} else {
let Some(session) = maybe_session else {
let login = mas_router::Login::and_continue_grant(grant_id);
Ok((cookie_jar, url_builder.redirect(&login)).into_response())
return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
};

activity_tracker
.record_browser_session(&clock, &session)
.await;

let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);

let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;

if res.valid() {
let ctx = ConsentContext::new(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);

let content = templates.render_consent(&ctx)?;

Ok((cookie_jar, Html(content)).into_response())
} else {
let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);

let content = templates.render_policy_violation(&ctx)?;

Ok((cookie_jar, Html(content)).into_response())
}
}

Expand All @@ -161,6 +173,8 @@ pub(crate) async fn get(
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
Expand All @@ -172,9 +186,18 @@ pub(crate) async fn post(
) -> Result<Response, RouteError> {
cookie_jar.verify_form(&clock, form)?;

let (session_info, cookie_jar) = cookie_jar.session_info();

let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};

let user_agent = user_agent.map(|ua| ua.to_string());

Expand Down
Loading
Loading