Skip to content

Commit 61091ff

Browse files
committed
Admin API endpoint to add upstream link
1 parent 5a1ac37 commit 61091ff

File tree

4 files changed

+485
-0
lines changed

4 files changed

+485
-0
lines changed

crates/handlers/src/admin/v1/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ mod user_emails;
2626
mod user_sessions;
2727
mod users;
2828

29+
#[allow(clippy::too_many_lines)]
2930
pub fn router<S>() -> ApiRouter<S>
3031
where
3132
S: Clone + Send + Sync + 'static,
@@ -123,6 +124,10 @@ where
123124
get_with(
124125
self::upstream_oauth_links::list,
125126
self::upstream_oauth_links::list_doc,
127+
)
128+
.post_with(
129+
self::upstream_oauth_links::add,
130+
self::upstream_oauth_links::add_doc,
126131
),
127132
)
128133
.api_route(
Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
// Copyright 2025 New Vector Ltd.
2+
//
3+
// SPDX-License-Identifier: AGPL-3.0-only
4+
// Please see LICENSE in the repository root for full details.
5+
6+
use aide::{NoApi, OperationIo, transform::TransformOperation};
7+
use axum::{Json, response::IntoResponse};
8+
use hyper::StatusCode;
9+
use mas_storage::{BoxRng, upstream_oauth2::UpstreamOAuthLinkFilter};
10+
use schemars::JsonSchema;
11+
use serde::Deserialize;
12+
use ulid::Ulid;
13+
14+
use crate::{
15+
admin::{
16+
call_context::CallContext,
17+
model::{Resource, UpstreamOAuthLink, User},
18+
response::{ErrorResponse, SingleResponse},
19+
},
20+
impl_from_error_for_route,
21+
};
22+
23+
#[derive(Debug, thiserror::Error, OperationIo)]
24+
#[aide(output_with = "Json<ErrorResponse>")]
25+
pub enum RouteError {
26+
#[error(transparent)]
27+
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
28+
29+
#[error("User ID {0} already has an upstream link for Upstream Oauth 2.0 Provider ID {1}")]
30+
LinkAlreadyExists(Ulid, Ulid),
31+
32+
#[error("User ID {0} not found")]
33+
UserNotFound(Ulid),
34+
35+
#[error("Upstream OAuth 2.0 Provider ID {0} not found")]
36+
ProviderNotFound(Ulid),
37+
}
38+
39+
impl_from_error_for_route!(mas_storage::RepositoryError);
40+
41+
impl IntoResponse for RouteError {
42+
fn into_response(self) -> axum::response::Response {
43+
let error = ErrorResponse::from_error(&self);
44+
let status = match self {
45+
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
46+
Self::LinkAlreadyExists(_, _) => StatusCode::CONFLICT,
47+
Self::UserNotFound(_) | Self::ProviderNotFound(_) => StatusCode::NOT_FOUND,
48+
};
49+
(status, Json(error)).into_response()
50+
}
51+
}
52+
53+
/// # JSON payload for the `POST /api/admin/v1/upstream-oauth-links`
54+
#[derive(Deserialize, JsonSchema)]
55+
#[serde(rename = "AddUpstreamOauthLinkRequest")]
56+
pub struct Request {
57+
/// The ID of the user to which the link should be added.
58+
#[schemars(with = "crate::admin::schema::Ulid")]
59+
user_id: Ulid,
60+
61+
/// The ID of the upstream provider to which the link is for.
62+
#[schemars(with = "crate::admin::schema::Ulid")]
63+
provider_id: Ulid,
64+
65+
/// The subject (sub) claim of the user on the provider.
66+
subject: String,
67+
68+
/// A human readable account name.
69+
human_account_name: Option<String>,
70+
}
71+
72+
pub fn doc(operation: TransformOperation) -> TransformOperation {
73+
operation
74+
.id("addUpstreamOAuthLink")
75+
.summary("Add an upstream OAuth 2.0 link")
76+
.tag("upstream-oauth-link")
77+
.response_with::<201, Json<SingleResponse<UpstreamOAuthLink>>, _>(|t| {
78+
let [sample, ..] = UpstreamOAuthLink::samples();
79+
let response = SingleResponse::new_canonical(sample);
80+
t.description("Upstream OAuth 2.0 link was created")
81+
.example(response)
82+
})
83+
.response_with::<409, RouteError, _>(|t| {
84+
let [provider_sample, ..] = UpstreamOAuthLink::samples();
85+
let [user_sample, ..] = User::samples();
86+
let response = ErrorResponse::from_error(&RouteError::LinkAlreadyExists(
87+
user_sample.id(),
88+
provider_sample.id(),
89+
));
90+
t.description("User already has an upstream link for this provider")
91+
.example(response)
92+
})
93+
.response_with::<404, RouteError, _>(|t| {
94+
let response = ErrorResponse::from_error(&RouteError::UserNotFound(Ulid::nil()));
95+
t.description("User or provider was not found")
96+
.example(response)
97+
})
98+
}
99+
100+
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.post", skip_all, err)]
101+
pub async fn handler(
102+
CallContext {
103+
mut repo, clock, ..
104+
}: CallContext,
105+
NoApi(mut rng): NoApi<BoxRng>,
106+
Json(params): Json<Request>,
107+
) -> Result<(StatusCode, Json<SingleResponse<UpstreamOAuthLink>>), RouteError> {
108+
// Find the user
109+
let user = repo
110+
.user()
111+
.lookup(params.user_id)
112+
.await?
113+
.ok_or(RouteError::UserNotFound(params.user_id))?;
114+
115+
// Find the provider
116+
let provider = repo
117+
.upstream_oauth_provider()
118+
.lookup(params.provider_id)
119+
.await?
120+
.ok_or(RouteError::ProviderNotFound(params.provider_id))?;
121+
122+
let filter = UpstreamOAuthLinkFilter::new()
123+
.for_user(&user)
124+
.for_provider(&provider);
125+
let count = repo.upstream_oauth_link().count(filter).await?;
126+
127+
if count > 0 {
128+
return Err(RouteError::LinkAlreadyExists(
129+
params.user_id,
130+
params.provider_id,
131+
));
132+
}
133+
134+
let mut link = repo
135+
.upstream_oauth_link()
136+
.add(
137+
&mut rng,
138+
&clock,
139+
&provider,
140+
params.subject,
141+
params.human_account_name,
142+
)
143+
.await?;
144+
145+
repo.upstream_oauth_link()
146+
.associate_to_user(&link, &user)
147+
.await?;
148+
link.user_id = Some(user.id);
149+
150+
repo.save().await?;
151+
152+
Ok((
153+
StatusCode::CREATED,
154+
Json(SingleResponse::new_canonical(link.into())),
155+
))
156+
}
157+
158+
#[cfg(test)]
159+
mod tests {
160+
use hyper::{Request, StatusCode};
161+
use insta::assert_json_snapshot;
162+
use sqlx::PgPool;
163+
use ulid::Ulid;
164+
165+
use super::super::test_utils;
166+
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
167+
168+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
169+
async fn test_create(pool: PgPool) {
170+
setup();
171+
let mut state = TestState::from_pool(pool).await.unwrap();
172+
let token = state.token_with_scope("urn:mas:admin").await;
173+
let mut rng = state.rng();
174+
let mut repo = state.repository().await.unwrap();
175+
176+
let alice = repo
177+
.user()
178+
.add(&mut rng, &state.clock, "alice".to_owned())
179+
.await
180+
.unwrap();
181+
182+
let provider = repo
183+
.upstream_oauth_provider()
184+
.add(
185+
&mut rng,
186+
&state.clock,
187+
test_utils::oidc_provider_params("provider1"),
188+
)
189+
.await
190+
.unwrap();
191+
192+
repo.save().await.unwrap();
193+
194+
let request = Request::post("/api/admin/v1/upstream-oauth-links")
195+
.bearer(&token)
196+
.json(serde_json::json!({
197+
"user_id": alice.id,
198+
"provider_id": provider.id,
199+
"subject": "subject1"
200+
}));
201+
let response = state.request(request).await;
202+
response.assert_status(StatusCode::CREATED);
203+
let body: serde_json::Value = response.json();
204+
assert_json_snapshot!(body, @r###"
205+
{
206+
"data": {
207+
"type": "upstream-oauth-link",
208+
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
209+
"attributes": {
210+
"created_at": "2022-01-16T14:40:00Z",
211+
"provider_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
212+
"subject": "subject1",
213+
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
214+
"human_account_name": null
215+
},
216+
"links": {
217+
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG07HNEZXNQM2KNBNF6"
218+
}
219+
},
220+
"links": {
221+
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG07HNEZXNQM2KNBNF6"
222+
}
223+
}
224+
"###);
225+
}
226+
227+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
228+
async fn test_link_already_exists(pool: PgPool) {
229+
setup();
230+
let mut state = TestState::from_pool(pool).await.unwrap();
231+
let token = state.token_with_scope("urn:mas:admin").await;
232+
let mut rng = state.rng();
233+
let mut repo = state.repository().await.unwrap();
234+
235+
let alice = repo
236+
.user()
237+
.add(&mut rng, &state.clock, "alice".to_owned())
238+
.await
239+
.unwrap();
240+
241+
let provider = repo
242+
.upstream_oauth_provider()
243+
.add(
244+
&mut rng,
245+
&state.clock,
246+
test_utils::oidc_provider_params("provider1"),
247+
)
248+
.await
249+
.unwrap();
250+
251+
let link = repo
252+
.upstream_oauth_link()
253+
.add(
254+
&mut rng,
255+
&state.clock,
256+
&provider,
257+
String::from("subject1"),
258+
None,
259+
)
260+
.await
261+
.unwrap();
262+
263+
repo.upstream_oauth_link()
264+
.associate_to_user(&link, &alice)
265+
.await
266+
.unwrap();
267+
268+
repo.save().await.unwrap();
269+
270+
let request = Request::post("/api/admin/v1/upstream-oauth-links")
271+
.bearer(&token)
272+
.json(serde_json::json!({
273+
"user_id": alice.id,
274+
"provider_id": provider.id,
275+
"subject": "subject1"
276+
}));
277+
let response = state.request(request).await;
278+
response.assert_status(StatusCode::CONFLICT);
279+
let body: serde_json::Value = response.json();
280+
assert_json_snapshot!(body, @r###"
281+
{
282+
"errors": [
283+
{
284+
"title": "User ID 01FSHN9AG0MZAA6S4AF7CTV32E already has an upstream link for Upstream Oauth 2.0 Provider ID 01FSHN9AG0AJ6AC5HQ9X6H4RP4"
285+
}
286+
]
287+
}
288+
"###);
289+
}
290+
291+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
292+
async fn test_user_not_found(pool: PgPool) {
293+
setup();
294+
let mut state = TestState::from_pool(pool).await.unwrap();
295+
let token = state.token_with_scope("urn:mas:admin").await;
296+
let mut rng = state.rng();
297+
let mut repo = state.repository().await.unwrap();
298+
299+
let provider = repo
300+
.upstream_oauth_provider()
301+
.add(
302+
&mut rng,
303+
&state.clock,
304+
test_utils::oidc_provider_params("provider1"),
305+
)
306+
.await
307+
.unwrap();
308+
309+
repo.save().await.unwrap();
310+
311+
let request = Request::post("/api/admin/v1/upstream-oauth-links")
312+
.bearer(&token)
313+
.json(serde_json::json!({
314+
"user_id": Ulid::nil(),
315+
"provider_id": provider.id,
316+
"subject": "subject1"
317+
}));
318+
let response = state.request(request).await;
319+
response.assert_status(StatusCode::NOT_FOUND);
320+
let body: serde_json::Value = response.json();
321+
assert_json_snapshot!(body, @r###"
322+
{
323+
"errors": [
324+
{
325+
"title": "User ID 00000000000000000000000000 not found"
326+
}
327+
]
328+
}
329+
"###);
330+
}
331+
332+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
333+
async fn test_provider_not_found(pool: PgPool) {
334+
setup();
335+
let mut state = TestState::from_pool(pool).await.unwrap();
336+
let token = state.token_with_scope("urn:mas:admin").await;
337+
let mut rng = state.rng();
338+
let mut repo = state.repository().await.unwrap();
339+
340+
let alice = repo
341+
.user()
342+
.add(&mut rng, &state.clock, "alice".to_owned())
343+
.await
344+
.unwrap();
345+
346+
repo.save().await.unwrap();
347+
348+
let request = Request::post("/api/admin/v1/upstream-oauth-links")
349+
.bearer(&token)
350+
.json(serde_json::json!({
351+
"user_id": alice.id,
352+
"provider_id": Ulid::nil(),
353+
"subject": "subject1"
354+
}));
355+
let response = state.request(request).await;
356+
response.assert_status(StatusCode::NOT_FOUND);
357+
let body: serde_json::Value = response.json();
358+
assert_json_snapshot!(body, @r###"
359+
{
360+
"errors": [
361+
{
362+
"title": "Upstream OAuth 2.0 Provider ID 00000000000000000000000000 not found"
363+
}
364+
]
365+
}
366+
"###);
367+
}
368+
}

0 commit comments

Comments
 (0)