Skip to content

Commit 19ea6f3

Browse files
committed
Hacky support for the /logout/all compatibility endpoint
1 parent 104e8f3 commit 19ea6f3

File tree

4 files changed

+240
-1
lines changed

4 files changed

+240
-1
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// Copyright 2025 New Vector Ltd.
2+
//
3+
// SPDX-License-Identifier: AGPL-3.0-only
4+
// Please see LICENSE in the repository root for full details.
5+
6+
use std::sync::LazyLock;
7+
8+
use axum::{Json, response::IntoResponse};
9+
use axum_extra::typed_header::TypedHeader;
10+
use headers::{Authorization, authorization::Bearer};
11+
use hyper::StatusCode;
12+
use mas_axum_utils::record_error;
13+
use mas_data_model::TokenType;
14+
use mas_storage::{
15+
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
16+
compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository},
17+
queue::{QueueJobRepositoryExt as _, SyncDevicesJob},
18+
};
19+
use opentelemetry::{Key, KeyValue, metrics::Counter};
20+
use serde::Deserialize;
21+
use thiserror::Error;
22+
use tracing::info;
23+
use ulid::Ulid;
24+
25+
use super::{MatrixError, MatrixJsonBody};
26+
use crate::{BoundActivityTracker, METER, impl_from_error_for_route};
27+
28+
static LOGOUT_ALL_COUNTER: LazyLock<Counter<u64>> = LazyLock::new(|| {
29+
METER
30+
.u64_counter("mas.compat.logout_all_request")
31+
.with_description(
32+
"How many request to the /logout/all compatibility endpoint have happened",
33+
)
34+
.with_unit("{request}")
35+
.build()
36+
});
37+
const RESULT: Key = Key::from_static_str("result");
38+
39+
#[derive(Error, Debug)]
40+
pub enum RouteError {
41+
#[error(transparent)]
42+
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
43+
44+
#[error("Can't load session {0}")]
45+
CantLoadSession(Ulid),
46+
47+
#[error("Can't load user {0}")]
48+
CantLoadUser(Ulid),
49+
50+
#[error("Token {0} has expired")]
51+
InvalidToken(Ulid),
52+
53+
#[error("Session {0} has been revoked")]
54+
InvalidSession(Ulid),
55+
56+
#[error("User {0} is locked or deactivated")]
57+
InvalidUser(Ulid),
58+
59+
#[error("/logout/all is not supported")]
60+
NotSupported,
61+
62+
#[error("Missing access token")]
63+
MissingAuthorization,
64+
65+
#[error("Invalid token format")]
66+
TokenFormat(#[from] mas_data_model::TokenFormatError),
67+
68+
#[error("Access token is not a compatibility access token")]
69+
NotACompatToken,
70+
}
71+
72+
impl_from_error_for_route!(mas_storage::RepositoryError);
73+
74+
impl IntoResponse for RouteError {
75+
fn into_response(self) -> axum::response::Response {
76+
let sentry_event_id = record_error!(
77+
self,
78+
Self::Internal(_) | Self::CantLoadSession(_) | Self::CantLoadUser(_)
79+
);
80+
81+
// We track separately if the endpoint was called without the custom
82+
// parameter, so that we know if clients are using this endpoint in the
83+
// wild
84+
if matches!(self, Self::NotSupported) {
85+
LOGOUT_ALL_COUNTER.add(1, &[KeyValue::new(RESULT, "not_supported")]);
86+
} else {
87+
LOGOUT_ALL_COUNTER.add(1, &[KeyValue::new(RESULT, "error")]);
88+
}
89+
90+
let response = match self {
91+
Self::Internal(_) | Self::CantLoadSession(_) | Self::CantLoadUser(_) => MatrixError {
92+
errcode: "M_UNKNOWN",
93+
error: "Internal error",
94+
status: StatusCode::INTERNAL_SERVER_ERROR,
95+
},
96+
Self::MissingAuthorization => MatrixError {
97+
errcode: "M_MISSING_TOKEN",
98+
error: "Missing access token",
99+
status: StatusCode::UNAUTHORIZED,
100+
},
101+
Self::InvalidUser(_)
102+
| Self::InvalidSession(_)
103+
| Self::InvalidToken(_)
104+
| Self::NotACompatToken
105+
| Self::TokenFormat(_) => MatrixError {
106+
errcode: "M_UNKNOWN_TOKEN",
107+
error: "Invalid access token",
108+
status: StatusCode::UNAUTHORIZED,
109+
},
110+
Self::NotSupported => MatrixError {
111+
errcode: "M_UNRECOGNIZED",
112+
error: "The /logout/all endpoint is not supported by this deployment",
113+
status: StatusCode::NOT_FOUND,
114+
},
115+
};
116+
117+
(sentry_event_id, response).into_response()
118+
}
119+
}
120+
121+
#[derive(Deserialize, Default)]
122+
pub(crate) struct RequestBody {
123+
#[serde(rename = "io.element.only_compat_is_fine", default)]
124+
only_compat_is_fine: bool,
125+
}
126+
127+
#[tracing::instrument(name = "handlers.compat.logout_all.post", skip_all)]
128+
pub(crate) async fn post(
129+
clock: BoxClock,
130+
mut rng: BoxRng,
131+
mut repo: BoxRepository,
132+
activity_tracker: BoundActivityTracker,
133+
maybe_authorization: Option<TypedHeader<Authorization<Bearer>>>,
134+
input: Option<MatrixJsonBody<RequestBody>>,
135+
) -> Result<impl IntoResponse, RouteError> {
136+
let MatrixJsonBody(input) = input.unwrap_or_default();
137+
let TypedHeader(authorization) = maybe_authorization.ok_or(RouteError::MissingAuthorization)?;
138+
139+
let token = authorization.token();
140+
let token_type = TokenType::check(token)?;
141+
142+
if token_type != TokenType::CompatAccessToken {
143+
return Err(RouteError::NotACompatToken);
144+
}
145+
146+
let token = repo
147+
.compat_access_token()
148+
.find_by_token(token)
149+
.await?
150+
.ok_or(RouteError::NotACompatToken)?;
151+
152+
if !token.is_valid(clock.now()) {
153+
return Err(RouteError::InvalidToken(token.id));
154+
}
155+
156+
let session = repo
157+
.compat_session()
158+
.lookup(token.session_id)
159+
.await?
160+
.ok_or(RouteError::CantLoadSession(token.session_id))?;
161+
162+
if !session.is_valid() {
163+
return Err(RouteError::InvalidSession(session.id));
164+
}
165+
166+
activity_tracker
167+
.record_compat_session(&clock, &session)
168+
.await;
169+
170+
let user = repo
171+
.user()
172+
.lookup(session.user_id)
173+
.await?
174+
.ok_or(RouteError::CantLoadUser(session.user_id))?;
175+
176+
if !user.is_valid() {
177+
return Err(RouteError::InvalidUser(session.user_id));
178+
}
179+
180+
if !input.only_compat_is_fine {
181+
return Err(RouteError::NotSupported);
182+
}
183+
184+
let filter = CompatSessionFilter::new().for_user(&user).active_only();
185+
let affected_sessions = repo.compat_session().finish_bulk(&clock, filter).await?;
186+
info!(
187+
"Logged out {affected_sessions} sessions for user {user_id}",
188+
user_id = user.id
189+
);
190+
191+
// Schedule a job to sync the devices of the user with the homeserver
192+
repo.queue_job()
193+
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
194+
.await?;
195+
196+
repo.save().await?;
197+
198+
LOGOUT_ALL_COUNTER.add(1, &[KeyValue::new(RESULT, "success")]);
199+
200+
Ok(Json(serde_json::json!({})))
201+
}

crates/handlers/src/compat/mod.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 New Vector Ltd.
1+
// Copyright 2024, 2025 New Vector Ltd.
22
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
33
//
44
// SPDX-License-Identifier: AGPL-3.0-only
@@ -22,6 +22,7 @@ pub(crate) mod login;
2222
pub(crate) mod login_sso_complete;
2323
pub(crate) mod login_sso_redirect;
2424
pub(crate) mod logout;
25+
pub(crate) mod logout_all;
2526
pub(crate) mod refresh;
2627

2728
#[derive(Debug, Serialize)]
@@ -140,3 +141,29 @@ where
140141
Ok(Self(value))
141142
}
142143
}
144+
145+
impl<T, S> axum::extract::OptionalFromRequest<S> for MatrixJsonBody<T>
146+
where
147+
T: DeserializeOwned,
148+
S: Send + Sync,
149+
{
150+
type Rejection = MatrixJsonBodyRejection;
151+
152+
async fn from_request(req: Request, state: &S) -> Result<Option<Self>, Self::Rejection> {
153+
if req.headers().contains_key(header::CONTENT_TYPE) {
154+
// If there is a Content-Type header, handle it as normal
155+
let result = <Self as axum::extract::FromRequest<S>>::from_request(req, state).await?;
156+
return Ok(Some(result));
157+
}
158+
159+
// Else, we poke at the body, and deserialize it only if it's JSON
160+
let bytes = <Bytes as axum::extract::FromRequest<S>>::from_request(req, state).await?;
161+
if bytes.is_empty() {
162+
return Ok(None);
163+
}
164+
165+
let value: T = serde_json::from_slice(&bytes)?;
166+
167+
Ok(Some(Self(value)))
168+
}
169+
}

crates/handlers/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ where
278278
mas_router::CompatLogout::route(),
279279
post(self::compat::logout::post),
280280
)
281+
.route(
282+
mas_router::CompatLogoutAll::route(),
283+
post(self::compat::logout_all::post),
284+
)
281285
.route(
282286
mas_router::CompatRefresh::route(),
283287
post(self::compat::refresh::post),

crates/router/src/endpoints.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,13 @@ impl SimpleRoute for CompatLogout {
548548
const PATH: &'static str = "/_matrix/client/{version}/logout";
549549
}
550550

551+
/// `POST /_matrix/client/v3/logout/all`
552+
pub struct CompatLogoutAll;
553+
554+
impl SimpleRoute for CompatLogoutAll {
555+
const PATH: &'static str = "/_matrix/client/{version}/logout/all";
556+
}
557+
551558
/// `POST /_matrix/client/v3/refresh`
552559
pub struct CompatRefresh;
553560

0 commit comments

Comments
 (0)