Skip to content

Commit 88f6171

Browse files
authored
Allow banning IPs and user agents through the policy (#4048)
2 parents ffb6e2e + 7c09b45 commit 88f6171

File tree

34 files changed

+702
-285
lines changed

34 files changed

+702
-285
lines changed

.github/actions/build-policies/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ runs:
77
- name: Install Open Policy Agent
88
uses: open-policy-agent/[email protected]
99
with:
10-
version: 0.70.0
10+
version: 1.1.0
1111

1212
- name: Build the policies
1313
run: make

crates/handlers/src/graphql/mod.rs

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#![allow(clippy::module_name_repetitions)]
88

9-
use std::sync::Arc;
9+
use std::{net::IpAddr, ops::Deref, sync::Arc};
1010

1111
use async_graphql::{
1212
extensions::Tracing,
@@ -238,9 +238,10 @@ async fn get_requester(
238238
activity_tracker: &BoundActivityTracker,
239239
mut repo: BoxRepository,
240240
session_info: SessionInfo,
241+
user_agent: Option<String>,
241242
token: Option<&str>,
242243
) -> Result<Requester, RouteError> {
243-
let requester = if let Some(token) = token {
244+
let entity = if let Some(token) = token {
244245
// If we haven't enabled undocumented_oauth2_access on the listener, we bail out
245246
if !undocumented_oauth2_access {
246247
return Err(RouteError::InvalidToken);
@@ -285,7 +286,7 @@ async fn get_requester(
285286
return Err(RouteError::MissingScope);
286287
}
287288

288-
Requester::OAuth2Session(Box::new((session, user)))
289+
RequestingEntity::OAuth2Session(Box::new((session, user)))
289290
} else {
290291
let maybe_session = session_info.load_session(&mut repo).await?;
291292

@@ -295,8 +296,15 @@ async fn get_requester(
295296
.await;
296297
}
297298

298-
Requester::from(maybe_session)
299+
RequestingEntity::from(maybe_session)
299300
};
301+
302+
let requester = Requester {
303+
entity,
304+
ip_address: activity_tracker.ip(),
305+
user_agent,
306+
};
307+
300308
repo.cancel().await?;
301309
Ok(requester)
302310
}
@@ -312,20 +320,22 @@ pub async fn post(
312320
cookie_jar: CookieJar,
313321
content_type: Option<TypedHeader<ContentType>>,
314322
authorization: Option<TypedHeader<Authorization<Bearer>>>,
315-
requester_fingerprint: RequesterFingerprint,
323+
user_agent: Option<TypedHeader<headers::UserAgent>>,
316324
body: Body,
317325
) -> Result<impl IntoResponse, RouteError> {
318326
let body = body.into_data_stream();
319327
let token = authorization
320328
.as_ref()
321329
.map(|TypedHeader(Authorization(bearer))| bearer.token());
330+
let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
322331
let (session_info, _cookie_jar) = cookie_jar.session_info();
323332
let requester = get_requester(
324333
undocumented_oauth2_access,
325334
&clock,
326335
&activity_tracker,
327336
repo,
328337
session_info,
338+
user_agent,
329339
token,
330340
)
331341
.await?;
@@ -339,7 +349,6 @@ pub async fn post(
339349
MultipartOptions::default(),
340350
)
341351
.await?
342-
.data(requester_fingerprint)
343352
.data(requester); // XXX: this should probably return another error response?
344353

345354
let span = span_for_graphql_request(&request);
@@ -366,26 +375,27 @@ pub async fn get(
366375
activity_tracker: BoundActivityTracker,
367376
cookie_jar: CookieJar,
368377
authorization: Option<TypedHeader<Authorization<Bearer>>>,
369-
requester_fingerprint: RequesterFingerprint,
378+
user_agent: Option<TypedHeader<headers::UserAgent>>,
370379
RawQuery(query): RawQuery,
371380
) -> Result<impl IntoResponse, FancyError> {
372381
let token = authorization
373382
.as_ref()
374383
.map(|TypedHeader(Authorization(bearer))| bearer.token());
384+
let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
375385
let (session_info, _cookie_jar) = cookie_jar.session_info();
376386
let requester = get_requester(
377387
undocumented_oauth2_access,
378388
&clock,
379389
&activity_tracker,
380390
repo,
381391
session_info,
392+
user_agent,
382393
token,
383394
)
384395
.await?;
385396

386-
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
387-
.data(requester)
388-
.data(requester_fingerprint);
397+
let request =
398+
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
389399

390400
let span = span_for_graphql_request(&request);
391401
let response = schema.execute(request).instrument(span).await;
@@ -417,9 +427,40 @@ pub fn schema_builder() -> SchemaBuilder {
417427
.register_output_type::<CreationEvent>()
418428
}
419429

430+
pub struct Requester {
431+
entity: RequestingEntity,
432+
ip_address: Option<IpAddr>,
433+
user_agent: Option<String>,
434+
}
435+
436+
impl Requester {
437+
pub fn fingerprint(&self) -> RequesterFingerprint {
438+
if let Some(ip) = self.ip_address {
439+
RequesterFingerprint::new(ip)
440+
} else {
441+
RequesterFingerprint::EMPTY
442+
}
443+
}
444+
445+
pub fn for_policy(&self) -> mas_policy::Requester {
446+
mas_policy::Requester {
447+
ip_address: self.ip_address,
448+
user_agent: self.user_agent.clone(),
449+
}
450+
}
451+
}
452+
453+
impl Deref for Requester {
454+
type Target = RequestingEntity;
455+
456+
fn deref(&self) -> &Self::Target {
457+
&self.entity
458+
}
459+
}
460+
420461
/// The identity of the requester.
421462
#[derive(Debug, Clone, Default, PartialEq, Eq)]
422-
pub enum Requester {
463+
pub enum RequestingEntity {
423464
/// The requester presented no authentication information.
424465
#[default]
425466
Anonymous,
@@ -480,7 +521,7 @@ impl OwnerId for UserId {
480521
}
481522
}
482523

483-
impl Requester {
524+
impl RequestingEntity {
484525
fn browser_session(&self) -> Option<&BrowserSession> {
485526
match self {
486527
Self::BrowserSession(session) => Some(session),
@@ -532,17 +573,21 @@ impl Requester {
532573
Self::BrowserSession(_) | Self::Anonymous => false,
533574
}
534575
}
576+
577+
fn is_unauthenticated(&self) -> bool {
578+
matches!(self, Self::Anonymous)
579+
}
535580
}
536581

537-
impl From<BrowserSession> for Requester {
582+
impl From<BrowserSession> for RequestingEntity {
538583
fn from(session: BrowserSession) -> Self {
539584
Self::BrowserSession(Box::new(session))
540585
}
541586
}
542587

543-
impl<T> From<Option<T>> for Requester
588+
impl<T> From<Option<T>> for RequestingEntity
544589
where
545-
T: Into<Requester>,
590+
T: Into<RequestingEntity>,
546591
{
547592
fn from(session: Option<T>) -> Self {
548593
session.map(Into::into).unwrap_or_default()

crates/handlers/src/graphql/mutations/user.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use zeroize::Zeroizing;
2121
use crate::graphql::{
2222
model::{NodeType, User},
2323
state::ContextExt,
24-
Requester, UserId,
24+
UserId,
2525
};
2626

2727
#[derive(Default)]
@@ -728,7 +728,7 @@ impl UserMutations {
728728
let state = ctx.state();
729729
let requester = ctx.requester();
730730
let clock = state.clock();
731-
if !matches!(requester, Requester::Anonymous) {
731+
if !requester.is_unauthenticated() {
732732
return Err(async_graphql::Error::new(
733733
"Account recovery is only for anonymous users.",
734734
));
@@ -830,7 +830,7 @@ impl UserMutations {
830830
input: ResendRecoveryEmailInput,
831831
) -> Result<ResendRecoveryEmailPayload, async_graphql::Error> {
832832
let state = ctx.state();
833-
let requester_fingerprint = ctx.requester_fingerprint();
833+
let requester = ctx.requester();
834834
let clock = state.clock();
835835
let mut rng = state.rng();
836836
let limiter = state.limiter();
@@ -847,7 +847,7 @@ impl UserMutations {
847847
.context("Could not load recovery session")?;
848848

849849
if let Err(e) =
850-
limiter.check_account_recovery(requester_fingerprint, &recovery_session.email)
850+
limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email)
851851
{
852852
tracing::warn!(error = &e as &dyn std::error::Error);
853853
return Ok(ResendRecoveryEmailPayload::RateLimited);

crates/handlers/src/graphql/mutations/user_email.rs

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,12 @@ impl UserEmailMutations {
424424

425425
if !skip_policy_check {
426426
let mut policy = state.policy().await?;
427-
let res = policy.evaluate_email(&input.email).await?;
427+
let res = policy
428+
.evaluate_email(mas_policy::EmailInput {
429+
email: &input.email,
430+
requester: requester.for_policy(),
431+
})
432+
.await?;
428433
if !res.valid() {
429434
return Ok(AddEmailPayload::Denied {
430435
violations: res.violations,
@@ -584,7 +589,7 @@ impl UserEmailMutations {
584589
}
585590

586591
if let Err(e) =
587-
limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email)
592+
limiter.check_email_authentication_email(requester.fingerprint(), &input.email)
588593
{
589594
tracing::warn!(error = &e as &dyn std::error::Error);
590595
return Ok(StartEmailAuthenticationPayload::RateLimited);
@@ -610,7 +615,12 @@ impl UserEmailMutations {
610615

611616
// Check if the email address is allowed by the policy
612617
let mut policy = state.policy().await?;
613-
let res = policy.evaluate_email(&input.email).await?;
618+
let res = policy
619+
.evaluate_email(mas_policy::EmailInput {
620+
email: &input.email,
621+
requester: requester.for_policy(),
622+
})
623+
.await?;
614624
if !res.valid() {
615625
return Ok(StartEmailAuthenticationPayload::Denied {
616626
violations: res.violations,
@@ -648,9 +658,10 @@ impl UserEmailMutations {
648658
let mut rng = state.rng();
649659
let clock = state.clock();
650660
let limiter = state.limiter();
661+
let requester = ctx.requester();
651662

652663
let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
653-
let Some(browser_session) = ctx.requester().browser_session() else {
664+
let Some(browser_session) = requester.browser_session() else {
654665
return Err(async_graphql::Error::new("Unauthorized"));
655666
};
656667

@@ -680,8 +691,8 @@ impl UserEmailMutations {
680691
return Ok(ResendEmailAuthenticationCodePayload::Completed);
681692
}
682693

683-
if let Err(e) = limiter
684-
.check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication)
694+
if let Err(e) =
695+
limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication)
685696
{
686697
tracing::warn!(error = &e as &dyn std::error::Error);
687698
return Ok(ResendEmailAuthenticationCodePayload::RateLimited);

crates/handlers/src/graphql/query/viewer.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use async_graphql::{Context, Object};
99
use crate::graphql::{
1010
model::{Viewer, ViewerSession},
1111
state::ContextExt,
12-
Requester,
1312
};
1413

1514
#[derive(Default)]
@@ -21,24 +20,25 @@ impl ViewerQuery {
2120
async fn viewer(&self, ctx: &Context<'_>) -> Viewer {
2221
let requester = ctx.requester();
2322

24-
match requester {
25-
Requester::BrowserSession(session) => Viewer::user(session.user.clone()),
26-
Requester::OAuth2Session(tuple) => match &tuple.1 {
27-
Some(user) => Viewer::user(user.clone()),
28-
None => Viewer::anonymous(),
29-
},
30-
Requester::Anonymous => Viewer::anonymous(),
23+
if let Some(user) = requester.user() {
24+
return Viewer::user(user.clone());
3125
}
26+
27+
Viewer::anonymous()
3228
}
3329

3430
/// Get the viewer's session
3531
async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession {
3632
let requester = ctx.requester();
3733

38-
match requester {
39-
Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()),
40-
Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()),
41-
Requester::Anonymous => ViewerSession::anonymous(),
34+
if let Some(session) = requester.browser_session() {
35+
return ViewerSession::browser_session(session.clone());
36+
}
37+
38+
if let Some(session) = requester.oauth2_session() {
39+
return ViewerSession::oauth2_session(session.clone());
4240
}
41+
42+
ViewerSession::anonymous()
4343
}
4444
}

crates/handlers/src/graphql/state.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use mas_policy::Policy;
1010
use mas_router::UrlBuilder;
1111
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
1212

13-
use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};
13+
use crate::{graphql::Requester, passwords::PasswordManager, Limiter};
1414

1515
#[async_trait::async_trait]
1616
pub trait State {
@@ -31,8 +31,6 @@ pub trait ContextExt {
3131
fn state(&self) -> &BoxState;
3232

3333
fn requester(&self) -> &Requester;
34-
35-
fn requester_fingerprint(&self) -> RequesterFingerprint;
3634
}
3735

3836
impl ContextExt for async_graphql::Context<'_> {
@@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> {
4341
fn requester(&self) -> &Requester {
4442
self.data_unchecked()
4543
}
46-
47-
fn requester_fingerprint(&self) -> RequesterFingerprint {
48-
*self.data_unchecked()
49-
}
5044
}

0 commit comments

Comments
 (0)