Skip to content

Commit 5cbb576

Browse files
committed
Make the rate limiter available to the GraphQL API handlers
1 parent cb08854 commit 5cbb576

File tree

5 files changed

+38
-10
lines changed

5 files changed

+38
-10
lines changed

crates/cli/src/commands/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ impl Options {
209209
site_config.clone(),
210210
password_manager.clone(),
211211
url_builder.clone(),
212+
limiter.clone(),
212213
);
213214

214215
let state = {

crates/cli/src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 New Vector Ltd.
1+
// Copyright 2024, 2025 New Vector Ltd.
22
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
33
//
44
// SPDX-License-Identifier: AGPL-3.0-only

crates/handlers/src/graphql/mod.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 New Vector Ltd.
1+
// Copyright 2024, 2025 New Vector Ltd.
22
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
33
//
44
// SPDX-License-Identifier: AGPL-3.0-only
@@ -53,7 +53,10 @@ use self::{
5353
mutations::Mutation,
5454
query::Query,
5555
};
56-
use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker};
56+
use crate::{
57+
impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker, Limiter,
58+
RequesterFingerprint,
59+
};
5760

5861
#[cfg(test)]
5962
mod tests;
@@ -72,6 +75,7 @@ struct GraphQLState {
7275
site_config: SiteConfig,
7376
password_manager: PasswordManager,
7477
url_builder: UrlBuilder,
78+
limiter: Limiter,
7579
}
7680

7781
#[async_trait]
@@ -104,6 +108,10 @@ impl state::State for GraphQLState {
104108
&self.url_builder
105109
}
106110

111+
fn limiter(&self) -> &Limiter {
112+
&self.limiter
113+
}
114+
107115
fn clock(&self) -> BoxClock {
108116
let clock = SystemClock::default();
109117
Box::new(clock)
@@ -126,6 +134,7 @@ pub fn schema(
126134
site_config: SiteConfig,
127135
password_manager: PasswordManager,
128136
url_builder: UrlBuilder,
137+
limiter: Limiter,
129138
) -> Schema {
130139
let state = GraphQLState {
131140
pool: pool.clone(),
@@ -134,6 +143,7 @@ pub fn schema(
134143
site_config,
135144
password_manager,
136145
url_builder,
146+
limiter,
137147
};
138148
let state: BoxState = Box::new(state);
139149

@@ -303,6 +313,7 @@ pub async fn post(
303313
cookie_jar: CookieJar,
304314
content_type: Option<TypedHeader<ContentType>>,
305315
authorization: Option<TypedHeader<Authorization<Bearer>>>,
316+
requester_fingerprint: RequesterFingerprint,
306317
body: Body,
307318
) -> Result<impl IntoResponse, RouteError> {
308319
let body = body.into_data_stream();
@@ -329,6 +340,7 @@ pub async fn post(
329340
MultipartOptions::default(),
330341
)
331342
.await?
343+
.data(requester_fingerprint)
332344
.data(requester); // XXX: this should probably return another error response?
333345

334346
let span = span_for_graphql_request(&request);
@@ -355,6 +367,7 @@ pub async fn get(
355367
activity_tracker: BoundActivityTracker,
356368
cookie_jar: CookieJar,
357369
authorization: Option<TypedHeader<Authorization<Bearer>>>,
370+
requester_fingerprint: RequesterFingerprint,
358371
RawQuery(query): RawQuery,
359372
) -> Result<impl IntoResponse, FancyError> {
360373
let token = authorization
@@ -371,8 +384,9 @@ pub async fn get(
371384
)
372385
.await?;
373386

374-
let request =
375-
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
387+
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
388+
.data(requester)
389+
.data(requester_fingerprint);
376390

377391
let span = span_for_graphql_request(&request);
378392
let response = schema.execute(request).instrument(span).await;

crates/handlers/src/graphql/state.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 New Vector Ltd.
1+
// Copyright 2024, 2025 New Vector Ltd.
22
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
33
//
44
// SPDX-License-Identifier: AGPL-3.0-only
@@ -10,7 +10,7 @@ use mas_policy::Policy;
1010
use mas_router::UrlBuilder;
1111
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
1212

13-
use crate::{graphql::Requester, passwords::PasswordManager};
13+
use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};
1414

1515
#[async_trait::async_trait]
1616
pub trait State {
@@ -22,6 +22,7 @@ pub trait State {
2222
fn rng(&self) -> BoxRng;
2323
fn site_config(&self) -> &SiteConfig;
2424
fn url_builder(&self) -> &UrlBuilder;
25+
fn limiter(&self) -> &Limiter;
2526
}
2627

2728
pub type BoxState = Box<dyn State + Send + Sync + 'static>;
@@ -30,6 +31,8 @@ pub trait ContextExt {
3031
fn state(&self) -> &BoxState;
3132

3233
fn requester(&self) -> &Requester;
34+
35+
fn requester_fingerprint(&self) -> RequesterFingerprint;
3336
}
3437

3538
impl ContextExt for async_graphql::Context<'_> {
@@ -40,4 +43,8 @@ impl ContextExt for async_graphql::Context<'_> {
4043
fn requester(&self) -> &Requester {
4144
self.data_unchecked()
4245
}
46+
47+
fn requester_fingerprint(&self) -> RequesterFingerprint {
48+
*self.data_unchecked()
49+
}
4350
}

crates/handlers/src/test_utils.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2024 New Vector Ltd.
1+
// Copyright 2024, 2025 New Vector Ltd.
22
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
33
//
44
// SPDX-License-Identifier: AGPL-3.0-only
@@ -204,6 +204,8 @@ impl TestState {
204204
let clock = Arc::new(MockClock::default());
205205
let rng = Arc::new(Mutex::new(ChaChaRng::seed_from_u64(42)));
206206

207+
let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
208+
207209
let graphql_state = TestGraphQLState {
208210
pool: pool.clone(),
209211
policy_factory: Arc::clone(&policy_factory),
@@ -213,6 +215,7 @@ impl TestState {
213215
clock: Arc::clone(&clock),
214216
password_manager: password_manager.clone(),
215217
url_builder: url_builder.clone(),
218+
limiter: limiter.clone(),
216219
};
217220
let state: crate::graphql::BoxState = Box::new(graphql_state);
218221

@@ -225,8 +228,6 @@ impl TestState {
225228
shutdown_token.child_token(),
226229
);
227230

228-
let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
229-
230231
Ok(Self {
231232
pool,
232233
templates,
@@ -379,6 +380,7 @@ struct TestGraphQLState {
379380
rng: Arc<Mutex<ChaChaRng>>,
380381
password_manager: PasswordManager,
381382
url_builder: UrlBuilder,
383+
limiter: Limiter,
382384
}
383385

384386
#[async_trait]
@@ -415,6 +417,10 @@ impl graphql::State for TestGraphQLState {
415417
&self.site_config
416418
}
417419

420+
fn limiter(&self) -> &Limiter {
421+
&self.limiter
422+
}
423+
418424
fn rng(&self) -> BoxRng {
419425
let mut parent_rng = self.rng.lock().expect("Failed to lock RNG");
420426
let rng = ChaChaRng::from_rng(&mut *parent_rng).expect("Failed to seed RNG");

0 commit comments

Comments
 (0)