diff --git a/.github/actions/build-policies/action.yml b/.github/actions/build-policies/action.yml index dfe78917d..274aa8134 100644 --- a/.github/actions/build-policies/action.yml +++ b/.github/actions/build-policies/action.yml @@ -7,7 +7,7 @@ runs: - name: Install Open Policy Agent uses: open-policy-agent/setup-opa@v2.2.0 with: - version: 0.70.0 + version: 1.1.0 - name: Build the policies run: make diff --git a/crates/handlers/src/graphql/mod.rs b/crates/handlers/src/graphql/mod.rs index 3fbf30166..0fc5209c2 100644 --- a/crates/handlers/src/graphql/mod.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -6,7 +6,7 @@ #![allow(clippy::module_name_repetitions)] -use std::sync::Arc; +use std::{net::IpAddr, ops::Deref, sync::Arc}; use async_graphql::{ extensions::Tracing, @@ -238,9 +238,10 @@ async fn get_requester( activity_tracker: &BoundActivityTracker, mut repo: BoxRepository, session_info: SessionInfo, + user_agent: Option, token: Option<&str>, ) -> Result { - let requester = if let Some(token) = token { + let entity = if let Some(token) = token { // If we haven't enabled undocumented_oauth2_access on the listener, we bail out if !undocumented_oauth2_access { return Err(RouteError::InvalidToken); @@ -285,7 +286,7 @@ async fn get_requester( return Err(RouteError::MissingScope); } - Requester::OAuth2Session(Box::new((session, user))) + RequestingEntity::OAuth2Session(Box::new((session, user))) } else { let maybe_session = session_info.load_session(&mut repo).await?; @@ -295,8 +296,15 @@ async fn get_requester( .await; } - Requester::from(maybe_session) + RequestingEntity::from(maybe_session) }; + + let requester = Requester { + entity, + ip_address: activity_tracker.ip(), + user_agent, + }; + repo.cancel().await?; Ok(requester) } @@ -312,13 +320,14 @@ pub async fn post( cookie_jar: CookieJar, content_type: Option>, authorization: Option>>, - requester_fingerprint: RequesterFingerprint, + user_agent: Option>, body: Body, ) -> Result { let body = body.into_data_stream(); let token = authorization .as_ref() .map(|TypedHeader(Authorization(bearer))| bearer.token()); + let user_agent = user_agent.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); let requester = get_requester( undocumented_oauth2_access, @@ -326,6 +335,7 @@ pub async fn post( &activity_tracker, repo, session_info, + user_agent, token, ) .await?; @@ -339,7 +349,6 @@ pub async fn post( MultipartOptions::default(), ) .await? - .data(requester_fingerprint) .data(requester); // XXX: this should probably return another error response? let span = span_for_graphql_request(&request); @@ -366,12 +375,13 @@ pub async fn get( activity_tracker: BoundActivityTracker, cookie_jar: CookieJar, authorization: Option>>, - requester_fingerprint: RequesterFingerprint, + user_agent: Option>, RawQuery(query): RawQuery, ) -> Result { let token = authorization .as_ref() .map(|TypedHeader(Authorization(bearer))| bearer.token()); + let user_agent = user_agent.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); let requester = get_requester( undocumented_oauth2_access, @@ -379,13 +389,13 @@ pub async fn get( &activity_tracker, repo, session_info, + user_agent, token, ) .await?; - let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())? - .data(requester) - .data(requester_fingerprint); + let request = + async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester); let span = span_for_graphql_request(&request); let response = schema.execute(request).instrument(span).await; @@ -417,9 +427,40 @@ pub fn schema_builder() -> SchemaBuilder { .register_output_type::() } +pub struct Requester { + entity: RequestingEntity, + ip_address: Option, + user_agent: Option, +} + +impl Requester { + pub fn fingerprint(&self) -> RequesterFingerprint { + if let Some(ip) = self.ip_address { + RequesterFingerprint::new(ip) + } else { + RequesterFingerprint::EMPTY + } + } + + pub fn for_policy(&self) -> mas_policy::Requester { + mas_policy::Requester { + ip_address: self.ip_address, + user_agent: self.user_agent.clone(), + } + } +} + +impl Deref for Requester { + type Target = RequestingEntity; + + fn deref(&self) -> &Self::Target { + &self.entity + } +} + /// The identity of the requester. #[derive(Debug, Clone, Default, PartialEq, Eq)] -pub enum Requester { +pub enum RequestingEntity { /// The requester presented no authentication information. #[default] Anonymous, @@ -480,7 +521,7 @@ impl OwnerId for UserId { } } -impl Requester { +impl RequestingEntity { fn browser_session(&self) -> Option<&BrowserSession> { match self { Self::BrowserSession(session) => Some(session), @@ -532,17 +573,21 @@ impl Requester { Self::BrowserSession(_) | Self::Anonymous => false, } } + + fn is_unauthenticated(&self) -> bool { + matches!(self, Self::Anonymous) + } } -impl From for Requester { +impl From for RequestingEntity { fn from(session: BrowserSession) -> Self { Self::BrowserSession(Box::new(session)) } } -impl From> for Requester +impl From> for RequestingEntity where - T: Into, + T: Into, { fn from(session: Option) -> Self { session.map(Into::into).unwrap_or_default() diff --git a/crates/handlers/src/graphql/mutations/user.rs b/crates/handlers/src/graphql/mutations/user.rs index 52c661b05..311a1dcb7 100644 --- a/crates/handlers/src/graphql/mutations/user.rs +++ b/crates/handlers/src/graphql/mutations/user.rs @@ -21,7 +21,7 @@ use zeroize::Zeroizing; use crate::graphql::{ model::{NodeType, User}, state::ContextExt, - Requester, UserId, + UserId, }; #[derive(Default)] @@ -728,7 +728,7 @@ impl UserMutations { let state = ctx.state(); let requester = ctx.requester(); let clock = state.clock(); - if !matches!(requester, Requester::Anonymous) { + if !requester.is_unauthenticated() { return Err(async_graphql::Error::new( "Account recovery is only for anonymous users.", )); @@ -830,7 +830,7 @@ impl UserMutations { input: ResendRecoveryEmailInput, ) -> Result { let state = ctx.state(); - let requester_fingerprint = ctx.requester_fingerprint(); + let requester = ctx.requester(); let clock = state.clock(); let mut rng = state.rng(); let limiter = state.limiter(); @@ -847,7 +847,7 @@ impl UserMutations { .context("Could not load recovery session")?; if let Err(e) = - limiter.check_account_recovery(requester_fingerprint, &recovery_session.email) + limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email) { tracing::warn!(error = &e as &dyn std::error::Error); return Ok(ResendRecoveryEmailPayload::RateLimited); diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 38048d39a..ba7aef776 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -424,7 +424,12 @@ impl UserEmailMutations { if !skip_policy_check { let mut policy = state.policy().await?; - let res = policy.evaluate_email(&input.email).await?; + let res = policy + .evaluate_email(mas_policy::EmailInput { + email: &input.email, + requester: requester.for_policy(), + }) + .await?; if !res.valid() { return Ok(AddEmailPayload::Denied { violations: res.violations, @@ -584,7 +589,7 @@ impl UserEmailMutations { } if let Err(e) = - limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email) + limiter.check_email_authentication_email(requester.fingerprint(), &input.email) { tracing::warn!(error = &e as &dyn std::error::Error); return Ok(StartEmailAuthenticationPayload::RateLimited); @@ -610,7 +615,12 @@ impl UserEmailMutations { // Check if the email address is allowed by the policy let mut policy = state.policy().await?; - let res = policy.evaluate_email(&input.email).await?; + let res = policy + .evaluate_email(mas_policy::EmailInput { + email: &input.email, + requester: requester.for_policy(), + }) + .await?; if !res.valid() { return Ok(StartEmailAuthenticationPayload::Denied { violations: res.violations, @@ -648,9 +658,10 @@ impl UserEmailMutations { let mut rng = state.rng(); let clock = state.clock(); let limiter = state.limiter(); + let requester = ctx.requester(); let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?; - let Some(browser_session) = ctx.requester().browser_session() else { + let Some(browser_session) = requester.browser_session() else { return Err(async_graphql::Error::new("Unauthorized")); }; @@ -680,8 +691,8 @@ impl UserEmailMutations { return Ok(ResendEmailAuthenticationCodePayload::Completed); } - if let Err(e) = limiter - .check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication) + if let Err(e) = + limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication) { tracing::warn!(error = &e as &dyn std::error::Error); return Ok(ResendEmailAuthenticationCodePayload::RateLimited); diff --git a/crates/handlers/src/graphql/query/viewer.rs b/crates/handlers/src/graphql/query/viewer.rs index 60f884357..6985dfd2e 100644 --- a/crates/handlers/src/graphql/query/viewer.rs +++ b/crates/handlers/src/graphql/query/viewer.rs @@ -9,7 +9,6 @@ use async_graphql::{Context, Object}; use crate::graphql::{ model::{Viewer, ViewerSession}, state::ContextExt, - Requester, }; #[derive(Default)] @@ -21,24 +20,25 @@ impl ViewerQuery { async fn viewer(&self, ctx: &Context<'_>) -> Viewer { let requester = ctx.requester(); - match requester { - Requester::BrowserSession(session) => Viewer::user(session.user.clone()), - Requester::OAuth2Session(tuple) => match &tuple.1 { - Some(user) => Viewer::user(user.clone()), - None => Viewer::anonymous(), - }, - Requester::Anonymous => Viewer::anonymous(), + if let Some(user) = requester.user() { + return Viewer::user(user.clone()); } + + Viewer::anonymous() } /// Get the viewer's session async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession { let requester = ctx.requester(); - match requester { - Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()), - Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()), - Requester::Anonymous => ViewerSession::anonymous(), + if let Some(session) = requester.browser_session() { + return ViewerSession::browser_session(session.clone()); + } + + if let Some(session) = requester.oauth2_session() { + return ViewerSession::oauth2_session(session.clone()); } + + ViewerSession::anonymous() } } diff --git a/crates/handlers/src/graphql/state.rs b/crates/handlers/src/graphql/state.rs index 874f6f7aa..95752c4fd 100644 --- a/crates/handlers/src/graphql/state.rs +++ b/crates/handlers/src/graphql/state.rs @@ -10,7 +10,7 @@ use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; -use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint}; +use crate::{graphql::Requester, passwords::PasswordManager, Limiter}; #[async_trait::async_trait] pub trait State { @@ -31,8 +31,6 @@ pub trait ContextExt { fn state(&self) -> &BoxState; fn requester(&self) -> &Requester; - - fn requester_fingerprint(&self) -> RequesterFingerprint; } impl ContextExt for async_graphql::Context<'_> { @@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> { fn requester(&self) -> &Requester { self.data_unchecked() } - - fn requester_fingerprint(&self) -> RequesterFingerprint { - *self.data_unchecked() - } } diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index a9efb2ae2..8c3faf5b7 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -8,6 +8,7 @@ use axum::{ extract::{Path, State}, response::{Html, IntoResponse, Response}, }; +use axum_extra::TypedHeader; use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, sentry::SentryEventID, SessionInfoExt}; use mas_data_model::{AuthorizationGrant, BrowserSession, Client, Device}; @@ -89,6 +90,7 @@ pub(crate) async fn get( State(key_store): State, policy: Policy, activity_tracker: BoundActivityTracker, + user_agent: Option>, mut repo: BoxRepository, cookie_jar: CookieJar, Path(grant_id): Path, @@ -97,6 +99,8 @@ pub(crate) async fn get( let maybe_session = session_info.load_session(&mut repo).await?; + let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string()); + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) @@ -130,6 +134,7 @@ pub(crate) async fn get( &mut rng, &clock, &activity_tracker, + user_agent, repo, key_store, policy, @@ -199,6 +204,7 @@ pub(crate) async fn complete( rng: &mut (impl rand::RngCore + rand::CryptoRng + Send), clock: &impl Clock, activity_tracker: &BoundActivityTracker, + user_agent: Option, mut repo: BoxRepository, key_store: Keystore, mut policy: Policy, @@ -226,7 +232,16 @@ pub(crate) async fn complete( // Run through the policy let res = policy - .evaluate_authorization_grant(&grant, client, &browser_session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: Some(&browser_session.user), + client, + scope: &grant.scope, + grant_type: mas_policy::GrantType::AuthorizationCode, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/authorization/mod.rs b/crates/handlers/src/oauth2/authorization/mod.rs index f56f06133..27b121e95 100644 --- a/crates/handlers/src/oauth2/authorization/mod.rs +++ b/crates/handlers/src/oauth2/authorization/mod.rs @@ -8,6 +8,7 @@ use axum::{ extract::{Form, State}, response::{Html, IntoResponse, Response}, }; +use axum_extra::TypedHeader; use hyper::StatusCode; use mas_axum_utils::{cookies::CookieJar, csrf::CsrfExt, sentry::SentryEventID, SessionInfoExt}; use mas_data_model::{AuthorizationCode, Pkce}; @@ -136,6 +137,7 @@ pub(crate) async fn get( State(key_store): State, State(url_builder): State, policy: Policy, + user_agent: Option>, activity_tracker: BoundActivityTracker, mut repo: BoxRepository, cookie_jar: CookieJar, @@ -166,6 +168,8 @@ pub(crate) async fn get( let (session_info, cookie_jar) = cookie_jar.session_info(); let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); + let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string()); + // One day, we will have try blocks let res: Result = ({ let templates = templates.clone(); @@ -349,6 +353,7 @@ pub(crate) async fn get( &mut rng, &clock, &activity_tracker, + user_agent, repo, key_store, policy, @@ -401,6 +406,7 @@ pub(crate) async fn get( &mut rng, &clock, &activity_tracker, + user_agent, repo, key_store, policy, diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index c5a479e41..f7b4b72fc 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -8,6 +8,7 @@ use axum::{ extract::{Form, Path, State}, response::{Html, IntoResponse, Response}, }; +use axum_extra::TypedHeader; use hyper::StatusCode; use mas_axum_utils::{ cookies::CookieJar, @@ -80,6 +81,7 @@ pub(crate) async fn get( mut policy: Policy, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(grant_id): Path, ) -> Result { @@ -87,6 +89,8 @@ pub(crate) async fn get( let maybe_session = session_info.load_session(&mut repo).await?; + let user_agent = user_agent.map(|ua| ua.to_string()); + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) @@ -111,7 +115,16 @@ pub(crate) async fn get( let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng); let res = policy - .evaluate_authorization_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: Some(&session.user), + client: &client, + scope: &grant.scope, + grant_type: mas_policy::GrantType::AuthorizationCode, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) .await?; if res.valid() { @@ -151,6 +164,7 @@ pub(crate) async fn post( mut policy: Policy, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, State(url_builder): State, Path(grant_id): Path, @@ -162,6 +176,8 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut repo).await?; + let user_agent = user_agent.map(|ua| ua.to_string()); + let grant = repo .oauth2_authorization_grant() .lookup(grant_id) @@ -185,7 +201,16 @@ pub(crate) async fn post( .ok_or(RouteError::NoSuchClient)?; let res = policy - .evaluate_authorization_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: Some(&session.user), + client: &client, + scope: &grant.scope, + grant_type: mas_policy::GrantType::AuthorizationCode, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/device/consent.rs b/crates/handlers/src/oauth2/device/consent.rs index 6025f3d1e..ac8bdd63b 100644 --- a/crates/handlers/src/oauth2/device/consent.rs +++ b/crates/handlers/src/oauth2/device/consent.rs @@ -10,6 +10,7 @@ use axum::{ response::{Html, IntoResponse, Response}, Form, }; +use axum_extra::TypedHeader; use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, @@ -46,6 +47,7 @@ pub(crate) async fn get( mut repo: BoxRepository, mut policy: Policy, activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(grant_id): Path, ) -> Result { @@ -54,6 +56,8 @@ pub(crate) async fn get( let maybe_session = session_info.load_session(&mut repo).await?; + let user_agent = user_agent.map(|ua| ua.to_string()); + let Some(session) = maybe_session else { let login = mas_router::Login::and_continue_device_code_grant(grant_id); return Ok((cookie_jar, url_builder.redirect(&login)).into_response()); @@ -82,7 +86,16 @@ pub(crate) async fn get( // Evaluate the policy let res = policy - .evaluate_device_code_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + grant_type: mas_policy::GrantType::DeviceCode, + client: &client, + scope: &grant.scope, + user: Some(&session.user), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) .await?; if !res.valid() { warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id); @@ -119,6 +132,7 @@ pub(crate) async fn post( mut repo: BoxRepository, mut policy: Policy, activity_tracker: BoundActivityTracker, + user_agent: Option>, cookie_jar: CookieJar, Path(grant_id): Path, Form(form): Form>, @@ -129,6 +143,8 @@ pub(crate) async fn post( let maybe_session = session_info.load_session(&mut repo).await?; + let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string()); + let Some(session) = maybe_session else { let login = mas_router::Login::and_continue_device_code_grant(grant_id); return Ok((cookie_jar, url_builder.redirect(&login)).into_response()); @@ -157,7 +173,16 @@ pub(crate) async fn post( // Evaluate the policy let res = policy - .evaluate_device_code_grant(&grant, &client, &session.user) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + grant_type: mas_policy::GrantType::DeviceCode, + client: &client, + scope: &grant.scope, + user: Some(&session.user), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) .await?; if !res.valid() { warn!(violation = ?res, "Device code grant for client {} denied by policy", client.id); diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index bd608dd69..fca11cf97 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -5,6 +5,7 @@ // Please see LICENSE in the repository root for full details. use axum::{extract::State, response::IntoResponse, Json}; +use axum_extra::TypedHeader; use hyper::StatusCode; use mas_axum_utils::sentry::SentryEventID; use mas_iana::oauth::OAuthClientAuthenticationMethod; @@ -25,7 +26,7 @@ use thiserror::Error; use tracing::info; use url::Url; -use crate::impl_from_error_for_route; +use crate::{impl_from_error_for_route, BoundActivityTracker}; #[derive(Debug, Error)] pub(crate) enum RouteError { @@ -195,6 +196,8 @@ pub(crate) async fn post( clock: BoxClock, mut repo: BoxRepository, mut policy: Policy, + activity_tracker: BoundActivityTracker, + user_agent: Option>, State(encrypter): State, body: Result, axum::extract::rejection::JsonRejection>, ) -> Result { @@ -203,6 +206,8 @@ pub(crate) async fn post( info!(?body, "Client registration"); + let user_agent = user_agent.map(|ua| ua.to_string()); + // Validate the body let metadata = body.validate()?; @@ -244,7 +249,15 @@ pub(crate) async fn post( } } - let res = policy.evaluate_client_registration(&metadata).await?; + let res = policy + .evaluate_client_registration(mas_policy::ClientRegistrationInput { + client_metadata: &metadata, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent, + }, + }) + .await?; if !res.valid() { return Err(RouteError::PolicyDenied(res.violations)); } diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index ad6a15618..2c68cb782 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -676,7 +676,16 @@ async fn client_credentials_grant( // Make the request go through the policy engine let res = policy - .evaluate_client_credentials_grant(&scope, client) + .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { + user: None, + client, + scope: &scope, + grant_type: mas_policy::GrantType::ClientCredentials, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone().map(|ua| ua.raw), + }, + }) .await?; if !res.valid() { return Err(RouteError::DeniedByPolicy(res.violations)); diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index ee694768f..1ba8cfcc8 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -43,7 +43,8 @@ use super::{ UpstreamSessionsCookie, }; use crate::{ - impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage, SiteConfig, + impl_from_error_for_route, views::shared::OptionalPostAuthAction, BoundActivityTracker, + PreferredLanguage, SiteConfig, }; const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}"; @@ -199,6 +200,7 @@ pub(crate) async fn get( State(url_builder): State, State(homeserver): State, cookie_jar: CookieJar, + activity_tracker: BoundActivityTracker, user_agent: Option>, Path(link_id): Path, ) -> Result { @@ -430,7 +432,7 @@ pub(crate) async fn get( .with_code("User exists") .with_description(format!( r"Upstream account provider returned {localpart:?} as username, - which is not linked to that upstream account" + which is not linked to that upstream account" )) .with_language(&locale); @@ -441,16 +443,32 @@ pub(crate) async fn get( } let res = policy - .evaluate_upstream_oauth_register(&localpart, None) + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, + username: &localpart, + email: None, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone().map(|ua| ua.raw), + }, + }) .await?; - if !res.valid() { + if res.valid() { + // The username passes the policy check, add it to the context + ctx.with_localpart( + localpart, + provider.claims_imports.localpart.is_forced(), + ) + } else if provider.claims_imports.localpart.is_forced() { + // If the username claim is 'forced' but doesn't pass the policy check, + // we display an error message. // TODO: translate let ctx = ErrorContext::new() .with_code("Policy error") .with_description(format!( r"Upstream account provider returned {localpart:?} as username, - which does not pass the policy check: {res}" + which does not pass the policy check: {res}" )) .with_language(&locale); @@ -458,9 +476,10 @@ pub(crate) async fn get( cookie_jar, Html(templates.render_error(&ctx)?).into_response(), )); + } else { + // Else, we just ignore it when it doesn't pass the policy check. + ctx } - - ctx.with_localpart(localpart, provider.claims_imports.localpart.is_forced()) } None => ctx, } @@ -489,6 +508,7 @@ pub(crate) async fn post( user_agent: Option>, mut policy: Policy, PreferredLanguage(locale): PreferredLanguage, + activity_tracker: BoundActivityTracker, State(templates): State, State(homeserver): State, State(url_builder): State, @@ -743,8 +763,17 @@ pub(crate) async fn post( // Policy check let res = policy - .evaluate_upstream_oauth_register(&username, email.as_deref()) + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, + username: &username, + email: email.as_deref(), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone().map(|ua| ua.raw), + }, + }) .await?; + if !res.valid() { let form_state = res.violations diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index c2177c484..184b98c42 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -233,7 +233,15 @@ pub(crate) async fn post( } let res = policy - .evaluate_register(&form.username, &form.email) + .evaluate_register(mas_policy::RegisterInput { + registration_method: mas_policy::RegistrationMethod::Password, + username: &form.username, + email: Some(&form.email), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + user_agent: user_agent.clone().map(|ua| ua.raw), + }, + }) .await?; for violation in res.violations { diff --git a/crates/policy/src/bin/schema.rs b/crates/policy/src/bin/schema.rs index 2b49b7955..5a33a907d 100644 --- a/crates/policy/src/bin/schema.rs +++ b/crates/policy/src/bin/schema.rs @@ -7,7 +7,7 @@ use std::path::{Path, PathBuf}; use mas_policy::model::{ - AuthorizationGrantInput, ClientRegistrationInput, EmailInput, PasswordInput, RegisterInput, + AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput, }; use schemars::{gen::SchemaSettings, JsonSchema}; @@ -45,5 +45,4 @@ fn main() { write_schema::(output_root, "client_registration_input.json"); write_schema::(output_root, "authorization_grant_input.json"); write_schema::(output_root, "email_input.json"); - write_schema::(output_root, "password_input.json"); } diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 9ffe2f511..4bc319391 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -6,8 +6,6 @@ pub mod model; -use mas_data_model::{AuthorizationGrant, Client, DeviceCodeGrant, User}; -use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use opa_wasm::{ wasmtime::{Config, Engine, Module, OptLevel, Store}, Runtime, @@ -16,9 +14,10 @@ use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; -use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput}; -pub use self::model::{Code as ViolationCode, EvaluationResult, Violation}; -use crate::model::GrantType; +pub use self::model::{ + AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, EmailInput, + EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester, Violation, +}; #[derive(Debug, Error)] pub enum LoadError { @@ -190,16 +189,14 @@ impl Policy { name = "policy.evaluate_email", skip_all, fields( - input.email = email, + %input.email, ), err, )] pub async fn evaluate_email( &mut self, - email: &str, + input: EmailInput<'_>, ) -> Result { - let input = EmailInput { email }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate(&mut self.store, &self.entrypoints.email, &input) @@ -212,44 +209,16 @@ impl Policy { name = "policy.evaluate.register", skip_all, fields( - input.registration_method = "password", - input.user.username = username, - input.user.email = email, + ?input.registration_method, + input.username = input.username, + input.email = input.email, ), err, )] pub async fn evaluate_register( &mut self, - username: &str, - email: &str, - ) -> Result { - let input = RegisterInput::Password { username, email }; - - let [res]: [EvaluationResult; 1] = self - .instance - .evaluate(&mut self.store, &self.entrypoints.register, &input) - .await?; - - Ok(res) - } - - #[tracing::instrument( - name = "policy.evaluate.upstream_oauth_register", - skip_all, - fields( - input.registration_method = "password", - input.user.username = username, - input.user.email = email, - ), - err, - )] - pub async fn evaluate_upstream_oauth_register( - &mut self, - username: &str, - email: Option<&str>, + input: RegisterInput<'_>, ) -> Result { - let input = RegisterInput::UpstreamOAuth2 { username, email }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate(&mut self.store, &self.entrypoints.register, &input) @@ -261,10 +230,8 @@ impl Policy { #[tracing::instrument(skip(self))] pub async fn evaluate_client_registration( &mut self, - client_metadata: &VerifiedClientMetadata, + input: ClientRegistrationInput<'_>, ) -> Result { - let input = ClientRegistrationInput { client_metadata }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate( @@ -281,95 +248,15 @@ impl Policy { name = "policy.evaluate.authorization_grant", skip_all, fields( - input.authorization_grant.id = %authorization_grant.id, - input.scope = %authorization_grant.scope, - input.client.id = %client.id, - input.user.id = %user.id, + %input.scope, + %input.client.id, ), err, )] pub async fn evaluate_authorization_grant( &mut self, - authorization_grant: &AuthorizationGrant, - client: &Client, - user: &User, - ) -> Result { - let input = AuthorizationGrantInput { - user: Some(user), - client, - scope: &authorization_grant.scope, - grant_type: GrantType::AuthorizationCode, - }; - - let [res]: [EvaluationResult; 1] = self - .instance - .evaluate( - &mut self.store, - &self.entrypoints.authorization_grant, - &input, - ) - .await?; - - Ok(res) - } - - #[tracing::instrument( - name = "policy.evaluate.client_credentials_grant", - skip_all, - fields( - input.scope = %scope, - input.client.id = %client.id, - ), - err, - )] - pub async fn evaluate_client_credentials_grant( - &mut self, - scope: &Scope, - client: &Client, - ) -> Result { - let input = AuthorizationGrantInput { - user: None, - client, - scope, - grant_type: GrantType::ClientCredentials, - }; - - let [res]: [EvaluationResult; 1] = self - .instance - .evaluate( - &mut self.store, - &self.entrypoints.authorization_grant, - &input, - ) - .await?; - - Ok(res) - } - - #[tracing::instrument( - name = "policy.evaluate.device_code_grant", - skip_all, - fields( - input.device_code_grant.id = %device_code_grant.id, - input.scope = %device_code_grant.scope, - input.client.id = %client.id, - input.user.id = %user.id, - ), - err, - )] - pub async fn evaluate_device_code_grant( - &mut self, - device_code_grant: &DeviceCodeGrant, - client: &Client, - user: &User, + input: AuthorizationGrantInput<'_>, ) -> Result { - let input = AuthorizationGrantInput { - user: Some(user), - client, - scope: &device_code_grant.scope, - grant_type: GrantType::DeviceCode, - }; - let [res]: [EvaluationResult; 1] = self .instance .evaluate( @@ -385,6 +272,7 @@ impl Policy { #[cfg(test)] mod tests { + use super::*; #[tokio::test] @@ -415,19 +303,43 @@ mod tests { let mut policy = factory.instantiate().await.unwrap(); let res = policy - .evaluate_register("hello", "hello@example.com") + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@example.com"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) .await .unwrap(); assert!(!res.valid()); let res = policy - .evaluate_register("hello", "hello@foo.element.io") + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@foo.element.io"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) .await .unwrap(); assert!(res.valid()); let res = policy - .evaluate_register("hello", "hello@staging.element.io") + .evaluate_register(RegisterInput { + registration_method: RegistrationMethod::Password, + username: "hello", + email: Some("hello@staging.element.io"), + requester: Requester { + ip_address: None, + user_agent: None, + }, + }) .await .unwrap(); assert!(!res.valid()); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index aebca8928..aef4f3c08 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -9,6 +9,8 @@ //! This is useful to generate JSON schemas for each input type, which can then //! be type-checked by Open Policy Agent. +use std::net::IpAddr; + use mas_data_model::{Client, User}; use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use serde::{Deserialize, Serialize}; @@ -35,6 +37,12 @@ pub enum Code { /// The email domain is banned. EmailDomainBanned, + + /// The email address is not allowed. + EmailNotAllowed, + + /// The email address is banned. + EmailBanned, } impl Code { @@ -48,6 +56,8 @@ impl Code { Self::UsernameAllNumeric => "username-all-numeric", Self::EmailDomainNotAllowed => "email-domain-not-allowed", Self::EmailDomainBanned => "email-domain-banned", + Self::EmailNotAllowed => "email-not-allowed", + Self::EmailBanned => "email-banned", } } } @@ -92,21 +102,41 @@ impl EvaluationResult { } } -/// Input for the user registration policy. +/// Identity of the requester +#[derive(Serialize, Debug, Default)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct Requester { + /// IP address of the entity making the request + pub ip_address: Option, + + /// User agent of the entity making the request + pub user_agent: Option, +} + #[derive(Serialize, Debug)] -#[serde(tag = "registration_method")] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] -pub enum RegisterInput<'a> { +pub enum RegistrationMethod { #[serde(rename = "password")] - Password { username: &'a str, email: &'a str }, + Password, #[serde(rename = "upstream-oauth2")] - UpstreamOAuth2 { - username: &'a str, + UpstreamOAuth2, +} - #[serde(skip_serializing_if = "Option::is_none")] - email: Option<&'a str>, - }, +/// Input for the user registration policy. +#[derive(Serialize, Debug)] +#[serde(tag = "registration_method")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct RegisterInput<'a> { + pub registration_method: RegistrationMethod, + + pub username: &'a str, + + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option<&'a str>, + + pub requester: Requester, } /// Input for the client registration policy. @@ -119,6 +149,7 @@ pub struct ClientRegistrationInput<'a> { schemars(with = "std::collections::HashMap") )] pub client_metadata: &'a VerifiedClientMetadata, + pub requester: Requester, } #[derive(Serialize, Debug)] @@ -152,6 +183,8 @@ pub struct AuthorizationGrantInput<'a> { pub scope: &'a Scope, pub grant_type: GrantType, + + pub requester: Requester, } /// Input for the email add policy. @@ -160,12 +193,6 @@ pub struct AuthorizationGrantInput<'a> { #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] pub struct EmailInput<'a> { pub email: &'a str, -} -/// Input for the password set policy. -#[derive(Serialize, Debug)] -#[serde(rename_all = "snake_case")] -#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] -pub struct PasswordInput<'a> { - pub password: &'a str, + pub requester: Requester, } diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 2671fc807..e69d14804 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -382,13 +382,46 @@ policy: # don't require clients to provide a client_uri. default: false allow_missing_client_uri: false - # Restrict emails on registration to a specific domain - # Items in this array are evaluated as a glob - allowed_domains: - - *.example.com - # Ban specific domains from registration - banned_domains: - - *.banned.example.com + # Restrict what email addresses can be added to a user + emails: + # If specified, the email address *must* match one of the allowed addresses. + # If unspecified, all email addresses are allowed. + allowed_addresses: + # Exact emails that are allowed + literals: ["alice@example.com", "bob@example.com"] + # Regular expressions that match allowed emails + regexes: ["@example\\.com$"] + # Suffixes that match allowed emails + suffixes: ["@example.com"] + + # If specified, the email address *must not* match one of the banned addresses. + # If unspecified, all email addresses are allowed. + banned_addresses: + # Exact emails that are banned + literals: ["alice@evil.corp", "bob@evil.corp"] + # Emails that contains those substrings are banned + substrings: ["evil"] + # Regular expressions that match banned emails + regexes: ["@evil\\.corp$"] + # Suffixes that match banned emails + suffixes: ["@evil.corp"] + # Prefixes that match banned emails + prefixes: ["alice@"] + + requester: + # List of IP addresses and CIDRs that are not allowed to register + banned_ips: + - 192.168.0.1 + - 192.168.1.0/24 + - fe80::/64 + + # User agent patterns that are not allowed to register + banned_user_agents: + literals: ["Pretend this is Real;"] + substrings: ["Chrome"] + regexes: ["Chrome 1.*;"] + prefixes: ["Mozilla/"] + suffixes: ["Safari/605.1.15"] ``` ## `rate_limiting` diff --git a/policies/Makefile b/policies/Makefile index 60f037e48..18cb2dbfc 100644 --- a/policies/Makefile +++ b/policies/Makefile @@ -1,10 +1,11 @@ # Set to 1 to run OPA through Docker DOCKER := 0 PODMAN := 0 -OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:0.70.0-debug -REGAL_DOCKER_IMAGE := ghcr.io/styrainc/regal:0.29.2 +OPA_DOCKER_IMAGE := docker.io/openpolicyagent/opa:1.1.0-debug +REGAL_DOCKER_IMAGE := ghcr.io/styrainc/regal:0.31.0 INPUTS := \ + common/common.rego \ client_registration/client_registration.rego \ register/register.rego \ authorization_grant/authorization_grant.rego \ diff --git a/policies/authorization_grant/authorization_grant.rego b/policies/authorization_grant/authorization_grant.rego index 362ba080e..72fa7ee8b 100644 --- a/policies/authorization_grant/authorization_grant.rego +++ b/policies/authorization_grant/authorization_grant.rego @@ -5,6 +5,8 @@ package authorization_grant import rego.v1 +import data.common + default allow := false allow if { @@ -82,3 +84,10 @@ violation contains {"msg": "only one device scope is allowed at a time"} if { scope_list := split(input.scope, " ") count({scope | some scope in scope_list; startswith(scope, "urn:matrix:org.matrix.msc2967.client:device:")}) > 1 } + +violation contains {"msg": sprintf( + "Requester [%s] isn't allowed to do this action", + [common.format_requester(input.requester)], +)} if { + common.requester_banned(input.requester, data.requester) +} diff --git a/policies/client_registration/client_registration.rego b/policies/client_registration/client_registration.rego index 4f3e28516..ad1fa9e0b 100644 --- a/policies/client_registration/client_registration.rego +++ b/policies/client_registration/client_registration.rego @@ -18,8 +18,7 @@ parse_uri(url) := obj if { obj := {"scheme": matches[1], "authority": matches[2], "host": matches[3], "port": matches[4], "path": matches[5]} } -secure_url(x) if { - x +secure_url(_) if { data.client_registration.allow_insecure_uris } @@ -37,16 +36,12 @@ secure_url(x) if { url.port == "" } -host_matches_client_uri(x) if { - x - +host_matches_client_uri(_) if { # Do not check we allow host mismatch data.client_registration.allow_host_mismatch } -host_matches_client_uri(x) if { - x - +host_matches_client_uri(_) if { # Do not check if the client_uri is missing and we allow that data.client_registration.allow_missing_client_uri not data.client_metadata.client_uri diff --git a/policies/common/common.rego b/policies/common/common.rego new file mode 100644 index 000000000..8386555c6 --- /dev/null +++ b/policies/common/common.rego @@ -0,0 +1,84 @@ +package common + +import rego.v1 + +matches_string_constraints(str, constraints) if matches_regexes(str, constraints.regexes) + +matches_string_constraints(str, constraints) if matches_substrings(str, constraints.substrings) + +matches_string_constraints(str, constraints) if matches_literals(str, constraints.literals) + +matches_string_constraints(str, constraints) if matches_suffixes(str, constraints.suffixes) + +matches_string_constraints(str, constraints) if matches_prefixes(str, constraints.prefixes) + +matches_regexes(str, regexes) if { + some pattern in regexes + regex.match(pattern, str) +} + +matches_substrings(str, substrings) if { + some pattern in substrings + contains(str, pattern) +} + +matches_literals(str, literals) if { + some literal in literals + str == literal +} + +matches_suffixes(str, suffixes) if { + some suffix in suffixes + endswith(str, suffix) +} + +matches_prefixes(str, prefixes) if { + some prefix in prefixes + startswith(str, prefix) +} + +# Normalize an IP address or CIDR to a CIDR +normalize_cidr(ip) := ip if contains(ip, "/") + +# If it's an IPv4, append /32 +normalize_cidr(ip) := sprintf("%s/32", [ip]) if { + not contains(ip, "/") + not contains(ip, ":") +} + +# If it's an IPv6, append /128 +normalize_cidr(ip) := sprintf("%s/128", [ip]) if { + not contains(ip, "/") + contains(ip, ":") +} + +ip_in_list(ip, list) if { + some cidr in list + net.cidr_contains(normalize_cidr(cidr), ip) +} + +mxid(username, server_name) := sprintf("@%s:%s", [username, server_name]) + +requester_banned(requester, policy) if ip_in_list(requester.ip_address, policy.banned_ips) + +requester_banned(requester, policy) if matches_string_constraints(requester.user_agent, policy.banned_user_agents) + +format_requester(requester) := "unknown" if { + not requester.ip_address + not requester.user_agent +} + +format_requester(requester) := sprintf("%s / %s", [requester.ip_address, requester.user_agent]) if { + requester.ip_address + requester.user_agent +} + +format_requester(requester) := sprintf("%s", [requester.ip_address]) if { + requester.ip_address + not requester.user_agent +} + +format_requester(requester) := sprintf("%s", [requester.user_agent]) if { + not requester.ip_address + requester.user_agent +} diff --git a/policies/common/common_test.rego b/policies/common/common_test.rego new file mode 100644 index 000000000..52950a591 --- /dev/null +++ b/policies/common/common_test.rego @@ -0,0 +1,44 @@ +package common_test + +import data.common +import rego.v1 + +test_match_literals if { + common.matches_string_constraints("literal", {"literals": ["literal"]}) + not common.matches_string_constraints("literal", {"literals": ["lit"]}) +} + +test_match_substring if { + common.matches_string_constraints("some string", {"substrings": ["str"]}) + not common.matches_string_constraints("some string", {"substrings": ["something"]}) +} + +test_match_regex if { + common.matches_string_constraints("some string", {"regexes": ["^some"]}) + not common.matches_string_constraints("some string", {"regexes": ["^string"]}) +} + +test_match_prefix if { + common.matches_string_constraints("some string", {"prefixes": ["some"]}) + not common.matches_string_constraints("some string", {"prefixes": ["string"]}) +} + +test_match_suffix if { + common.matches_string_constraints("some string", {"suffixes": ["string"]}) + not common.matches_string_constraints("some string", {"suffixes": ["some"]}) +} + +test_ip_in_list if { + common.ip_in_list("192.168.1.1", ["192.168.1.1"]) + common.ip_in_list("192.168.1.1", ["192.168.1.0/24"]) + common.ip_in_list("::1", ["::1"]) + common.ip_in_list("::1", ["::/64"]) + not common.ip_in_list("192.168.1.1", ["192.168.1.2/32"]) +} + +test_requester_banned if { + common.requester_banned( + {"ip_address": "192.168.1.1", "user_agent": "Mozilla/5.0"}, + {"banned_ips": ["192.168.1.1"]}, + ) +} diff --git a/policies/email/email.rego b/policies/email/email.rego index 24b1d94b4..d9c5eb778 100644 --- a/policies/email/email.rego +++ b/policies/email/email.rego @@ -5,6 +5,8 @@ package email import rego.v1 +import data.common + default allow := false allow if { @@ -23,6 +25,16 @@ domain_allowed if { glob.match(allowed_domain, ["."], domain) } +# Allow any emails if the data.emails.allowed_addresses is not set +address_allowed if { + not data.emails.allowed_addresses +} + +# Allow an email only if its address is in the list of allowed addresses +address_allowed if { + common.matches_string_constraints(input.email, data.emails.allowed_addresses) +} + # METADATA # entrypoint: true violation contains {"code": "email-domain-not-allowed", "msg": "email domain is not allowed"} if { @@ -35,3 +47,13 @@ violation contains {"code": "email-domain-banned", "msg": "email domain is banne some banned_domain in data.banned_domains glob.match(banned_domain, ["."], domain) } + +# Deny emails if it's not allowed +violation contains {"code": "email-not-allowed", "msg": "email is not allowed"} if { + not address_allowed +} + +# Deny emails which match the email ban list constraint +violation contains {"code": "email-banned", "msg": "email is not allowed"} if { + common.matches_string_constraints(input.email, data.emails.banned_addresses) +} diff --git a/policies/email/email_test.rego b/policies/email/email_test.rego index 0adcdfad2..9d3750b56 100644 --- a/policies/email/email_test.rego +++ b/policies/email/email_test.rego @@ -27,3 +27,27 @@ test_banned_subdomain if { with data.allowed_domains as ["*.element.io"] with data.banned_domains as ["staging.element.io"] } + +test_regex_banned if { + not email.allow with input.email as "hello@staging.element.io" + with data.emails.banned_addresses.regexes as ["hello@.*"] +} + +test_literal_banned if { + not email.allow with input.email as "hello@staging.element.io" + with data.emails.banned_addresses.literals as ["hello@staging.element.io"] +} + +test_regex_allowed if { + email.allow with input.email as "hello@staging.element.io" + with data.emails.allowed_addresses.regexes as ["hello@.*"] + not email.allow with input.email as "hello@staging.element.io" + with data.emails.allowed_addresses.regexes as ["hola@.*"] +} + +test_literal_allowed if { + email.allow with input.email as "hello@staging.element.io" + with data.emails.allowed_addresses.literals as ["hello@staging.element.io"] + not email.allow with input.email as "hello@staging.element.io" + with data.emails.allowed_addresses.literals as ["hola@staging.element.io"] +} diff --git a/policies/register/register.rego b/policies/register/register.rego index 0fb36bf37..6189c3926 100644 --- a/policies/register/register.rego +++ b/policies/register/register.rego @@ -5,6 +5,7 @@ package register import rego.v1 +import data.common import data.email as email_policy default allow := false @@ -13,8 +14,6 @@ allow if { count(violation) == 0 } -mxid(username, server_name) := sprintf("@%s:%s", [username, server_name]) - # METADATA # entrypoint: true violation contains {"field": "username", "code": "username-too-short", "msg": "username too short"} if { @@ -22,7 +21,7 @@ violation contains {"field": "username", "code": "username-too-short", "msg": "u } violation contains {"field": "username", "code": "username-too-long", "msg": "username too long"} if { - user_id := mxid(input.username, data.server_name) + user_id := common.mxid(input.username, data.server_name) count(user_id) > 255 } @@ -48,6 +47,13 @@ violation contains {"msg": "unknown registration method"} if { not input.registration_method in ["password", "upstream-oauth2"] } +violation contains {"msg": sprintf( + "Requester [%s] isn't allowed to do this action", + [common.format_requester(input.requester)], +)} if { + common.requester_banned(input.requester, data.requester) +} + # Check that we supplied an email for password registration violation contains {"field": "email", "msg": "email required for password-based registration"} if { input.registration_method == "password" diff --git a/policies/register/register_test.rego b/policies/register/register_test.rego index 26e119248..51105ea39 100644 --- a/policies/register/register_test.rego +++ b/policies/register/register_test.rego @@ -74,3 +74,26 @@ test_invalid_username if { test_numeric_username if { not register.allow with input as {"username": "1234", "registration_method": "upstream-oauth2"} } + +test_ip_ban if { + not register.allow with input as { + "username": "hello", + "registration_method": "upstream-oauth2", + "requester": {"ip_address": "1.1.1.1"}, + } + with data.requester.banned_ips as ["1.1.1.1"] + + not register.allow with input as { + "username": "hello", + "registration_method": "upstream-oauth2", + "requester": {"ip_address": "1.1.1.1"}, + } + with data.requester.banned_ips as ["1.0.0.0/8"] + + not register.allow with input as { + "username": "hello", + "registration_method": "upstream-oauth2", + "requester": {"user_agent": "Evil Client"}, + } + with data.requester.banned_user_agents.substrings as ["Evil"] +} diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json index 9b2f77403..f23bf7a73 100644 --- a/policies/schema/authorization_grant_input.json +++ b/policies/schema/authorization_grant_input.json @@ -6,6 +6,7 @@ "required": [ "client", "grant_type", + "requester", "scope" ], "properties": { @@ -22,6 +23,9 @@ }, "grant_type": { "$ref": "#/definitions/GrantType" + }, + "requester": { + "$ref": "#/definitions/Requester" } }, "definitions": { @@ -32,6 +36,21 @@ "client_credentials", "urn:ietf:params:oauth:grant-type:device_code" ] + }, + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": "string", + "format": "ip" + }, + "user_agent": { + "description": "User agent of the entity making the request", + "type": "string" + } + } } } } \ No newline at end of file diff --git a/policies/schema/client_registration_input.json b/policies/schema/client_registration_input.json index cc9957a85..461645126 100644 --- a/policies/schema/client_registration_input.json +++ b/policies/schema/client_registration_input.json @@ -4,12 +4,33 @@ "description": "Input for the client registration policy.", "type": "object", "required": [ - "client_metadata" + "client_metadata", + "requester" ], "properties": { "client_metadata": { "type": "object", "additionalProperties": true + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "definitions": { + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": "string", + "format": "ip" + }, + "user_agent": { + "description": "User agent of the entity making the request", + "type": "string" + } + } } } } \ No newline at end of file diff --git a/policies/schema/email_input.json b/policies/schema/email_input.json index 19f4af523..d97f291be 100644 --- a/policies/schema/email_input.json +++ b/policies/schema/email_input.json @@ -4,11 +4,32 @@ "description": "Input for the email add policy.", "type": "object", "required": [ - "email" + "email", + "requester" ], "properties": { "email": { "type": "string" + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "definitions": { + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": "string", + "format": "ip" + }, + "user_agent": { + "description": "User agent of the entity making the request", + "type": "string" + } + } } } } \ No newline at end of file diff --git a/policies/schema/password_input.json b/policies/schema/password_input.json deleted file mode 100644 index c3cbf92d8..000000000 --- a/policies/schema/password_input.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "PasswordInput", - "description": "Input for the password set policy.", - "type": "object", - "required": [ - "password" - ], - "properties": { - "password": { - "type": "string" - } - } -} \ No newline at end of file diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json index e1d796ea1..cd8868cd4 100644 --- a/policies/schema/register_input.json +++ b/policies/schema/register_input.json @@ -2,49 +2,48 @@ "$schema": "http://json-schema.org/draft-07/schema#", "title": "RegisterInput", "description": "Input for the user registration policy.", - "oneOf": [ - { - "type": "object", - "required": [ - "email", - "registration_method", - "username" - ], - "properties": { - "registration_method": { - "type": "string", - "enum": [ - "password" - ] - }, - "username": { - "type": "string" - }, - "email": { - "type": "string" - } - } + "type": "object", + "required": [ + "registration_method", + "requester", + "username" + ], + "properties": { + "registration_method": { + "$ref": "#/definitions/RegistrationMethod" + }, + "username": { + "type": "string" }, - { + "email": { + "type": "string" + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "definitions": { + "RegistrationMethod": { + "type": "string", + "enum": [ + "password", + "upstream-oauth2" + ] + }, + "Requester": { + "description": "Identity of the requester", "type": "object", - "required": [ - "registration_method", - "username" - ], "properties": { - "registration_method": { + "ip_address": { + "description": "IP address of the entity making the request", "type": "string", - "enum": [ - "upstream-oauth2" - ] - }, - "username": { - "type": "string" + "format": "ip" }, - "email": { + "user_agent": { + "description": "User agent of the entity making the request", "type": "string" } } } - ] + } } \ No newline at end of file diff --git a/templates/components/field.html b/templates/components/field.html index c698c5152..d26fc3b66 100644 --- a/templates/components/field.html +++ b/templates/components/field.html @@ -73,6 +73,10 @@ {{ _("mas.errors.email_domain_not_allowed") }} {% elif error.code == "email-domain-banned" %} {{ _("mas.errors.email_domain_banned") }} + {% elif error.code == "email-not-allowed" %} + {{ _("mas.errors.email_not_allowed") }} + {% elif error.code == "email-banned" %} + {{ _("mas.errors.email_banned") }} {% else %} {{ _("mas.errors.denied_policy", policy=error.message) }} {% endif %} diff --git a/translations/en.json b/translations/en.json index 0da0ccf85..c27b1977c 100644 --- a/translations/en.json +++ b/translations/en.json @@ -299,7 +299,11 @@ }, "denied_policy": "Denied by policy: %(policy)s", "@denied_policy": { - "context": "components/errors.html:17:7-58, components/field.html:77:19-70" + "context": "components/errors.html:17:7-58, components/field.html:81:19-70" + }, + "email_banned": "Email is banned by the server policy", + "@email_banned": { + "context": "components/field.html:79:19-47" }, "email_domain_banned": "Email domain is banned by the server policy", "@email_domain_banned": { @@ -309,6 +313,10 @@ "@email_domain_not_allowed": { "context": "components/field.html:73:19-59" }, + "email_not_allowed": "Email is not allowed by the server policy", + "@email_not_allowed": { + "context": "components/field.html:77:19-52" + }, "field_required": "This field is required", "@field_required": { "context": "components/field.html:60:17-47" @@ -319,7 +327,7 @@ }, "password_mismatch": "Password fields don't match", "@password_mismatch": { - "context": "components/errors.html:13:7-40, components/field.html:80:17-50" + "context": "components/errors.html:13:7-40, components/field.html:84:17-50" }, "rate_limit_exceeded": "You've made too many requests in a short period. Please wait a few minutes and try again.", "@rate_limit_exceeded": { @@ -416,7 +424,7 @@ }, "or_separator": "Or", "@or_separator": { - "context": "components/field.html:99:10-31", + "context": "components/field.html:103:10-31", "description": "Separator between the login methods" }, "policy_violation": {