Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/build-policies/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ runs:
- name: Install Open Policy Agent
uses: open-policy-agent/[email protected]
with:
version: 0.70.0
version: 1.1.0

- name: Build the policies
run: make
Expand Down
75 changes: 60 additions & 15 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#![allow(clippy::module_name_repetitions)]

use std::sync::Arc;
use std::{net::IpAddr, ops::Deref, sync::Arc};

use async_graphql::{
extensions::Tracing,
Expand Down Expand Up @@ -238,9 +238,10 @@ async fn get_requester(
activity_tracker: &BoundActivityTracker,
mut repo: BoxRepository,
session_info: SessionInfo,
user_agent: Option<String>,
token: Option<&str>,
) -> Result<Requester, RouteError> {
let requester = if let Some(token) = token {
let entity = if let Some(token) = token {
// If we haven't enabled undocumented_oauth2_access on the listener, we bail out
if !undocumented_oauth2_access {
return Err(RouteError::InvalidToken);
Expand Down Expand Up @@ -285,7 +286,7 @@ async fn get_requester(
return Err(RouteError::MissingScope);
}

Requester::OAuth2Session(Box::new((session, user)))
RequestingEntity::OAuth2Session(Box::new((session, user)))
} else {
let maybe_session = session_info.load_session(&mut repo).await?;

Expand All @@ -295,8 +296,15 @@ async fn get_requester(
.await;
}

Requester::from(maybe_session)
RequestingEntity::from(maybe_session)
};

let requester = Requester {
entity,
ip_address: activity_tracker.ip(),
user_agent,
};

repo.cancel().await?;
Ok(requester)
}
Expand All @@ -312,20 +320,22 @@ pub async fn post(
cookie_jar: CookieJar,
content_type: Option<TypedHeader<ContentType>>,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
requester_fingerprint: RequesterFingerprint,
user_agent: Option<TypedHeader<headers::UserAgent>>,
body: Body,
) -> Result<impl IntoResponse, RouteError> {
let body = body.into_data_stream();
let token = authorization
.as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token());
let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let requester = get_requester(
undocumented_oauth2_access,
&clock,
&activity_tracker,
repo,
session_info,
user_agent,
token,
)
.await?;
Expand All @@ -339,7 +349,6 @@ pub async fn post(
MultipartOptions::default(),
)
.await?
.data(requester_fingerprint)
.data(requester); // XXX: this should probably return another error response?

let span = span_for_graphql_request(&request);
Expand All @@ -366,26 +375,27 @@ pub async fn get(
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
requester_fingerprint: RequesterFingerprint,
user_agent: Option<TypedHeader<headers::UserAgent>>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let token = authorization
.as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token());
let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let requester = get_requester(
undocumented_oauth2_access,
&clock,
&activity_tracker,
repo,
session_info,
user_agent,
token,
)
.await?;

let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
.data(requester)
.data(requester_fingerprint);
let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);

let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
Expand Down Expand Up @@ -417,9 +427,40 @@ pub fn schema_builder() -> SchemaBuilder {
.register_output_type::<CreationEvent>()
}

pub struct Requester {
entity: RequestingEntity,
ip_address: Option<IpAddr>,
user_agent: Option<String>,
}

impl Requester {
pub fn fingerprint(&self) -> RequesterFingerprint {
if let Some(ip) = self.ip_address {
RequesterFingerprint::new(ip)
} else {
RequesterFingerprint::EMPTY
}
}

pub fn for_policy(&self) -> mas_policy::Requester {
mas_policy::Requester {
ip_address: self.ip_address,
user_agent: self.user_agent.clone(),
}
}
}

impl Deref for Requester {
type Target = RequestingEntity;

fn deref(&self) -> &Self::Target {
&self.entity
}
}

/// The identity of the requester.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum Requester {
pub enum RequestingEntity {
/// The requester presented no authentication information.
#[default]
Anonymous,
Expand Down Expand Up @@ -480,7 +521,7 @@ impl OwnerId for UserId {
}
}

impl Requester {
impl RequestingEntity {
fn browser_session(&self) -> Option<&BrowserSession> {
match self {
Self::BrowserSession(session) => Some(session),
Expand Down Expand Up @@ -532,17 +573,21 @@ impl Requester {
Self::BrowserSession(_) | Self::Anonymous => false,
}
}

fn is_unauthenticated(&self) -> bool {
matches!(self, Self::Anonymous)
}
}

impl From<BrowserSession> for Requester {
impl From<BrowserSession> for RequestingEntity {
fn from(session: BrowserSession) -> Self {
Self::BrowserSession(Box::new(session))
}
}

impl<T> From<Option<T>> for Requester
impl<T> From<Option<T>> for RequestingEntity
where
T: Into<Requester>,
T: Into<RequestingEntity>,
{
fn from(session: Option<T>) -> Self {
session.map(Into::into).unwrap_or_default()
Expand Down
8 changes: 4 additions & 4 deletions crates/handlers/src/graphql/mutations/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use zeroize::Zeroizing;
use crate::graphql::{
model::{NodeType, User},
state::ContextExt,
Requester, UserId,
UserId,
};

#[derive(Default)]
Expand Down Expand Up @@ -728,7 +728,7 @@ impl UserMutations {
let state = ctx.state();
let requester = ctx.requester();
let clock = state.clock();
if !matches!(requester, Requester::Anonymous) {
if !requester.is_unauthenticated() {
return Err(async_graphql::Error::new(
"Account recovery is only for anonymous users.",
));
Expand Down Expand Up @@ -830,7 +830,7 @@ impl UserMutations {
input: ResendRecoveryEmailInput,
) -> Result<ResendRecoveryEmailPayload, async_graphql::Error> {
let state = ctx.state();
let requester_fingerprint = ctx.requester_fingerprint();
let requester = ctx.requester();
let clock = state.clock();
let mut rng = state.rng();
let limiter = state.limiter();
Expand All @@ -847,7 +847,7 @@ impl UserMutations {
.context("Could not load recovery session")?;

if let Err(e) =
limiter.check_account_recovery(requester_fingerprint, &recovery_session.email)
limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(ResendRecoveryEmailPayload::RateLimited);
Expand Down
23 changes: 17 additions & 6 deletions crates/handlers/src/graphql/mutations/user_email.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,12 @@ impl UserEmailMutations {

if !skip_policy_check {
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester.for_policy(),
})
.await?;
if !res.valid() {
return Ok(AddEmailPayload::Denied {
violations: res.violations,
Expand Down Expand Up @@ -584,7 +589,7 @@ impl UserEmailMutations {
}

if let Err(e) =
limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email)
limiter.check_email_authentication_email(requester.fingerprint(), &input.email)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(StartEmailAuthenticationPayload::RateLimited);
Expand All @@ -610,7 +615,12 @@ impl UserEmailMutations {

// Check if the email address is allowed by the policy
let mut policy = state.policy().await?;
let res = policy.evaluate_email(&input.email).await?;
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester.for_policy(),
})
.await?;
if !res.valid() {
return Ok(StartEmailAuthenticationPayload::Denied {
violations: res.violations,
Expand Down Expand Up @@ -648,9 +658,10 @@ impl UserEmailMutations {
let mut rng = state.rng();
let clock = state.clock();
let limiter = state.limiter();
let requester = ctx.requester();

let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
let Some(browser_session) = ctx.requester().browser_session() else {
let Some(browser_session) = requester.browser_session() else {
return Err(async_graphql::Error::new("Unauthorized"));
};

Expand Down Expand Up @@ -680,8 +691,8 @@ impl UserEmailMutations {
return Ok(ResendEmailAuthenticationCodePayload::Completed);
}

if let Err(e) = limiter
.check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication)
if let Err(e) =
limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(ResendEmailAuthenticationCodePayload::RateLimited);
Expand Down
24 changes: 12 additions & 12 deletions crates/handlers/src/graphql/query/viewer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use async_graphql::{Context, Object};
use crate::graphql::{
model::{Viewer, ViewerSession},
state::ContextExt,
Requester,
};

#[derive(Default)]
Expand All @@ -21,24 +20,25 @@ impl ViewerQuery {
async fn viewer(&self, ctx: &Context<'_>) -> Viewer {
let requester = ctx.requester();

match requester {
Requester::BrowserSession(session) => Viewer::user(session.user.clone()),
Requester::OAuth2Session(tuple) => match &tuple.1 {
Some(user) => Viewer::user(user.clone()),
None => Viewer::anonymous(),
},
Requester::Anonymous => Viewer::anonymous(),
if let Some(user) = requester.user() {
return Viewer::user(user.clone());
}

Viewer::anonymous()
}

/// Get the viewer's session
async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession {
let requester = ctx.requester();

match requester {
Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()),
Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()),
Requester::Anonymous => ViewerSession::anonymous(),
if let Some(session) = requester.browser_session() {
return ViewerSession::browser_session(session.clone());
}

if let Some(session) = requester.oauth2_session() {
return ViewerSession::oauth2_session(session.clone());
}

ViewerSession::anonymous()
}
}
8 changes: 1 addition & 7 deletions crates/handlers/src/graphql/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};

use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};
use crate::{graphql::Requester, passwords::PasswordManager, Limiter};

#[async_trait::async_trait]
pub trait State {
Expand All @@ -31,8 +31,6 @@ pub trait ContextExt {
fn state(&self) -> &BoxState;

fn requester(&self) -> &Requester;

fn requester_fingerprint(&self) -> RequesterFingerprint;
}

impl ContextExt for async_graphql::Context<'_> {
Expand All @@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> {
fn requester(&self) -> &Requester {
self.data_unchecked()
}

fn requester_fingerprint(&self) -> RequesterFingerprint {
*self.data_unchecked()
}
}
Loading
Loading