Skip to content

Commit 38fa52a

Browse files
committed
Merge the GraphQL requester and requester fingerprint into a single struct
1 parent aa6436a commit 38fa52a

File tree

5 files changed

+68
-46
lines changed

5 files changed

+68
-46
lines changed

crates/handlers/src/graphql/mod.rs

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#![allow(clippy::module_name_repetitions)]
88

9-
use std::sync::Arc;
9+
use std::{net::IpAddr, ops::Deref, sync::Arc};
1010

1111
use async_graphql::{
1212
extensions::Tracing,
@@ -240,7 +240,7 @@ async fn get_requester(
240240
session_info: SessionInfo,
241241
token: Option<&str>,
242242
) -> Result<Requester, RouteError> {
243-
let requester = if let Some(token) = token {
243+
let entity = if let Some(token) = token {
244244
// If we haven't enabled undocumented_oauth2_access on the listener, we bail out
245245
if !undocumented_oauth2_access {
246246
return Err(RouteError::InvalidToken);
@@ -285,7 +285,7 @@ async fn get_requester(
285285
return Err(RouteError::MissingScope);
286286
}
287287

288-
Requester::OAuth2Session(Box::new((session, user)))
288+
RequestingEntity::OAuth2Session(Box::new((session, user)))
289289
} else {
290290
let maybe_session = session_info.load_session(&mut repo).await?;
291291

@@ -295,8 +295,14 @@ async fn get_requester(
295295
.await;
296296
}
297297

298-
Requester::from(maybe_session)
298+
RequestingEntity::from(maybe_session)
299299
};
300+
301+
let requester = Requester {
302+
entity,
303+
ip_address: activity_tracker.ip(),
304+
};
305+
300306
repo.cancel().await?;
301307
Ok(requester)
302308
}
@@ -312,7 +318,6 @@ pub async fn post(
312318
cookie_jar: CookieJar,
313319
content_type: Option<TypedHeader<ContentType>>,
314320
authorization: Option<TypedHeader<Authorization<Bearer>>>,
315-
requester_fingerprint: RequesterFingerprint,
316321
body: Body,
317322
) -> Result<impl IntoResponse, RouteError> {
318323
let body = body.into_data_stream();
@@ -339,7 +344,6 @@ pub async fn post(
339344
MultipartOptions::default(),
340345
)
341346
.await?
342-
.data(requester_fingerprint)
343347
.data(requester); // XXX: this should probably return another error response?
344348

345349
let span = span_for_graphql_request(&request);
@@ -366,7 +370,6 @@ pub async fn get(
366370
activity_tracker: BoundActivityTracker,
367371
cookie_jar: CookieJar,
368372
authorization: Option<TypedHeader<Authorization<Bearer>>>,
369-
requester_fingerprint: RequesterFingerprint,
370373
RawQuery(query): RawQuery,
371374
) -> Result<impl IntoResponse, FancyError> {
372375
let token = authorization
@@ -383,9 +386,8 @@ pub async fn get(
383386
)
384387
.await?;
385388

386-
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
387-
.data(requester)
388-
.data(requester_fingerprint);
389+
let request =
390+
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
389391

390392
let span = span_for_graphql_request(&request);
391393
let response = schema.execute(request).instrument(span).await;
@@ -417,9 +419,32 @@ pub fn schema_builder() -> SchemaBuilder {
417419
.register_output_type::<CreationEvent>()
418420
}
419421

422+
pub struct Requester {
423+
entity: RequestingEntity,
424+
ip_address: Option<IpAddr>,
425+
}
426+
427+
impl Requester {
428+
pub fn fingerprint(&self) -> RequesterFingerprint {
429+
if let Some(ip) = self.ip_address {
430+
RequesterFingerprint::new(ip)
431+
} else {
432+
RequesterFingerprint::EMPTY
433+
}
434+
}
435+
}
436+
437+
impl Deref for Requester {
438+
type Target = RequestingEntity;
439+
440+
fn deref(&self) -> &Self::Target {
441+
&self.entity
442+
}
443+
}
444+
420445
/// The identity of the requester.
421446
#[derive(Debug, Clone, Default, PartialEq, Eq)]
422-
pub enum Requester {
447+
pub enum RequestingEntity {
423448
/// The requester presented no authentication information.
424449
#[default]
425450
Anonymous,
@@ -480,7 +505,7 @@ impl OwnerId for UserId {
480505
}
481506
}
482507

483-
impl Requester {
508+
impl RequestingEntity {
484509
fn browser_session(&self) -> Option<&BrowserSession> {
485510
match self {
486511
Self::BrowserSession(session) => Some(session),
@@ -532,17 +557,21 @@ impl Requester {
532557
Self::BrowserSession(_) | Self::Anonymous => false,
533558
}
534559
}
560+
561+
fn is_unauthenticated(&self) -> bool {
562+
matches!(self, Self::Anonymous)
563+
}
535564
}
536565

537-
impl From<BrowserSession> for Requester {
566+
impl From<BrowserSession> for RequestingEntity {
538567
fn from(session: BrowserSession) -> Self {
539568
Self::BrowserSession(Box::new(session))
540569
}
541570
}
542571

543-
impl<T> From<Option<T>> for Requester
572+
impl<T> From<Option<T>> for RequestingEntity
544573
where
545-
T: Into<Requester>,
574+
T: Into<RequestingEntity>,
546575
{
547576
fn from(session: Option<T>) -> Self {
548577
session.map(Into::into).unwrap_or_default()

crates/handlers/src/graphql/mutations/user.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use zeroize::Zeroizing;
2121
use crate::graphql::{
2222
model::{NodeType, User},
2323
state::ContextExt,
24-
Requester, UserId,
24+
UserId,
2525
};
2626

2727
#[derive(Default)]
@@ -728,7 +728,7 @@ impl UserMutations {
728728
let state = ctx.state();
729729
let requester = ctx.requester();
730730
let clock = state.clock();
731-
if !matches!(requester, Requester::Anonymous) {
731+
if !requester.is_unauthenticated() {
732732
return Err(async_graphql::Error::new(
733733
"Account recovery is only for anonymous users.",
734734
));
@@ -830,7 +830,7 @@ impl UserMutations {
830830
input: ResendRecoveryEmailInput,
831831
) -> Result<ResendRecoveryEmailPayload, async_graphql::Error> {
832832
let state = ctx.state();
833-
let requester_fingerprint = ctx.requester_fingerprint();
833+
let requester = ctx.requester();
834834
let clock = state.clock();
835835
let mut rng = state.rng();
836836
let limiter = state.limiter();
@@ -847,7 +847,7 @@ impl UserMutations {
847847
.context("Could not load recovery session")?;
848848

849849
if let Err(e) =
850-
limiter.check_account_recovery(requester_fingerprint, &recovery_session.email)
850+
limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email)
851851
{
852852
tracing::warn!(error = &e as &dyn std::error::Error);
853853
return Ok(ResendRecoveryEmailPayload::RateLimited);

crates/handlers/src/graphql/mutations/user_email.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,6 @@ impl UserEmailMutations {
398398
let state = ctx.state();
399399
let id = NodeType::User.extract_ulid(&input.user_id)?;
400400
let requester = ctx.requester();
401-
let requester_fingerprint = ctx.requester_fingerprint();
402401
let clock = state.clock();
403402
let mut rng = state.rng();
404403

@@ -428,7 +427,7 @@ impl UserEmailMutations {
428427
let res = policy
429428
.evaluate_email(mas_policy::EmailInput {
430429
email: &input.email,
431-
requester: requester_fingerprint.into(),
430+
requester: requester.fingerprint().into(),
432431
})
433432
.await?;
434433
if !res.valid() {
@@ -561,7 +560,6 @@ impl UserEmailMutations {
561560
let mut rng = state.rng();
562561
let clock = state.clock();
563562
let requester = ctx.requester();
564-
let requester_fingerprint = ctx.requester_fingerprint();
565563
let limiter = state.limiter();
566564

567565
// Only allow calling this if the requester is a browser session
@@ -591,7 +589,7 @@ impl UserEmailMutations {
591589
}
592590

593591
if let Err(e) =
594-
limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email)
592+
limiter.check_email_authentication_email(requester.fingerprint(), &input.email)
595593
{
596594
tracing::warn!(error = &e as &dyn std::error::Error);
597595
return Ok(StartEmailAuthenticationPayload::RateLimited);
@@ -620,7 +618,7 @@ impl UserEmailMutations {
620618
let res = policy
621619
.evaluate_email(mas_policy::EmailInput {
622620
email: &input.email,
623-
requester: requester_fingerprint.into(),
621+
requester: requester.fingerprint().into(),
624622
})
625623
.await?;
626624
if !res.valid() {
@@ -660,9 +658,10 @@ impl UserEmailMutations {
660658
let mut rng = state.rng();
661659
let clock = state.clock();
662660
let limiter = state.limiter();
661+
let requester = ctx.requester();
663662

664663
let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
665-
let Some(browser_session) = ctx.requester().browser_session() else {
664+
let Some(browser_session) = requester.browser_session() else {
666665
return Err(async_graphql::Error::new("Unauthorized"));
667666
};
668667

@@ -692,8 +691,8 @@ impl UserEmailMutations {
692691
return Ok(ResendEmailAuthenticationCodePayload::Completed);
693692
}
694693

695-
if let Err(e) = limiter
696-
.check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication)
694+
if let Err(e) =
695+
limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication)
697696
{
698697
tracing::warn!(error = &e as &dyn std::error::Error);
699698
return Ok(ResendEmailAuthenticationCodePayload::RateLimited);

crates/handlers/src/graphql/query/viewer.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use async_graphql::{Context, Object};
99
use crate::graphql::{
1010
model::{Viewer, ViewerSession},
1111
state::ContextExt,
12-
Requester,
1312
};
1413

1514
#[derive(Default)]
@@ -21,24 +20,25 @@ impl ViewerQuery {
2120
async fn viewer(&self, ctx: &Context<'_>) -> Viewer {
2221
let requester = ctx.requester();
2322

24-
match requester {
25-
Requester::BrowserSession(session) => Viewer::user(session.user.clone()),
26-
Requester::OAuth2Session(tuple) => match &tuple.1 {
27-
Some(user) => Viewer::user(user.clone()),
28-
None => Viewer::anonymous(),
29-
},
30-
Requester::Anonymous => Viewer::anonymous(),
23+
if let Some(user) = requester.user() {
24+
return Viewer::user(user.clone());
3125
}
26+
27+
Viewer::anonymous()
3228
}
3329

3430
/// Get the viewer's session
3531
async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession {
3632
let requester = ctx.requester();
3733

38-
match requester {
39-
Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()),
40-
Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()),
41-
Requester::Anonymous => ViewerSession::anonymous(),
34+
if let Some(session) = requester.browser_session() {
35+
return ViewerSession::browser_session(session.clone());
36+
}
37+
38+
if let Some(session) = requester.oauth2_session() {
39+
return ViewerSession::oauth2_session(session.clone());
4240
}
41+
42+
ViewerSession::anonymous()
4343
}
4444
}

crates/handlers/src/graphql/state.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use mas_policy::Policy;
1010
use mas_router::UrlBuilder;
1111
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
1212

13-
use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};
13+
use crate::{graphql::Requester, passwords::PasswordManager, Limiter};
1414

1515
#[async_trait::async_trait]
1616
pub trait State {
@@ -31,8 +31,6 @@ pub trait ContextExt {
3131
fn state(&self) -> &BoxState;
3232

3333
fn requester(&self) -> &Requester;
34-
35-
fn requester_fingerprint(&self) -> RequesterFingerprint;
3634
}
3735

3836
impl ContextExt for async_graphql::Context<'_> {
@@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> {
4341
fn requester(&self) -> &Requester {
4442
self.data_unchecked()
4543
}
46-
47-
fn requester_fingerprint(&self) -> RequesterFingerprint {
48-
*self.data_unchecked()
49-
}
5044
}

0 commit comments

Comments
 (0)