Skip to content
This repository was archived by the owner on Sep 10, 2024. It is now read-only.

Commit c0c88a6

Browse files
committed
Merge branch 'quenting/admin-api/create-user' into HEAD
2 parents 53061ff + dacb59b commit c0c88a6

File tree

7 files changed

+495
-10
lines changed

7 files changed

+495
-10
lines changed

crates/handlers/src/admin/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ use aide::{
1616
axum::ApiRouter,
1717
openapi::{OAuth2Flow, OAuth2Flows, OpenApi, SecurityScheme, Server, ServerVariable},
1818
};
19-
use axum::{extract::FromRequestParts, Json, Router};
19+
use axum::{
20+
extract::{FromRef, FromRequestParts},
21+
Json, Router,
22+
};
2023
use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
2124
use indexmap::IndexMap;
2225
use mas_http::CorsLayerExt;
26+
use mas_matrix::BoxHomeserverConnection;
2327
use mas_router::{OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, SimpleRoute};
28+
use mas_storage::BoxRng;
2429
use tower_http::cors::{Any, CorsLayer};
2530

2631
mod call_context;
@@ -34,6 +39,8 @@ use self::call_context::CallContext;
3439
pub fn router<S>() -> (OpenApi, Router<S>)
3540
where
3641
S: Clone + Send + Sync + 'static,
42+
BoxHomeserverConnection: FromRef<S>,
43+
BoxRng: FromRequestParts<S>,
3744
CallContext: FromRequestParts<S>,
3845
{
3946
let mut api = OpenApi::default();

crates/handlers/src/admin/v1/mod.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
// limitations under the License.
1414

1515
use aide::axum::{routing::get_with, ApiRouter};
16-
use axum::extract::FromRequestParts;
16+
use axum::extract::{FromRef, FromRequestParts};
17+
use mas_matrix::BoxHomeserverConnection;
18+
use mas_storage::BoxRng;
1719

1820
use super::call_context::CallContext;
1921

@@ -22,10 +24,16 @@ mod users;
2224
pub fn router<S>() -> ApiRouter<S>
2325
where
2426
S: Clone + Send + Sync + 'static,
27+
BoxHomeserverConnection: FromRef<S>,
28+
BoxRng: FromRequestParts<S>,
2529
CallContext: FromRequestParts<S>,
2630
{
2731
ApiRouter::<S>::new()
28-
.api_route("/users", get_with(self::users::list, self::users::list_doc))
32+
.api_route(
33+
"/users",
34+
get_with(self::users::list, self::users::list_doc)
35+
.post_with(self::users::add, self::users::add_doc),
36+
)
2937
.api_route(
3038
"/users/:id",
3139
get_with(self::users::get, self::users::get_doc),
Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
// Copyright 2024 The Matrix.org Foundation C.I.C.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use aide::{transform::TransformOperation, NoApi, OperationIo};
16+
use axum::{extract::State, response::IntoResponse, Json};
17+
use hyper::StatusCode;
18+
use mas_matrix::BoxHomeserverConnection;
19+
use mas_storage::{
20+
job::{JobRepositoryExt, ProvisionUserJob},
21+
BoxRng,
22+
};
23+
use schemars::JsonSchema;
24+
use serde::Deserialize;
25+
use tracing::warn;
26+
27+
use crate::{
28+
admin::{
29+
call_context::CallContext,
30+
model::User,
31+
response::{ErrorResponse, SingleResponse},
32+
},
33+
impl_from_error_for_route,
34+
};
35+
36+
fn valid_username_character(c: char) -> bool {
37+
c.is_ascii_lowercase()
38+
|| c.is_ascii_digit()
39+
|| c == '='
40+
|| c == '_'
41+
|| c == '-'
42+
|| c == '.'
43+
|| c == '/'
44+
|| c == '+'
45+
}
46+
47+
// XXX: this should be shared with the graphql handler
48+
fn username_valid(username: &str) -> bool {
49+
if username.is_empty() || username.len() > 255 {
50+
return false;
51+
}
52+
53+
// Should not start with an underscore
54+
if username.get(0..1) == Some("_") {
55+
return false;
56+
}
57+
58+
// Should only contain valid characters
59+
if !username.chars().all(valid_username_character) {
60+
return false;
61+
}
62+
63+
true
64+
}
65+
66+
#[derive(Debug, thiserror::Error, OperationIo)]
67+
#[aide(output_with = "Json<ErrorResponse>")]
68+
pub enum RouteError {
69+
#[error(transparent)]
70+
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
71+
72+
#[error(transparent)]
73+
Homeserver(anyhow::Error),
74+
75+
#[error("Username is not valid")]
76+
UsernameNotValid,
77+
78+
#[error("User already exists")]
79+
UserAlreadyExists,
80+
81+
#[error("Username is reserved by the homeserver")]
82+
UsernameReserved,
83+
}
84+
85+
impl_from_error_for_route!(mas_storage::RepositoryError);
86+
87+
impl IntoResponse for RouteError {
88+
fn into_response(self) -> axum::response::Response {
89+
let error = ErrorResponse::from_error(&self);
90+
let status = match self {
91+
Self::Internal(_) | Self::Homeserver(_) => StatusCode::INTERNAL_SERVER_ERROR,
92+
Self::UsernameNotValid => StatusCode::BAD_REQUEST,
93+
Self::UserAlreadyExists | Self::UsernameReserved => StatusCode::CONFLICT,
94+
};
95+
(status, Json(error)).into_response()
96+
}
97+
}
98+
99+
/// # JSON payload for the `POST /api/admin/v1/users` endpoint
100+
#[derive(Deserialize, JsonSchema)]
101+
#[schemars(rename = "AddUserPayload")]
102+
pub struct Payload {
103+
/// The username of the user to add.
104+
username: String,
105+
106+
/// Skip checking with the homeserver whether the username is valid.
107+
///
108+
/// Use this with caution! The main reason to use this, is when a user used
109+
/// by an application service needs to exist in MAS to craft special
110+
/// tokens (like with admin access) for them
111+
#[serde(default)]
112+
skip_homeserver_check: bool,
113+
}
114+
115+
pub fn doc(operation: TransformOperation) -> TransformOperation {
116+
operation
117+
.summary("Create a new user")
118+
.tag("user")
119+
.response_with::<200, Json<SingleResponse<User>>, _>(|t| {
120+
let [sample, ..] = User::samples();
121+
let response = SingleResponse::new_canonical(sample);
122+
t.description("User was created").example(response)
123+
})
124+
.response_with::<400, RouteError, _>(|t| {
125+
let response = ErrorResponse::from_error(&RouteError::UsernameNotValid);
126+
t.description("Username is not valid").example(response)
127+
})
128+
.response_with::<409, RouteError, _>(|t| {
129+
let response = ErrorResponse::from_error(&RouteError::UserAlreadyExists);
130+
t.description("User already exists").example(response)
131+
})
132+
.response_with::<409, RouteError, _>(|t| {
133+
let response = ErrorResponse::from_error(&RouteError::UsernameReserved);
134+
t.description("Username is reserved by the homeserver")
135+
.example(response)
136+
})
137+
}
138+
139+
#[tracing::instrument(name = "handler.admin.v1.users.add", skip_all, err)]
140+
pub async fn handler(
141+
CallContext {
142+
mut repo, clock, ..
143+
}: CallContext,
144+
NoApi(mut rng): NoApi<BoxRng>,
145+
State(homeserver): State<BoxHomeserverConnection>,
146+
Json(params): Json<Payload>,
147+
) -> Result<Json<SingleResponse<User>>, RouteError> {
148+
if repo.user().exists(&params.username).await? {
149+
return Err(RouteError::UserAlreadyExists);
150+
}
151+
152+
// Do some basic check on the username
153+
if !username_valid(&params.username) {
154+
return Err(RouteError::UsernameNotValid);
155+
}
156+
157+
// Ask the homeserver if the username is available
158+
let homeserver_available = homeserver
159+
.is_localpart_available(&params.username)
160+
.await
161+
.map_err(RouteError::Homeserver)?;
162+
163+
if !homeserver_available {
164+
if !params.skip_homeserver_check {
165+
return Err(RouteError::UsernameReserved);
166+
}
167+
168+
// If we skipped the check, we still want to shout about it
169+
warn!("Skipped homeserver check for username {}", params.username);
170+
}
171+
172+
let user = repo.user().add(&mut rng, &clock, params.username).await?;
173+
174+
repo.job()
175+
.schedule_job(ProvisionUserJob::new(&user))
176+
.await?;
177+
178+
repo.save().await?;
179+
180+
Ok(Json(SingleResponse::new_canonical(User::from(user))))
181+
}
182+
183+
#[cfg(test)]
184+
mod tests {
185+
use hyper::{Request, StatusCode};
186+
use mas_storage::{user::UserRepository, RepositoryAccess};
187+
use sqlx::PgPool;
188+
189+
use crate::test_utils::{setup, RequestBuilderExt, ResponseExt, TestState};
190+
191+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
192+
async fn test_add_user(pool: PgPool) {
193+
setup();
194+
let mut state = TestState::from_pool(pool).await.unwrap();
195+
let token = state.token_with_scope("urn:mas:admin").await;
196+
197+
let request = Request::post("/api/admin/v1/users")
198+
.bearer(&token)
199+
.json(serde_json::json!({
200+
"username": "alice",
201+
}));
202+
203+
let response = state.request(request).await;
204+
response.assert_status(StatusCode::OK);
205+
206+
let body: serde_json::Value = response.json();
207+
assert_eq!(body["data"]["type"], "user");
208+
let id = body["data"]["id"].as_str().unwrap();
209+
assert_eq!(body["data"]["attributes"]["username"], "alice");
210+
211+
// Check that the user was created in the database
212+
let mut repo = state.repository().await.unwrap();
213+
let user = repo
214+
.user()
215+
.lookup(id.parse().unwrap())
216+
.await
217+
.unwrap()
218+
.unwrap();
219+
220+
assert_eq!(user.username, "alice");
221+
}
222+
223+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
224+
async fn test_add_user_invalid_username(pool: PgPool) {
225+
setup();
226+
let mut state = TestState::from_pool(pool).await.unwrap();
227+
let token = state.token_with_scope("urn:mas:admin").await;
228+
229+
let request = Request::post("/api/admin/v1/users")
230+
.bearer(&token)
231+
.json(serde_json::json!({
232+
"username": "this is invalid",
233+
}));
234+
235+
let response = state.request(request).await;
236+
response.assert_status(StatusCode::BAD_REQUEST);
237+
238+
let body: serde_json::Value = response.json();
239+
assert_eq!(body["errors"][0]["title"], "Username is not valid");
240+
}
241+
242+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
243+
async fn test_add_user_exists(pool: PgPool) {
244+
setup();
245+
let mut state = TestState::from_pool(pool).await.unwrap();
246+
let token = state.token_with_scope("urn:mas:admin").await;
247+
248+
let request = Request::post("/api/admin/v1/users")
249+
.bearer(&token)
250+
.json(serde_json::json!({
251+
"username": "alice",
252+
}));
253+
254+
let response = state.request(request).await;
255+
response.assert_status(StatusCode::OK);
256+
257+
let body: serde_json::Value = response.json();
258+
assert_eq!(body["data"]["type"], "user");
259+
assert_eq!(body["data"]["attributes"]["username"], "alice");
260+
261+
let request = Request::post("/api/admin/v1/users")
262+
.bearer(&token)
263+
.json(serde_json::json!({
264+
"username": "alice",
265+
}));
266+
267+
let response = state.request(request).await;
268+
response.assert_status(StatusCode::CONFLICT);
269+
270+
let body: serde_json::Value = response.json();
271+
assert_eq!(body["errors"][0]["title"], "User already exists");
272+
}
273+
274+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
275+
async fn test_add_user_reserved(pool: PgPool) {
276+
setup();
277+
let mut state = TestState::from_pool(pool).await.unwrap();
278+
let token = state.token_with_scope("urn:mas:admin").await;
279+
280+
// Reserve a username on the homeserver and try to add it
281+
state.homeserver_connection.reserve_localpart("bob").await;
282+
283+
let request = Request::post("/api/admin/v1/users")
284+
.bearer(&token)
285+
.json(serde_json::json!({
286+
"username": "bob",
287+
}));
288+
289+
let response = state.request(request).await;
290+
291+
let body: serde_json::Value = response.json();
292+
assert_eq!(
293+
body["errors"][0]["title"],
294+
"Username is reserved by the homeserver"
295+
);
296+
297+
// But we can force it with the skip_homeserver_check flag
298+
let request = Request::post("/api/admin/v1/users")
299+
.bearer(&token)
300+
.json(serde_json::json!({
301+
"username": "bob",
302+
"skip_homeserver_check": true,
303+
}));
304+
305+
let response = state.request(request).await;
306+
response.assert_status(StatusCode::OK);
307+
308+
let body: serde_json::Value = response.json();
309+
let id = body["data"]["id"].as_str().unwrap();
310+
assert_eq!(body["data"]["attributes"]["username"], "bob");
311+
312+
// Check that the user was created in the database
313+
let mut repo = state.repository().await.unwrap();
314+
let user = repo
315+
.user()
316+
.lookup(id.parse().unwrap())
317+
.await
318+
.unwrap()
319+
.unwrap();
320+
321+
assert_eq!(user.username, "bob");
322+
}
323+
}

crates/handlers/src/admin/v1/users/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
mod add;
1516
mod by_username;
1617
mod get;
1718
mod list;
1819

1920
pub use self::{
21+
add::{doc as add_doc, handler as add},
2022
by_username::{doc as by_username_doc, handler as by_username},
2123
get::{doc as get_doc, handler as get},
2224
list::{doc as list_doc, handler as list},

0 commit comments

Comments
 (0)