Skip to content

Commit 2558a52

Browse files
committed
Allow longer & shorter usernames, complying with the MXID length spec
1 parent 881c6df commit 2558a52

File tree

9 files changed

+89
-26
lines changed

9 files changed

+89
-26
lines changed

crates/cli/src/commands/debug.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::process::ExitCode;
88

99
use clap::Parser;
1010
use figment::Figment;
11-
use mas_config::{ConfigurationSectionExt, PolicyConfig};
11+
use mas_config::{ConfigurationSection, ConfigurationSectionExt, MatrixConfig, PolicyConfig};
1212
use tracing::{info, info_span};
1313

1414
use crate::util::policy_factory_from_config;
@@ -33,8 +33,9 @@ impl Options {
3333
SC::Policy => {
3434
let _span = info_span!("cli.debug.policy").entered();
3535
let config = PolicyConfig::extract_or_default(figment)?;
36+
let matrix_config = MatrixConfig::extract(figment)?;
3637
info!("Loading and compiling the policy module");
37-
let policy_factory = policy_factory_from_config(&config).await?;
38+
let policy_factory = policy_factory_from_config(&config, &matrix_config).await?;
3839

3940
let _instance = policy_factory.instantiate().await?;
4041
}

crates/cli/src/commands/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ impl Options {
123123

124124
// Load and compile the WASM policies (and fallback to the default embedded one)
125125
info!("Loading and compiling the policy module");
126-
let policy_factory = policy_factory_from_config(&config.policy).await?;
126+
let policy_factory = policy_factory_from_config(&config.policy, &config.matrix).await?;
127127
let policy_factory = Arc::new(policy_factory);
128128

129129
let url_builder = UrlBuilder::new(

crates/cli/src/util.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ pub fn mailer_from_config(
101101

102102
pub async fn policy_factory_from_config(
103103
config: &PolicyConfig,
104+
matrix_config: &MatrixConfig,
104105
) -> Result<PolicyFactory, anyhow::Error> {
105106
let policy_file = tokio::fs::File::open(&config.wasm_module)
106107
.await
@@ -113,7 +114,10 @@ pub async fn policy_factory_from_config(
113114
email: config.email_entrypoint.clone(),
114115
};
115116

116-
PolicyFactory::load(policy_file, config.data.clone(), entrypoints)
117+
let data =
118+
mas_policy::Data::new(matrix_config.homeserver.clone()).with_rest(config.data.clone());
119+
120+
PolicyFactory::load(policy_file, data, entrypoints)
117121
.await
118122
.context("failed to load the policy")
119123
}

crates/handlers/src/graphql/tests.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,12 @@ async fn test_oauth2_client_credentials(pool: PgPool) {
469469
// Now make the client admin and try again
470470
let state = {
471471
let mut state = state;
472-
state.policy_factory = test_utils::policy_factory(serde_json::json!({
473-
"admin_clients": [client_id],
474-
}))
472+
state.policy_factory = test_utils::policy_factory(
473+
"example.com",
474+
serde_json::json!({
475+
"admin_clients": [client_id],
476+
}),
477+
)
475478
.await
476479
.unwrap();
477480
state
@@ -593,9 +596,12 @@ async fn test_add_user(pool: PgPool) {
593596
// Make the client admin
594597
let state = {
595598
let mut state = state;
596-
state.policy_factory = test_utils::policy_factory(serde_json::json!({
597-
"admin_clients": [client_id],
598-
}))
599+
state.policy_factory = test_utils::policy_factory(
600+
"example.com",
601+
serde_json::json!({
602+
"admin_clients": [client_id],
603+
}),
604+
)
599605
.await
600606
.unwrap();
601607
state

crates/handlers/src/oauth2/token.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,9 +1475,12 @@ mod tests {
14751475
// Now, if we add the client to the admin list in the policy, it should work
14761476
let state = {
14771477
let mut state = state;
1478-
state.policy_factory = crate::test_utils::policy_factory(serde_json::json!({
1479-
"admin_clients": [client_id]
1480-
}))
1478+
state.policy_factory = crate::test_utils::policy_factory(
1479+
"example.com",
1480+
serde_json::json!({
1481+
"admin_clients": [client_id]
1482+
}),
1483+
)
14811484
.await
14821485
.unwrap();
14831486
state

crates/handlers/src/test_utils.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ pub(crate) fn setup() {
6969
}
7070

7171
pub(crate) async fn policy_factory(
72+
server_name: &str,
7273
data: serde_json::Value,
7374
) -> Result<Arc<PolicyFactory>, anyhow::Error> {
7475
let workspace_root = camino::Utf8Path::new(env!("CARGO_MANIFEST_DIR"))
@@ -84,6 +85,8 @@ pub(crate) async fn policy_factory(
8485
email: "email/violation".to_owned(),
8586
};
8687

88+
let data = mas_policy::Data::new(server_name.to_owned()).with_rest(data);
89+
8790
let policy_factory = PolicyFactory::load(file, data, entrypoints).await?;
8891
let policy_factory = Arc::new(policy_factory);
8992
Ok(policy_factory)
@@ -192,7 +195,8 @@ impl TestState {
192195
PasswordManager::disabled()
193196
};
194197

195-
let policy_factory = policy_factory(serde_json::json!({})).await?;
198+
let policy_factory =
199+
policy_factory(&site_config.server_name, serde_json::json!({})).await?;
196200

197201
let homeserver_connection =
198202
Arc::new(MockHomeserverConnection::new(&site_config.server_name));
@@ -297,9 +301,12 @@ impl TestState {
297301
// Make the client admin
298302
let state = {
299303
let mut state = self.clone();
300-
state.policy_factory = policy_factory(serde_json::json!({
301-
"admin_clients": [client_id],
302-
}))
304+
state.policy_factory = policy_factory(
305+
"example.com",
306+
serde_json::json!({
307+
"admin_clients": [client_id],
308+
}),
309+
)
303310
.await
304311
.unwrap();
305312
state

crates/policy/src/lib.rs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use opa_wasm::{
1212
wasmtime::{Config, Engine, Module, OptLevel, Store},
1313
Runtime,
1414
};
15+
use serde::Serialize;
1516
use thiserror::Error;
1617
use tokio::io::{AsyncRead, AsyncReadExt};
1718

@@ -69,18 +70,42 @@ impl Entrypoints {
6970
}
7071
}
7172

73+
#[derive(Serialize, Debug)]
74+
pub struct Data {
75+
server_name: String,
76+
77+
#[serde(flatten)]
78+
rest: Option<serde_json::Value>,
79+
}
80+
81+
impl Data {
82+
#[must_use]
83+
pub fn new(server_name: String) -> Self {
84+
Self {
85+
server_name,
86+
rest: None,
87+
}
88+
}
89+
90+
#[must_use]
91+
pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
92+
self.rest = Some(rest);
93+
self
94+
}
95+
}
96+
7297
pub struct PolicyFactory {
7398
engine: Engine,
7499
module: Module,
75-
data: serde_json::Value,
100+
data: Data,
76101
entrypoints: Entrypoints,
77102
}
78103

79104
impl PolicyFactory {
80105
#[tracing::instrument(name = "policy.load", skip(source), err)]
81106
pub async fn load(
82107
mut source: impl AsyncRead + std::marker::Unpin,
83-
data: serde_json::Value,
108+
data: Data,
84109
entrypoints: Entrypoints,
85110
) -> Result<Self, LoadError> {
86111
let mut config = Config::default();
@@ -364,10 +389,10 @@ mod tests {
364389

365390
#[tokio::test]
366391
async fn test_register() {
367-
let data = serde_json::json!({
392+
let data = Data::new("example.com".to_owned()).with_rest(serde_json::json!({
368393
"allowed_domains": ["element.io", "*.element.io"],
369394
"banned_domains": ["staging.element.io"],
370-
});
395+
}));
371396

372397
#[allow(clippy::disallowed_types)]
373398
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))

policies/register/register.rego

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@ allow if {
1313
count(violation) == 0
1414
}
1515

16+
mxid(username, server_name) := sprintf("@%s:%s", [username, server_name])
17+
1618
# METADATA
1719
# entrypoint: true
1820
violation contains {"field": "username", "msg": "username too short"} if {
19-
count(input.username) <= 2
21+
count(input.username) == 0
2022
}
2123

2224
violation contains {"field": "username", "msg": "username too long"} if {
23-
count(input.username) > 64
25+
user_id := mxid(input.username, data.server_name)
26+
count(user_id) > 255
2427
}
2528

2629
violation contains {"field": "username", "msg": "username contains invalid characters"} if {

policies/register/register_test.rego

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,29 @@ test_no_email if {
4242
register.allow with input as {"username": "hello", "registration_method": "upstream-oauth2"}
4343
}
4444

45-
test_short_username if {
46-
not register.allow with input as {"username": "a", "registration_method": "upstream-oauth2"}
45+
test_empty_username if {
46+
not register.allow with input as {"username": "", "registration_method": "upstream-oauth2"}
4747
}
4848

4949
test_long_username if {
5050
not register.allow with input as {
51-
"username": "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
51+
"username": concat("", ["a" | some x in numbers.range(1, 249)]),
5252
"registration_method": "upstream-oauth2",
5353
}
54+
with data.server_name as "matrix.org"
55+
56+
# This makes a MXID that is exactly 255 characters long
57+
register.allow with input as {
58+
"username": concat("", ["a" | some x in numbers.range(1, 249)]),
59+
"registration_method": "upstream-oauth2",
60+
}
61+
with data.server_name as "a.io"
62+
63+
not register.allow with input as {
64+
"username": concat("", ["a" | some x in numbers.range(1, 250)]),
65+
"registration_method": "upstream-oauth2",
66+
}
67+
with data.server_name as "a.io"
5468
}
5569

5670
test_invalid_username if {

0 commit comments

Comments
 (0)