Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crates/cli/src/commands/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::process::ExitCode;

use clap::Parser;
use figment::Figment;
use mas_config::{ConfigurationSectionExt, PolicyConfig};
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
use tracing::{info, info_span};

use crate::util::policy_factory_from_config;
Expand All @@ -33,8 +33,9 @@ impl Options {
SC::Policy => {
let _span = info_span!("cli.debug.policy").entered();
let config = PolicyConfig::extract_or_default(figment)?;
let matrix_config = MatrixConfig::extract(figment)?;
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config).await?;
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;

let _instance = policy_factory.instantiate().await?;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl Options {

// Load and compile the WASM policies (and fallback to the default embedded one)
info!("Loading and compiling the policy module");
let policy_factory = policy_factory_from_config(&config.policy).await?;
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
let policy_factory = Arc::new(policy_factory);

let url_builder = UrlBuilder::new(
Expand Down
6 changes: 5 additions & 1 deletion crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ pub fn mailer_from_config(

pub async fn policy_factory_from_config(
config: &PolicyConfig,
matrix_config: &MatrixConfig,
) -> Result<PolicyFactory, anyhow::Error> {
let policy_file = tokio::fs::File::open(&config.wasm_module)
.await
Expand All @@ -113,7 +114,10 @@ pub async fn policy_factory_from_config(
email: config.email_entrypoint.clone(),
};

PolicyFactory::load(policy_file, config.data.clone(), entrypoints)
let data =
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());

PolicyFactory::load(policy_file, data, entrypoints)
.await
.context("failed to load the policy")
}
Expand Down
18 changes: 12 additions & 6 deletions crates/handlers/src/graphql/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,12 @@ async fn test_oauth2_client_credentials(pool: PgPool) {
// Now make the client admin and try again
let state = {
let mut state = state;
state.policy_factory = test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id],
}))
state.policy_factory = test_utils::policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id],
}),
)
.await
.unwrap();
state
Expand Down Expand Up @@ -593,9 +596,12 @@ async fn test_add_user(pool: PgPool) {
// Make the client admin
let state = {
let mut state = state;
state.policy_factory = test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id],
}))
state.policy_factory = test_utils::policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id],
}),
)
.await
.unwrap();
state
Expand Down
9 changes: 6 additions & 3 deletions crates/handlers/src/oauth2/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1475,9 +1475,12 @@ mod tests {
// Now, if we add the client to the admin list in the policy, it should work
let state = {
let mut state = state;
state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({
"admin_clients": [client_id]
}))
state.policy_factory = crate::test_utils::policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id]
}),
)
.await
.unwrap();
state
Expand Down
15 changes: 11 additions & 4 deletions crates/handlers/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ pub(crate) fn setup() {
}

pub(crate) async fn policy_factory(
server_name: &str,
data: serde_json::Value,
) -> Result<Arc<PolicyFactory>, anyhow::Error> {
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
Expand All @@ -84,6 +85,8 @@ pub(crate) async fn policy_factory(
email: "email/violation".to_owned(),
};

let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);

let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
let policy_factory = Arc::new(policy_factory);
Ok(policy_factory)
Expand Down Expand Up @@ -192,7 +195,8 @@ impl TestState {
PasswordManager::disabled()
};

let policy_factory = policy_factory(serde_json::json!({})).await?;
let policy_factory =
policy_factory(&site_config.server_name, serde_json::json!({})).await?;

let homeserver_connection =
Arc::new(MockHomeserverConnection::new(&site_config.server_name));
Expand Down Expand Up @@ -297,9 +301,12 @@ impl TestState {
// Make the client admin
let state = {
let mut state = self.clone();
state.policy_factory = policy_factory(serde_json::json!({
"admin_clients": [client_id],
}))
state.policy_factory = policy_factory(
"example.com",
serde_json::json!({
"admin_clients": [client_id],
}),
)
.await
.unwrap();
state
Expand Down
2 changes: 2 additions & 0 deletions crates/handlers/src/upstream_oauth2/link.rs
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,12 @@ pub(crate) async fn post(
Some("username") => form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
),
_ => form_state.with_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
}),
}
Expand Down
45 changes: 33 additions & 12 deletions crates/handlers/src/views/register.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,22 @@ pub(crate) async fn post(
state.add_error_on_form(FormError::Captcha);
}

let mut homeserver_denied_username = false;
if form.username.is_empty() {
state.add_error_on_field(RegisterFormField::Username, FieldError::Required);
} else if repo.user().exists(&form.username).await? {
// The user already exists in the database
state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
} else if !homeserver.is_localpart_available(&form.username).await? {
// The user already exists on the homeserver
// XXX: we may want to return different errors like "this username is reserved"
tracing::warn!(
username = &form.username,
"User tried to register with a reserved username"
"Homeserver denied username provided by user"
);

state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
// We defer adding the error on the field, until we know whether we had another
// error from the policy, to avoid showing both
homeserver_denied_username = true;
}

