Skip to content
This repository was archived by the owner on Sep 10, 2024. It is now read-only.

Commit 27ca7ec

Browse files
committed
Add an extractor to check for credentails in the admin API
1 parent 43ff6dc commit 27ca7ec

File tree

3 files changed

+286
-1
lines changed

3 files changed

+286
-1
lines changed
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// Copyright 2024 The Matrix.org Foundation C.I.C.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::convert::Infallible;
16+
17+
use aide::OperationIo;
18+
use axum::{
19+
extract::FromRequestParts,
20+
response::{IntoResponse, Response},
21+
Json,
22+
};
23+
use axum_extra::TypedHeader;
24+
use headers::{authorization::Bearer, Authorization};
25+
use hyper::StatusCode;
26+
use mas_data_model::{Session, User};
27+
use mas_storage::{BoxClock, BoxRepository, RepositoryError};
28+
use ulid::Ulid;
29+
30+
use super::response::ErrorResponse;
31+
use crate::BoundActivityTracker;
32+
33+
#[derive(Debug, thiserror::Error)]
34+
pub enum Rejection {
35+
/// The authorization header is missing
36+
#[error("Missing authorization header")]
37+
MissingAuthorizationHeader,
38+
39+
/// The authorization header is invalid
40+
#[error("Invalid authorization header")]
41+
InvalidAuthorizationHeader,
42+
43+
/// Couldn't load the database repository
44+
#[error("Couldn't load the database repository")]
45+
RepositorySetup(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
46+
47+
/// A database operation failed
48+
#[error("Invalid repository operation")]
49+
Repository(#[from] RepositoryError),
50+
51+
/// The access token could not be found in the database
52+
#[error("Unknown access token")]
53+
UnknownAccessToken,
54+
55+
/// The access token provided expired
56+
#[error("Access token expired")]
57+
TokenExpired,
58+
59+
/// The session associated with the access token was revoked
60+
#[error("Access token revoked")]
61+
SessionRevoked,
62+
63+
/// The user associated with the session is locked
64+
#[error("User locked")]
65+
UserLocked,
66+
67+
/// Failed to load the session
68+
#[error("Failed to load session {0}")]
69+
LoadSession(Ulid),
70+
71+
/// Failed to load the user
72+
#[error("Failed to load user {0}")]
73+
LoadUser(Ulid),
74+
75+
/// The session does not have the `urn:mas:admin` scope
76+
#[error("Missing urn:mas:admin scope")]
77+
MissingScope,
78+
}
79+
80+
impl Rejection {
81+
fn status_code(&self) -> StatusCode {
82+
match self {
83+
Self::InvalidAuthorizationHeader | Self::MissingAuthorizationHeader => {
84+
StatusCode::BAD_REQUEST
85+
}
86+
Self::UnknownAccessToken
87+
| Self::TokenExpired
88+
| Self::SessionRevoked
89+
| Self::UserLocked
90+
| Self::MissingScope => StatusCode::UNAUTHORIZED,
91+
_ => StatusCode::INTERNAL_SERVER_ERROR,
92+
}
93+
}
94+
}
95+
96+
impl IntoResponse for Rejection {
97+
fn into_response(self) -> Response {
98+
let response = ErrorResponse::from_error(&self);
99+
let status = self.status_code();
100+
(status, Json(response)).into_response()
101+
}
102+
}
103+
104+
/// An extractor which authorizes the request
105+
///
106+
/// Because we need to load the database repository and the clock, we keep them
107+
/// in the context to avoid creating two instances for each request.
108+
#[non_exhaustive]
109+
#[derive(OperationIo)]
110+
#[aide(input)]
111+
pub struct CallContext {
112+
pub repo: BoxRepository,
113+
pub clock: BoxClock,
114+
pub user: Option<User>,
115+
pub session: Session,
116+
}
117+
118+
#[async_trait::async_trait]
119+
impl<S> FromRequestParts<S> for CallContext
120+
where
121+
S: Send + Sync,
122+
BoundActivityTracker: FromRequestParts<S, Rejection = Infallible>,
123+
BoxRepository: FromRequestParts<S>,
124+
BoxClock: FromRequestParts<S, Rejection = Infallible>,
125+
<BoxRepository as FromRequestParts<S>>::Rejection:
126+
Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
127+
{
128+
type Rejection = Rejection;
129+
130+
async fn from_request_parts(
131+
parts: &mut axum::http::request::Parts,
132+
state: &S,
133+
) -> Result<Self, Self::Rejection> {
134+
let activity_tracker = BoundActivityTracker::from_request_parts(parts, state).await;
135+
let activity_tracker = match activity_tracker {
136+
Ok(t) => t,
137+
Err(e) => match e {},
138+
};
139+
140+
let clock = BoxClock::from_request_parts(parts, state).await;
141+
let clock = match clock {
142+
Ok(c) => c,
143+
Err(e) => match e {},
144+
};
145+
146+
// Load the database repository
147+
let mut repo = BoxRepository::from_request_parts(parts, state)
148+
.await
149+
.map_err(Into::into)
150+
.map_err(Rejection::RepositorySetup)?;
151+
152+
// Extract the access token from the authorization header
153+
let token = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
154+
.await
155+
.map_err(|e| {
156+
// We map to two differentsson of errors depending on whether the header is
157+
// missing or invalid
158+
if e.is_missing() {
159+
Rejection::MissingAuthorizationHeader
160+
} else {
161+
Rejection::InvalidAuthorizationHeader
162+
}
163+
})?;
164+
165+
let token = token.token();
166+
167+
// Look for the access token in the database
168+
let token = repo
169+
.oauth2_access_token()
170+
.find_by_token(token)
171+
.await?
172+
.ok_or(Rejection::UnknownAccessToken)?;
173+
174+
// Look for the associated session in the database
175+
let session = repo
176+
.oauth2_session()
177+
.lookup(token.session_id)
178+
.await?
179+
.ok_or_else(|| Rejection::LoadSession(token.session_id))?;
180+
181+
// Record the activity on the session
182+
activity_tracker
183+
.record_oauth2_session(&clock, &session)
184+
.await;
185+
186+
// Load the user if there is one
187+
let user = if let Some(user_id) = session.user_id {
188+
let user = repo
189+
.user()
190+
.lookup(user_id)
191+
.await?
192+
.ok_or_else(|| Rejection::LoadUser(user_id))?;
193+
Some(user)
194+
} else {
195+
None
196+
};
197+
198+
// If there is a user for this session, check that it is not locked
199+
if let Some(user) = &user {
200+
if !user.is_valid() {
201+
return Err(Rejection::UserLocked);
202+
}
203+
}
204+
205+
if !session.is_valid() {
206+
return Err(Rejection::SessionRevoked);
207+
}
208+
209+
if !token.is_valid(clock.now()) {
210+
return Err(Rejection::TokenExpired);
211+
}
212+
213+
// For now, we only check that the session has the admin scope
214+
// Later we might want to check other route-specific scopes
215+
if !session.scope.contains("urn:mas:admin") {
216+
return Err(Rejection::MissingScope);
217+
}
218+
219+
Ok(Self {
220+
repo,
221+
clock,
222+
user,
223+
session,
224+
})
225+
}
226+
}

crates/handlers/src/admin/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,22 @@ use aide::{
1616
axum::ApiRouter,
1717
openapi::{OAuth2Flow, OAuth2Flows, OpenApi, SecurityScheme, Server, ServerVariable},
1818
};
19-
use axum::{Json, Router};
19+
use axum::{extract::FromRequestParts, Json, Router};
2020
use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
2121
use indexmap::IndexMap;
2222
use mas_http::CorsLayerExt;
2323
use mas_router::{OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, SimpleRoute};
2424
use tower_http::cors::{Any, CorsLayer};
2525

26+
mod call_context;
27+
mod response;
28+
29+
use self::call_context::CallContext;
30+
2631
pub fn router<S>() -> (OpenApi, Router<S>)
2732
where
2833
S: Clone + Send + Sync + 'static,
34+
CallContext: FromRequestParts<S>,
2935
{
3036
let mut api = OpenApi::default();
3137
let router = ApiRouter::<S>::new()
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2024 The Matrix.org Foundation C.I.C.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#![allow(clippy::module_name_repetitions)]
16+
17+
use schemars::JsonSchema;
18+
use serde::Serialize;
19+
20+
/// A single error
21+
#[derive(Serialize, JsonSchema)]
22+
struct Error {
23+
/// A human-readable title for the error
24+
title: String,
25+
}
26+
27+
impl Error {
28+
fn from_error(error: &(dyn std::error::Error + 'static)) -> Self {
29+
Self {
30+
title: error.to_string(),
31+
}
32+
}
33+
}
34+
35+
/// A top-level response with a list of errors
36+
#[derive(Serialize, JsonSchema)]
37+
pub struct ErrorResponse {
38+
/// The list of errors
39+
errors: Vec<Error>,
40+
}
41+
42+
impl ErrorResponse {
43+
/// Create a new error response from any Rust error
44+
pub fn from_error(error: &(dyn std::error::Error + 'static)) -> Self {
45+
let mut errors = Vec::new();
46+
let mut head = Some(error);
47+
while let Some(error) = head {
48+
errors.push(Error::from_error(error));
49+
head = error.source();
50+
}
51+
Self { errors }
52+
}
53+
}

0 commit comments

Comments
 (0)