Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions crates/handlers/src/compat/logout_all.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::sync::LazyLock;

use axum::{Json, response::IntoResponse};
use axum_extra::typed_header::TypedHeader;
use headers::{Authorization, authorization::Bearer};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::TokenType;
use mas_storage::{
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository},
queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
};
use opentelemetry::{Key, KeyValue, metrics::Counter};
use serde::Deserialize;
use thiserror::Error;
use tracing::info;
use ulid::Ulid;

use super::{MatrixError, MatrixJsonBody};
use crate::{BoundActivityTracker, METER, impl_from_error_for_route};

static LOGOUT_ALL_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
METER
.u64_counter("mas.compat.logout_all_request")
.with_description(
"How many request to the /logout/all compatibility endpoint have happened",
)
.with_unit("{request}")
.build()
});
const RESULT: Key = Key::from_static_str("result");

#[derive(Error, Debug)]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),

#[error("Can't load session {0}")]
CantLoadSession(Ulid),

#[error("Can't load user {0}")]
CantLoadUser(Ulid),

#[error("Token {0} has expired")]
InvalidToken(Ulid),

#[error("Session {0} has been revoked")]
InvalidSession(Ulid),

#[error("User {0} is locked or deactivated")]
InvalidUser(Ulid),

#[error("/logout/all is not supported")]
NotSupported,

#[error("Missing access token")]
MissingAuthorization,

#[error("Invalid token format")]
TokenFormat(#[from] mas_data_model::TokenFormatError),

#[error("Access token is not a compatibility access token")]
NotACompatToken,
}

impl_from_error_for_route!(mas_storage::RepositoryError);

impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let sentry_event_id = record_error!(
self,
Self::Internal(_) | Self::CantLoadSession(_) | Self::CantLoadUser(_)
);

// We track separately if the endpoint was called without the custom
// parameter, so that we know if clients are using this endpoint in the
// wild
if matches!(self, Self::NotSupported) {
LOGOUT_ALL_COUNTER.add(1, &[KeyValue::new(RESULT, "not_supported")]);
} else {
LOGOUT_ALL_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
}

let response = match self {
Self::Internal(_) | Self::CantLoadSession(_) | Self::CantLoadUser(_) => MatrixError {
errcode: "M_UNKNOWN",
error: "Internal error",
status: StatusCode::INTERNAL_SERVER_ERROR,
},
Self::MissingAuthorization => MatrixError {
errcode: "M_MISSING_TOKEN",
error: "Missing access token",
status: StatusCode::UNAUTHORIZED,
},
Self::InvalidUser(_)
| Self::InvalidSession(_)
| Self::InvalidToken(_)
| Self::NotACompatToken
| Self::TokenFormat(_) => MatrixError {
errcode: "M_UNKNOWN_TOKEN",
error: "Invalid access token",
status: StatusCode::UNAUTHORIZED,
},
Self::NotSupported => MatrixError {
errcode: "M_UNRECOGNIZED",
error: "The /logout/all endpoint is not supported by this deployment",
status: StatusCode::NOT_FOUND,
},
};

(sentry_event_id, response).into_response()
}
}

#[derive(Deserialize, Default)]
pub(crate) struct RequestBody {
#[serde(rename = "io.element.only_compat_is_fine", default)]
only_compat_is_fine: bool,
}

#[tracing::instrument(name = "handlers.compat.logout_all.post", skip_all)]
pub(crate) async fn post(
clock: BoxClock,
mut rng: BoxRng,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
input: Option<MatrixJsonBody<RequestBody>>,
) -> Result<impl IntoResponse, RouteError> {
let MatrixJsonBody(input) = input.unwrap_or_default();
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;

let token = authorization.token();
let token_type = TokenType::check(token)?;

if token_type != TokenType::CompatAccessToken {
return Err(RouteError::NotACompatToken);
}

let token = repo
.compat_access_token()
.find_by_token(token)
.await?
.ok_or(RouteError::NotACompatToken)?;

if !token.is_valid(clock.now()) {
return Err(RouteError::InvalidToken(token.id));
}

let session = repo
.compat_session()
.lookup(token.session_id)
.await?
.ok_or(RouteError::CantLoadSession(token.session_id))?;

if !session.is_valid() {
return Err(RouteError::InvalidSession(session.id));
}

activity_tracker
.record_compat_session(&clock, &session)
.await;

let user = repo
.user()
.lookup(session.user_id)
.await?
.ok_or(RouteError::CantLoadUser(session.user_id))?;

if !user.is_valid() {
return Err(RouteError::InvalidUser(session.user_id));
}

if !input.only_compat_is_fine {
return Err(RouteError::NotSupported);
}

let filter = CompatSessionFilter::new().for_user(&user).active_only();
let affected_sessions = repo.compat_session().finish_bulk(&clock, filter).await?;
info!(
"Logged out {affected_sessions} sessions for user {user_id}",
user_id = user.id
);

// Schedule a job to sync the devices of the user with the homeserver
repo.queue_job()
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
.await?;

repo.save().await?;

LOGOUT_ALL_COUNTER.add(1, &[KeyValue::new(RESULT, "success")]);

Ok(Json(serde_json::json!({})))
}
29 changes: 28 additions & 1 deletion crates/handlers/src/compat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
Expand All @@ -22,6 +22,7 @@ pub(crate) mod login;
pub(crate) mod login_sso_complete;
pub(crate) mod login_sso_redirect;
pub(crate) mod logout;
pub(crate) mod logout_all;
pub(crate) mod refresh;

#[derive(Debug, Serialize)]
Expand Down Expand Up @@ -140,3 +141,29 @@ where
Ok(Self(value))
}
}

impl<T, S> axum::extract::OptionalFromRequest<S> for MatrixJsonBody<T>
where
T: DeserializeOwned,
S: Send + Sync,
{
type Rejection = MatrixJsonBodyRejection;

async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
if req.headers().contains_key(header::CONTENT_TYPE) {
// If there is a Content-Type header, handle it as normal
let result = <Self as axum::extract::FromRequest<S>>::from_request(req, state).await?;
return Ok(Some(result));
}

// Else, we poke at the body, and deserialize it only if it's JSON
let bytes = <Bytes as axum::extract::FromRequest<S>>::from_request(req, state).await?;
if bytes.is_empty() {
return Ok(None);
}

let value: T = serde_json::from_slice(&bytes)?;

Ok(Some(Self(value)))
}
}
4 changes: 4 additions & 0 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ where
mas_router::CompatLogout::route(),
post(self::compat::logout::post),
)
.route(
mas_router::CompatLogoutAll::route(),
post(self::compat::logout_all::post),
)
.route(
mas_router::CompatRefresh::route(),
post(self::compat::refresh::post),
Expand Down
7 changes: 7 additions & 0 deletions crates/router/src/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,13 @@ impl SimpleRoute for CompatLogout {
const PATH: &'static str = "/_matrix/client/{version}/logout";
}

/// `POST /_matrix/client/v3/logout/all`
pub struct CompatLogoutAll;

impl SimpleRoute for CompatLogoutAll {
const PATH: &'static str = "/_matrix/client/{version}/logout/all";
}

/// `POST /_matrix/client/v3/refresh`
pub struct CompatRefresh;

Expand Down
Loading