if form.email.is_empty() {
Expand Down Expand Up @@ -197,6 +199,7 @@ pub(crate) async fn post(
state.add_error_on_field(
RegisterFormField::Password,
FieldError::Policy {
code: None,
message: "Password is too weak".to_owned(),
},
);
Expand All @@ -216,27 +219,41 @@ pub(crate) async fn post(
Some("email") => state.add_error_on_field(
RegisterFormField::Email,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
),
Some("username") => state.add_error_on_field(
RegisterFormField::Username,
FieldError::Policy {
message: violation.msg,
},
),
Some("username") => {
// If the homeserver denied the username, but we also had an error on the policy
// side, we don't want to show both, so we reset the state here
homeserver_denied_username = false;
state.add_error_on_field(
RegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
);
}
Some("password") => state.add_error_on_field(
RegisterFormField::Password,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
),
_ => state.add_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
}),
}
}

if homeserver_denied_username {
// XXX: we may want to return different errors like "this username is reserved"
state.add_error_on_field(RegisterFormField::Username, FieldError::Exists);
}

if state.is_valid() {
// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_registration(requester) {
Expand Down Expand Up @@ -481,7 +498,7 @@ mod tests {
}

#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_register_username_too_short(pool: PgPool) {
async fn test_register_username_too_long(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();
Expand All @@ -507,7 +524,7 @@ mod tests {
let request = Request::post(&*mas_router::Register::default().path_and_query()).form(
serde_json::json!({
"csrf": csrf_token,
"username": "a",
"username": "a".repeat(256),
"email": "[email protected]",
"password": "hunter2",
"password_confirm": "hunter2",
Expand All @@ -518,7 +535,11 @@ mod tests {
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
assert!(response.body().contains("username too short"));
assert!(
response.body().contains("Username is too long"),
"response body: {}",
response.body()
);
}

/// When the user already exists in the database, it should give an error
Expand Down
35 changes: 30 additions & 5 deletions crates/policy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ use opa_wasm::{
wasmtime::{Config, Engine, Module, OptLevel, Store},
Runtime,
};
use serde::Serialize;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt};

use self::model::{AuthorizationGrantInput, ClientRegistrationInput, EmailInput, RegisterInput};
pub use self::model::{EvaluationResult, Violation};
pub use self::model::{Code as ViolationCode, EvaluationResult, Violation};
use crate::model::GrantType;

#[derive(Debug, Error)]
Expand Down Expand Up @@ -69,18 +70,42 @@ impl Entrypoints {
}
}

#[derive(Serialize, Debug)]
pub struct Data {
server_name: String,

#[serde(flatten)]
rest: Option<serde_json::Value>,
}

impl Data {
#[must_use]
pub fn new(server_name: String) -> Self {
Self {
server_name,
rest: None,
}
}

#[must_use]
pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
self.rest = Some(rest);
self
}
}

pub struct PolicyFactory {
engine: Engine,
module: Module,
data: serde_json::Value,
data: Data,
entrypoints: Entrypoints,
}

impl PolicyFactory {
#[tracing::instrument(name = "policy.load", skip(source), err)]
pub async fn load(
mut source: impl AsyncRead + std::marker::Unpin,
data: serde_json::Value,
data: Data,
entrypoints: Entrypoints,
) -> Result<Self, LoadError> {
let mut config = Config::default();
Expand Down Expand Up @@ -364,10 +389,10 @@ mod tests {

#[tokio::test]
async fn test_register() {
let data = serde_json::json!({
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
"allowed_domains": ["element.io", "*.element.io"],
"banned_domains": ["staging.element.io"],
});
}));

#[allow(clippy::disallowed_types)]
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
Expand Down
40 changes: 40 additions & 0 deletions crates/policy/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,53 @@ use mas_data_model::{Client, User};
use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
use serde::{Deserialize, Serialize};

/// A well-known policy code.
#[derive(Deserialize, Debug, Clone, Copy)]
#[serde(rename_all = "kebab-case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum Code {
/// The username is too short.
UsernameTooShort,

/// The username is too long.
UsernameTooLong,

/// The username contains invalid characters.
UsernameInvalidChars,

/// The username contains only numeric characters.
UsernameAllNumeric,

/// The email domain is not allowed.
EmailDomainNotAllowed,

/// The email domain is banned.
EmailDomainBanned,
}

impl Code {
/// Returns the code as a string
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
Self::UsernameTooShort => "username-too-short",
Self::UsernameTooLong => "username-too-long",
Self::UsernameInvalidChars => "username-invalid-chars",
Self::UsernameAllNumeric => "username-all-numeric",
Self::EmailDomainNotAllowed => "email-domain-not-allowed",
Self::EmailDomainBanned => "email-domain-banned",
}
}
}

/// A single violation of a policy.
#[derive(Deserialize, Debug)]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub struct Violation {
pub msg: String,
pub redirect_uri: Option<String>,
pub field: Option<String>,
pub code: Option<Code>,
}

/// The result of a policy evaluation.
Expand Down
1 change: 1 addition & 0 deletions crates/templates/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ impl TemplateContext for LoginContext {
.with_error_on_field(
LoginFormField::Password,
FieldError::Policy {
code: None,
message: "password too short".to_owned(),
},
),
Expand Down
Loading
Loading