Skip to content

Commit 0629915

Browse files
authored
Allow the homeserver to perform introspection using a shared secret (#4808)
2 parents cfefd24 + 85be5e1 commit 0629915

File tree

7 files changed

+207
-47
lines changed

7 files changed

+207
-47
lines changed

crates/axum-utils/src/client_authorization.rs

Lines changed: 86 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@ use std::collections::HashMap;
99
use axum::{
1010
BoxError, Json,
1111
extract::{
12-
Form, FromRequest, FromRequestParts,
12+
Form, FromRequest,
1313
rejection::{FailedToDeserializeForm, FormRejection},
1414
},
1515
response::IntoResponse,
1616
};
17-
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18-
use headers::{Authorization, authorization::Basic};
17+
use headers::authorization::{Basic, Bearer, Credentials as _};
1918
use http::{Request, StatusCode};
2019
use mas_data_model::{Client, JwksOrJwksUri};
2120
use mas_http::RequestBuilderExt;
@@ -60,17 +59,30 @@ pub enum Credentials {
6059
client_id: String,
6160
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
6261
},
62+
BearerToken {
63+
token: String,
64+
},
6365
}
6466

6567
impl Credentials {
6668
/// Get the `client_id` of the credentials
6769
#[must_use]
68-
pub fn client_id(&self) -> &str {
70+
pub fn client_id(&self) -> Option<&str> {
6971
match self {
7072
Credentials::None { client_id }
7173
| Credentials::ClientSecretBasic { client_id, .. }
7274
| Credentials::ClientSecretPost { client_id, .. }
73-
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
75+
| Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
76+
Credentials::BearerToken { .. } => None,
77+
}
78+
}
79+
80+
/// Get the bearer token from the credentials.
81+
#[must_use]
82+
pub fn bearer_token(&self) -> Option<&str> {
83+
match self {
84+
Credentials::BearerToken { token } => Some(token),
85+
_ => None,
7486
}
7587
}
7688

@@ -89,6 +101,7 @@ impl Credentials {
89101
| Credentials::ClientSecretBasic { client_id, .. }
90102
| Credentials::ClientSecretPost { client_id, .. }
91103
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
104+
Credentials::BearerToken { .. } => return Ok(None),
92105
};
93106

94107
repo.oauth2_client().find_by_client_id(client_id).await
@@ -239,7 +252,7 @@ pub struct ClientAuthorization<F = ()> {
239252
impl<F> ClientAuthorization<F> {
240253
/// Get the `client_id` from the credentials.
241254
#[must_use]
242-
pub fn client_id(&self) -> &str {
255+
pub fn client_id(&self) -> Option<&str> {
243256
self.credentials.client_id()
244257
}
245258
}
@@ -360,25 +373,36 @@ where
360373
req: Request<axum::body::Body>,
361374
state: &S,
362375
) -> Result<Self, Self::Rejection> {
363-
// Split the request into parts so we can extract some headers
364-
let (mut parts, body) = req.into_parts();
365-
366-
let header =
367-
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
368-
369-
// Take the Authorization header
370-
let credentials_from_header = match header {
371-
Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
372-
Err(err) => match err.reason() {
373-
// If it's missing it is fine
374-
TypedHeaderRejectionReason::Missing => None,
375-
// If the header could not be parsed, return the error
376-
_ => return Err(ClientAuthorizationError::InvalidHeader),
377-
},
378-
};
376+
enum Authorization {
377+
Basic(String, String),
378+
Bearer(String),
379+
}
380+
381+
// Sadly, the typed-header 'Authorization' doesn't let us check for both
382+
// Basic and Bearer at the same time, so we need to parse them manually
383+
let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
384+
let bytes = header.as_bytes();
385+
if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
386+
let Some(decoded) = Basic::decode(header) else {
387+
return Err(ClientAuthorizationError::InvalidHeader);
388+
};
379389

380-
// Reconstruct the request from the parts
381-
let req = Request::from_parts(parts, body);
390+
Some(Authorization::Basic(
391+
decoded.username().to_owned(),
392+
decoded.password().to_owned(),
393+
))
394+
} else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
395+
let Some(decoded) = Bearer::decode(header) else {
396+
return Err(ClientAuthorizationError::InvalidHeader);
397+
};
398+
399+
Some(Authorization::Bearer(decoded.token().to_owned()))
400+
} else {
401+
return Err(ClientAuthorizationError::InvalidHeader);
402+
}
403+
} else {
404+
None
405+
};
382406

383407
// Take the form value
384408
let (
@@ -407,13 +431,19 @@ where
407431

408432
// And now, figure out the actual auth method
409433
let credentials = match (
410-
credentials_from_header,
434+
authorization,
411435
client_id_from_form,
412436
client_secret_from_form,
413437
client_assertion_type,
414438
client_assertion,
415439
) {
416-
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
440+
(
441+
Some(Authorization::Basic(client_id, client_secret)),
442+
client_id_from_form,
443+
None,
444+
None,
445+
None,
446+
) => {
417447
if let Some(client_id_from_form) = client_id_from_form {
418448
// If the client_id was in the body, verify it matches with the header
419449
if client_id != client_id_from_form {
@@ -483,6 +513,11 @@ where
483513
});
484514
}
485515

516+
(Some(Authorization::Bearer(token)), None, None, None, None) => {
517+
// Got a bearer token
518+
Credentials::BearerToken { token }
519+
}
520+
486521
(None, None, None, None, None) => {
487522
// Special case when there are no credentials anywhere
488523
return Err(ClientAuthorizationError::MissingCredentials);
@@ -677,4 +712,29 @@ mod tests {
677712
jwt.verify_with_shared_secret(b"client-secret".to_vec())
678713
.unwrap();
679714
}
715+
716+
#[tokio::test]
717+
async fn bearer_token_test() {
718+
let req = Request::builder()
719+
.method(Method::POST)
720+
.header(
721+
http::header::CONTENT_TYPE,
722+
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
723+
)
724+
.header(http::header::AUTHORIZATION, "Bearer token")
725+
.body(Body::new("foo=bar".to_owned()))
726+
.unwrap();
727+
728+
assert_eq!(
729+
ClientAuthorization::<serde_json::Value>::from_request(req, &())
730+
.await
731+
.unwrap(),
732+
ClientAuthorization {
733+
credentials: Credentials::BearerToken {
734+
token: "token".to_owned(),
735+
},
736+
form: Some(serde_json::json!({"foo": "bar"})),
737+
}
738+
);
739+
}
680740
}

crates/handlers/src/oauth2/introspection.rs

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
55
// Please see LICENSE files in the repository root for full details.
66

7-
use std::sync::LazyLock;
7+
use std::sync::{Arc, LazyLock};
88

99
use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
1010
use hyper::{HeaderMap, StatusCode};
@@ -15,6 +15,7 @@ use mas_axum_utils::{
1515
use mas_data_model::{Device, TokenFormatError, TokenType};
1616
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
1717
use mas_keystore::Encrypter;
18+
use mas_matrix::HomeserverConnection;
1819
use mas_storage::{
1920
BoxClock, BoxRepository, Clock,
2021
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
@@ -102,8 +103,14 @@ pub enum RouteError {
102103
#[error("bad request")]
103104
BadRequest,
104105

106+
#[error("failed to verify token")]
107+
FailedToVerifyToken(#[source] anyhow::Error),
108+
105109
#[error(transparent)]
106110
ClientCredentialsVerification(#[from] CredentialsVerificationError),
111+
112+
#[error("bearer token presented is invalid")]
113+
InvalidBearerToken,
107114
}
108115

109116
impl IntoResponse for RouteError {
@@ -114,13 +121,15 @@ impl IntoResponse for RouteError {
114121
| Self::CantLoadCompatSession(_)
115122
| Self::CantLoadOAuthSession(_)
116123
| Self::CantLoadUser(_)
124+
| Self::FailedToVerifyToken(_)
117125
);
118126

119127
let response = match self {
120128
e @ (Self::Internal(_)
121129
| Self::CantLoadCompatSession(_)
122130
| Self::CantLoadOAuthSession(_)
123-
| Self::CantLoadUser(_)) => (
131+
| Self::CantLoadUser(_)
132+
| Self::FailedToVerifyToken(_)) => (
124133
StatusCode::INTERNAL_SERVER_ERROR,
125134
Json(
126135
ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()),
@@ -140,6 +149,14 @@ impl IntoResponse for RouteError {
140149
),
141150
)
142151
.into_response(),
152+
e @ Self::InvalidBearerToken => (
153+
StatusCode::UNAUTHORIZED,
154+
Json(
155+
ClientError::from(ClientErrorCode::AccessDenied)
156+
.with_description(e.to_string()),
157+
),
158+
)
159+
.into_response(),
143160

144161
Self::UnknownToken(_)
145162
| Self::UnexpectedTokenType
@@ -195,7 +212,7 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm
195212

196213
#[tracing::instrument(
197214
name = "handlers.oauth2.introspection.post",
198-
fields(client.id = client_authorization.client_id()),
215+
fields(client.id = credentials.client_id()),
199216
skip_all,
200217
)]
201218
#[allow(clippy::too_many_lines)]
@@ -205,28 +222,41 @@ pub(crate) async fn post(
205222
mut repo: BoxRepository,
206223
activity_tracker: ActivityTracker,
207224
State(encrypter): State<Encrypter>,
225+
State(homeserver): State<Arc<dyn HomeserverConnection>>,
208226
headers: HeaderMap,
209-
client_authorization: ClientAuthorization<IntrospectionRequest>,
227+
ClientAuthorization { credentials, form }: ClientAuthorization<IntrospectionRequest>,
210228
) -> Result<impl IntoResponse, RouteError> {
211-
let client = client_authorization
212-
.credentials
213-
.fetch(&mut repo)
214-
.await?
215-
.ok_or(RouteError::ClientNotFound)?;
216-
217-
let method = match &client.token_endpoint_auth_method {
218-
None | Some(OAuthClientAuthenticationMethod::None) => {
219-
return Err(RouteError::NotAllowed(client.id));
229+
if let Some(token) = credentials.bearer_token() {
230+
// If the client presented a bearer token, we check with the homeserver
231+
// configuration if it is allowed to use the introspection endpoint
232+
if !homeserver
233+
.verify_token(token)
234+
.await
235+
.map_err(RouteError::FailedToVerifyToken)?
236+
{
237+
return Err(RouteError::InvalidBearerToken);
220238
}
221-
Some(c) => c,
222-
};
239+
} else {
240+
// Otherwise, it presented regular client credentials, so we verify them
241+
let client = credentials
242+
.fetch(&mut repo)
243+
.await?
244+
.ok_or(RouteError::ClientNotFound)?;
245+
246+
// Only confidential clients are allowed to introspect
247+
let method = match &client.token_endpoint_auth_method {
248+
None | Some(OAuthClientAuthenticationMethod::None) => {
249+
return Err(RouteError::NotAllowed(client.id));
250+
}
251+
Some(c) => c,
252+
};
223253

224-
client_authorization
225-
.credentials
226-
.verify(&http_client, &encrypter, method, &client)
227-
.await?;
254+
credentials
255+
.verify(&http_client, &encrypter, method, &client)
256+
.await?;
257+
}
228258

229-
let Some(form) = client_authorization.form else {
259+
let Some(form) = form else {
230260
return Err(RouteError::BadRequest);
231261
};
232262

@@ -578,10 +608,11 @@ mod tests {
578608
use hyper::{Request, StatusCode};
579609
use mas_data_model::{AccessToken, RefreshToken};
580610
use mas_iana::oauth::OAuthTokenTypeHint;
581-
use mas_matrix::{HomeserverConnection, ProvisionRequest};
611+
use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest};
582612
use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute};
583613
use mas_storage::Clock;
584614
use oauth2_types::{
615+
errors::{ClientError, ClientErrorCode},
585616
registration::ClientRegistrationResponse,
586617
requests::IntrospectionResponse,
587618
scope::{OPENID, Scope},
@@ -984,4 +1015,29 @@ mod tests {
9841015
let response: IntrospectionResponse = response.json();
9851016
assert!(response.active);
9861017
}
1018+
1019+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1020+
async fn test_introspect_with_bearer_token(pool: PgPool) {
1021+
setup();
1022+
let state = TestState::from_pool(pool).await.unwrap();
1023+
1024+
// Check that talking to the introspection endpoint with the bearer token from
1025+
// the MockHomeserverConnection doens't error out
1026+
let request = Request::post(OAuth2Introspection::PATH)
1027+
.bearer(MockHomeserverConnection::VALID_BEARER_TOKEN)
1028+
.form(json!({ "token": "some_token" }));
1029+
let response = state.request(request).await;
1030+
response.assert_status(StatusCode::OK);
1031+
let response: IntrospectionResponse = response.json();
1032+
assert!(!response.active);
1033+
1034+
// Check with another token, we should get a 401
1035+
let request = Request::post(OAuth2Introspection::PATH)
1036+
.bearer("another_token")
1037+
.form(json!({ "token": "some_token" }));
1038+
let response = state.request(request).await;
1039+
response.assert_status(StatusCode::UNAUTHORIZED);
1040+
let response: ClientError = response.json();
1041+
assert_eq!(response.error, ClientErrorCode::AccessDenied);
1042+
}
9871043
}

crates/matrix-synapse/src/legacy.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ impl HomeserverConnection for SynapseConnection {
160160
&self.homeserver
161161
}
162162

163+
#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
164+
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
165+
Ok(self.access_token == token)
166+
}
167+
163168
#[tracing::instrument(
164169
name = "homeserver.query_user",
165170
skip_all,

crates/matrix-synapse/src/modern.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ impl HomeserverConnection for SynapseConnection {
6666
&self.homeserver
6767
}
6868

69+
#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
70+
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
71+
Ok(self.access_token == token)
72+
}
73+
6974
#[tracing::instrument(
7075
name = "homeserver.query_user",
7176
skip_all,

0 commit comments

Comments
 (0)