diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 8ffb50306..8768ae7b1 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -8,7 +8,7 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; -use mas_config::{ConfigurationSectionExt, PolicyConfig}; +use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig}; use tracing::{info, info_span}; use crate::util::policy_factory_from_config; @@ -33,8 +33,9 @@ impl Options { SC::Policy => { let _span = info_span!("cli.debug.policy").entered(); let config = PolicyConfig::extract_or_default(figment)?; + let matrix_config = MatrixConfig::extract(figment)?; info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config).await?; + let policy_factory = policy_factory_from_config(&config, &matrix_config).await?; let _instance = policy_factory.instantiate().await?; } diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 9234e45bb..2b867ca86 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -123,7 +123,7 @@ impl Options { // Load and compile the WASM policies (and fallback to the default embedded one) info!("Loading and compiling the policy module"); - let policy_factory = policy_factory_from_config(&config.policy).await?; + let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?; let policy_factory = Arc::new(policy_factory); let url_builder = UrlBuilder::new( diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index a4ab8eba6..3d3d8f676 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -101,6 +101,7 @@ pub fn mailer_from_config( pub async fn policy_factory_from_config( config: &PolicyConfig, + matrix_config: &MatrixConfig, ) -> Result { let policy_file = tokio::fs::File::open(&config.wasm_module) .await @@ -113,7 +114,10 @@ pub async fn policy_factory_from_config( email: config.email_entrypoint.clone(), }; - PolicyFactory::load(policy_file, config.data.clone(), entrypoints) + let data = + mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone()); + + PolicyFactory::load(policy_file, data, entrypoints) .await .context("failed to load the policy") } diff --git a/crates/handlers/src/graphql/tests.rs b/crates/handlers/src/graphql/tests.rs index 1d72eb66b..a830f2448 100644 --- a/crates/handlers/src/graphql/tests.rs +++ b/crates/handlers/src/graphql/tests.rs @@ -469,9 +469,12 @@ async fn test_oauth2_client_credentials(pool: PgPool) { // Now make the client admin and try again let state = { let mut state = state; - state.policy_factory = test_utils::policy_factory(serde_json::json!({ - "admin_clients": [client_id], - })) + state.policy_factory = test_utils::policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id], + }), + ) .await .unwrap(); state @@ -593,9 +596,12 @@ async fn test_add_user(pool: PgPool) { // Make the client admin let state = { let mut state = state; - state.policy_factory = test_utils::policy_factory(serde_json::json!({ - "admin_clients": [client_id], - })) + state.policy_factory = test_utils::policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id], + }), + ) .await .unwrap(); state diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index b5c3b0341..ad6a15618 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -1475,9 +1475,12 @@ mod tests { // Now, if we add the client to the admin list in the policy, it should work let state = { let mut state = state; - state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({ - "admin_clients": [client_id] - })) + state.policy_factory = crate::test_utils::policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id] + }), + ) .await .unwrap(); state diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index f6a16037a..5f7240dab 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -69,6 +69,7 @@ pub(crate) fn setup() { } pub(crate) async fn policy_factory( + server_name: &str, data: serde_json::Value, ) -> Result, anyhow::Error> { let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR")) @@ -84,6 +85,8 @@ pub(crate) async fn policy_factory( email: "email/violation".to_owned(), }; + let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data); + let policy_factory = PolicyFactory::load(file, data, entrypoints).await?; let policy_factory = Arc::new(policy_factory); Ok(policy_factory) @@ -192,7 +195,8 @@ impl TestState { PasswordManager::disabled() }; - let policy_factory = policy_factory(serde_json::json!({})).await?; + let policy_factory = + policy_factory(&site_config.server_name, serde_json::json!({})).await?; let homeserver_connection = Arc::new(MockHomeserverConnection::new(&site_config.server_name)); @@ -297,9 +301,12 @@ impl TestState { // Make the client admin let state = { let mut state = self.clone(); - state.policy_factory = policy_factory(serde_json::json!({ - "admin_clients": [client_id], - })) + state.policy_factory = policy_factory( + "example.com", + serde_json::json!({ + "admin_clients": [client_id], + }), + ) .await .unwrap(); state diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index e48ff190b..b37c4fce7 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -759,10 +759,12 @@ pub(crate) async fn post( Some("username") => form_state.with_error_on_field( mas_templates::UpstreamRegisterFormField::Username, FieldError::Policy { + code: violation.code.map(|c| c.as_str()), message: violation.msg, }, ), _ => form_state.with_error_on_form(FormError::Policy { + code: violation.code.map(|c| c.as_str()), message: violation.msg, }), } diff --git a/crates/handlers/src/views/register.rs b/crates/handlers/src/views/register.rs index 0eaa99801..5c81eb164 100644 --- a/crates/handlers/src/views/register.rs +++ b/crates/handlers/src/views/register.rs @@ -154,6 +154,7 @@ pub(crate) async fn post( state.add_error_on_form(FormError::Captcha); } + let mut homeserver_denied_username = false; if form.username.is_empty() { state.add_error_on_field(RegisterFormField::Username, FieldError::Required); } else if repo.user().exists(&form.username).await? { @@ -161,13 +162,14 @@ pub(crate) async fn post( state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); } else if !homeserver.is_localpart_available(&form.username).await? { // The user already exists on the homeserver - // XXX: we may want to return different errors like "this username is reserved" tracing::warn!( username = &form.username, - "User tried to register with a reserved username" + "Homeserver denied username provided by user" ); - state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); + // We defer adding the error on the field, until we know whether we had another + // error from the policy, to avoid showing both + homeserver_denied_username = true; } if form.email.is_empty() { @@ -197,6 +199,7 @@ pub(crate) async fn post( state.add_error_on_field( RegisterFormField::Password, FieldError::Policy { + code: None, message: "Password is too weak".to_owned(), }, ); @@ -216,27 +219,41 @@ pub(crate) async fn post( Some("email") => state.add_error_on_field( RegisterFormField::Email, FieldError::Policy { + code: violation.code.map(|c| c.as_str()), message: violation.msg, }, ), - Some("username") => state.add_error_on_field( - RegisterFormField::Username, - FieldError::Policy { - message: violation.msg, - }, - ), + Some("username") => { + // If the homeserver denied the username, but we also had an error on the policy + // side, we don't want to show both, so we reset the state here + homeserver_denied_username = false; + state.add_error_on_field( + RegisterFormField::Username, + FieldError::Policy { + code: violation.code.map(|c| c.as_str()), + message: violation.msg, + }, + ); + } Some("password") => state.add_error_on_field( RegisterFormField::Password, FieldError::Policy { + code: violation.code.map(|c| c.as_str()), message: violation.msg, }, ), _ => state.add_error_on_form(FormError::Policy { + code: violation.code.map(|c| c.as_str()), message: violation.msg, }), } } + if homeserver_denied_username { + // XXX: we may want to return different errors like "this username is reserved" + state.add_error_on_field(RegisterFormField::Username, FieldError::Exists); + } + if state.is_valid() { // Check the rate limit if we are about to process the form if let Err(e) = limiter.check_registration(requester) { @@ -481,7 +498,7 @@ mod tests { } #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] - async fn test_register_username_too_short(pool: PgPool) { + async fn test_register_username_too_long(pool: PgPool) { setup(); let state = TestState::from_pool(pool).await.unwrap(); let cookies = CookieHelper::new(); @@ -507,7 +524,7 @@ mod tests { let request = Request::post(&*mas_router::Register::default().path_and_query()).form( serde_json::json!({ "csrf": csrf_token, - "username": "a", + "username": "a".repeat(256), "email": "john@example.com", "password": "hunter2", "password_confirm": "hunter2", @@ -518,7 +535,11 @@ mod tests { let response = state.request(request).await; cookies.save_cookies(&response); response.assert_status(StatusCode::OK); - assert!(response.body().contains("username too short")); + assert!( + response.body().contains("Username is too long"), + "response body: {}", + response.body() + ); } /// When the user already exists in the database, it should give an error diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index 9db450fca..9ffe2f511 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -12,11 +12,12 @@ use opa_wasm::{ wasmtime::{Config, Engine, Module, OptLevel, Store}, Runtime, }; +use serde::Serialize; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt}; use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput}; -pub use self::model::{EvaluationResult, Violation}; +pub use self::model::{Code as ViolationCode, EvaluationResult, Violation}; use crate::model::GrantType; #[derive(Debug, Error)] @@ -69,10 +70,34 @@ impl Entrypoints { } } +#[derive(Serialize, Debug)] +pub struct Data { + server_name: String, + + #[serde(flatten)] + rest: Option, +} + +impl Data { + #[must_use] + pub fn new(server_name: String) -> Self { + Self { + server_name, + rest: None, + } + } + + #[must_use] + pub fn with_rest(mut self, rest: serde_json::Value) -> Self { + self.rest = Some(rest); + self + } +} + pub struct PolicyFactory { engine: Engine, module: Module, - data: serde_json::Value, + data: Data, entrypoints: Entrypoints, } @@ -80,7 +105,7 @@ impl PolicyFactory { #[tracing::instrument(name = "policy.load", skip(source), err)] pub async fn load( mut source: impl AsyncRead + std::marker::Unpin, - data: serde_json::Value, + data: Data, entrypoints: Entrypoints, ) -> Result { let mut config = Config::default(); @@ -364,10 +389,10 @@ mod tests { #[tokio::test] async fn test_register() { - let data = serde_json::json!({ + let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({ "allowed_domains": ["element.io", "*.element.io"], "banned_domains": ["staging.element.io"], - }); + })); #[allow(clippy::disallowed_types)] let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")) diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index c8d46599d..aebca8928 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -13,6 +13,45 @@ use mas_data_model::{Client, User}; use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use serde::{Deserialize, Serialize}; +/// A well-known policy code. +#[derive(Deserialize, Debug, Clone, Copy)] +#[serde(rename_all = "kebab-case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub enum Code { + /// The username is too short. + UsernameTooShort, + + /// The username is too long. + UsernameTooLong, + + /// The username contains invalid characters. + UsernameInvalidChars, + + /// The username contains only numeric characters. + UsernameAllNumeric, + + /// The email domain is not allowed. + EmailDomainNotAllowed, + + /// The email domain is banned. + EmailDomainBanned, +} + +impl Code { + /// Returns the code as a string + #[must_use] + pub fn as_str(&self) -> &'static str { + match self { + Self::UsernameTooShort => "username-too-short", + Self::UsernameTooLong => "username-too-long", + Self::UsernameInvalidChars => "username-invalid-chars", + Self::UsernameAllNumeric => "username-all-numeric", + Self::EmailDomainNotAllowed => "email-domain-not-allowed", + Self::EmailDomainBanned => "email-domain-banned", + } + } +} + /// A single violation of a policy. #[derive(Deserialize, Debug)] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] @@ -20,6 +59,7 @@ pub struct Violation { pub msg: String, pub redirect_uri: Option, pub field: Option, + pub code: Option, } /// The result of a policy evaluation. diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index 70fc16d0f..eb10e9592 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -462,6 +462,7 @@ impl TemplateContext for LoginContext { .with_error_on_field( LoginFormField::Password, FieldError::Policy { + code: None, message: "password too short".to_owned(), }, ), diff --git a/crates/templates/src/forms.rs b/crates/templates/src/forms.rs index a018633ba..2b769e122 100644 --- a/crates/templates/src/forms.rs +++ b/crates/templates/src/forms.rs @@ -36,6 +36,9 @@ pub enum FieldError { /// Denied by the policy Policy { + /// Well-known policy code + code: Option<&'static str>, + /// Message for this policy violation message: String, }, @@ -59,6 +62,9 @@ pub enum FormError { /// Denied by the policy Policy { + /// Well-known policy code + code: Option<&'static str>, + /// Message for this policy violation message: String, }, diff --git a/policies/email/email.rego b/policies/email/email.rego index 00103da26..24b1d94b4 100644 --- a/policies/email/email.rego +++ b/policies/email/email.rego @@ -25,12 +25,12 @@ domain_allowed if { # METADATA # entrypoint: true -violation contains {"msg": "email domain is not allowed"} if { +violation contains {"code": "email-domain-not-allowed", "msg": "email domain is not allowed"} if { not domain_allowed } # Deny emails with their domain in the domains banlist -violation contains {"msg": "email domain is banned"} if { +violation contains {"code": "email-domain-banned", "msg": "email domain is banned"} if { [_, domain] := split(input.email, "@") some banned_domain in data.banned_domains glob.match(banned_domain, ["."], domain) diff --git a/policies/register/register.rego b/policies/register/register.rego index 4507a10ca..0fb36bf37 100644 --- a/policies/register/register.rego +++ b/policies/register/register.rego @@ -13,17 +13,30 @@ allow if { count(violation) == 0 } +mxid(username, server_name) := sprintf("@%s:%s", [username, server_name]) + # METADATA # entrypoint: true -violation contains {"field": "username", "msg": "username too short"} if { - count(input.username) <= 2 +violation contains {"field": "username", "code": "username-too-short", "msg": "username too short"} if { + count(input.username) == 0 +} + +violation contains {"field": "username", "code": "username-too-long", "msg": "username too long"} if { + user_id := mxid(input.username, data.server_name) + count(user_id) > 255 } -violation contains {"field": "username", "msg": "username too long"} if { - count(input.username) > 64 +violation contains { + "field": "username", "code": "username-all-numeric", + "msg": "username must contain at least one non-numeric character", +} if { + regex.match(`^[0-9]+$`, input.username) } -violation contains {"field": "username", "msg": "username contains invalid characters"} if { +violation contains { + "field": "username", "code": "username-invalid-chars", + "msg": "username contains invalid characters", +} if { not regex.match(`^[a-z0-9.=_/-]+$`, input.username) } diff --git a/policies/register/register_test.rego b/policies/register/register_test.rego index 0d270ec26..26e119248 100644 --- a/policies/register/register_test.rego +++ b/policies/register/register_test.rego @@ -42,17 +42,35 @@ test_no_email if { register.allow with input as {"username": "hello", "registration_method": "upstream-oauth2"} } -test_short_username if { - not register.allow with input as {"username": "a", "registration_method": "upstream-oauth2"} +test_empty_username if { + not register.allow with input as {"username": "", "registration_method": "upstream-oauth2"} } test_long_username if { not register.allow with input as { - "username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "username": concat("", ["a" | some x in numbers.range(1, 249)]), "registration_method": "upstream-oauth2", } + with data.server_name as "matrix.org" + + # This makes a MXID that is exactly 255 characters long + register.allow with input as { + "username": concat("", ["a" | some x in numbers.range(1, 249)]), + "registration_method": "upstream-oauth2", + } + with data.server_name as "a.io" + + not register.allow with input as { + "username": concat("", ["a" | some x in numbers.range(1, 250)]), + "registration_method": "upstream-oauth2", + } + with data.server_name as "a.io" } test_invalid_username if { not register.allow with input as {"username": "hello world", "registration_method": "upstream-oauth2"} } + +test_numeric_username if { + not register.allow with input as {"username": "1234", "registration_method": "upstream-oauth2"} +} diff --git a/templates/components/field.html b/templates/components/field.html index 10704d90c..c698c5152 100644 --- a/templates/components/field.html +++ b/templates/components/field.html @@ -61,7 +61,21 @@ {% elif error.kind == "exists" and field.name == "username" %} {{ _("mas.errors.username_taken") }} {% elif error.kind == "policy" %} - {{ _("mas.errors.denied_policy", policy=error.message) }} + {% if error.code == "username-too-short" %} + {{ _("mas.errors.username_too_short") }} + {% elif error.code == "username-too-long" %} + {{ _("mas.errors.username_too_long") }} + {% elif error.code == "username-invalid-chars" %} + {{ _("mas.errors.username_invalid_chars") }} + {% elif error.code == "username-all-numeric" %} + {{ _("mas.errors.username_all_numeric") }} + {% elif error.code == "email-domain-not-allowed" %} + {{ _("mas.errors.email_domain_not_allowed") }} + {% elif error.code == "email-domain-banned" %} + {{ _("mas.errors.email_domain_banned") }} + {% else %} + {{ _("mas.errors.denied_policy", policy=error.message) }} + {% endif %} {% elif error.kind == "password_mismatch" %} {{ _("mas.errors.password_mismatch") }} {% else %} diff --git a/translations/en.json b/translations/en.json index 4404ca4ca..e52d4f6db 100644 --- a/translations/en.json +++ b/translations/en.json @@ -284,7 +284,15 @@ }, "denied_policy": "Denied by policy: %(policy)s", "@denied_policy": { - "context": "components/errors.html:17:7-58, components/field.html:64:17-68" + "context": "components/errors.html:17:7-58, components/field.html:77:19-70" + }, + "email_domain_banned": "Email domain is banned by the server policy", + "@email_domain_banned": { + "context": "components/field.html:75:19-54" + }, + "email_domain_not_allowed": "Email domain is not allowed by the server policy", + "@email_domain_not_allowed": { + "context": "components/field.html:73:19-59" }, "field_required": "This field is required", "@field_required": { @@ -296,15 +304,31 @@ }, "password_mismatch": "Password fields don't match", "@password_mismatch": { - "context": "components/errors.html:13:7-40, components/field.html:66:17-50" + "context": "components/errors.html:13:7-40, components/field.html:80:17-50" }, "rate_limit_exceeded": "You've made too many requests in a short period. Please wait a few minutes and try again.", "@rate_limit_exceeded": { "context": "components/errors.html:15:7-42, pages/recovery/progress.html:26:11-46" }, + "username_all_numeric": "Username cannot consist solely of numbers", + "@username_all_numeric": { + "context": "components/field.html:71:19-55" + }, + "username_invalid_chars": "Username contains invalid characters. Use lowercase letters, numbers, dashes and underscores only.", + "@username_invalid_chars": { + "context": "components/field.html:69:19-57" + }, "username_taken": "This username is already taken", "@username_taken": { "context": "components/field.html:62:17-47" + }, + "username_too_long": "Username is too long", + "@username_too_long": { + "context": "components/field.html:67:19-52" + }, + "username_too_short": "Username is too short", + "@username_too_short": { + "context": "components/field.html:65:19-53" } }, "login": { @@ -377,7 +401,7 @@ }, "or_separator": "Or", "@or_separator": { - "context": "components/field.html:85:10-31", + "context": "components/field.html:99:10-31", "description": "Separator between the login methods" }, "policy_violation": {