Skip to content

Commit 00b31b8

Browse files
authored
Merge pull request #3908 from element-hq/quenting/mxid-in-login
Allow logging in with the full MXID
2 parents 931de22 + 76ba8e1 commit 00b31b8

File tree

8 files changed

+204
-46
lines changed

8 files changed

+204
-46
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ version = "0.12.12"
248248
default-features = false
249249
features = ["http2", "rustls-tls-manual-roots", "charset", "json", "socks"]
250250

251+
# Matrix-related types
252+
[workspace.dependencies.ruma-common]
253+
version = "0.15.0"
254+
251255
# TLS stack
252256
[workspace.dependencies.rustls]
253257
version = "0.23.21"

crates/data-model/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ rand.workspace = true
2424
rand_chacha = "0.3.1"
2525
regex = "1.11.1"
2626
woothee = "0.13.0"
27-
ruma-common = "0.15.0"
27+
ruma-common.workspace = true
2828

2929
mas-iana.workspace = true
3030
mas-jose.workspace = true

crates/data-model/src/oauth2/authorization_grant.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use rand::{
1717
distributions::{Alphanumeric, DistString},
1818
RngCore,
1919
};
20-
use ruma_common::{OwnedUserId, UserId};
20+
use ruma_common::UserId;
2121
use serde::Serialize;
2222
use ulid::Ulid;
2323
use url::Url;
@@ -142,8 +142,8 @@ impl AuthorizationGrantStage {
142142
}
143143
}
144144

145-
pub enum LoginHint {
146-
MXID(OwnedUserId),
145+
pub enum LoginHint<'a> {
146+
MXID(&'a UserId),
147147
None,
148148
}
149149

@@ -200,7 +200,7 @@ impl AuthorizationGrant {
200200
match prefix {
201201
"mxid" => {
202202
// Instead of erroring just return none
203-
let Ok(mxid) = UserId::parse(value) else {
203+
let Ok(mxid) = <&UserId>::try_from(value) else {
204204
return LoginHint::None;
205205
};
206206

crates/handlers/src/compat/login.rs

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,13 @@ async fn user_password_login(
386386
username: String,
387387
password: String,
388388
) -> Result<(CompatSession, User), RouteError> {
389+
// Try getting the localpart out of the MXID
390+
let username = homeserver.localpart(&username).unwrap_or(&username);
391+
389392
// Find the user
390393
let user = repo
391394
.user()
392-
.find_by_username(&username)
395+
.find_by_username(username)
393396
.await?
394397
.filter(mas_data_model::User::is_valid)
395398
.ok_or(RouteError::UserNotFound)?;
@@ -539,52 +542,43 @@ mod tests {
539542
assert_eq!(body["errcode"], "M_UNRECOGNIZED");
540543
}
541544

542-
/// Test that a user can login with a password using the Matrix
543-
/// compatibility API.
544-
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
545-
async fn test_user_password_login(pool: PgPool) {
546-
setup();
547-
let state = TestState::from_pool(pool).await.unwrap();
548-
549-
// Let's provision a user and add a password to it. This part is hard to test
550-
// with just HTTP requests, so we'll use the repository directly.
545+
async fn user_with_password(state: &TestState, username: &str, password: &str) {
546+
let mut rng = state.rng();
551547
let mut repo = state.repository().await.unwrap();
552548

553549
let user = repo
554550
.user()
555-
.add(&mut state.rng(), &state.clock, "alice".to_owned())
551+
.add(&mut rng, &state.clock, username.to_owned())
552+
.await
553+
.unwrap();
554+
let (version, hash) = state
555+
.password_manager
556+
.hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec()))
556557
.await
557558
.unwrap();
558559

560+
repo.user_password()
561+
.add(&mut rng, &state.clock, &user, version, hash, None)
562+
.await
563+
.unwrap();
559564
let mxid = state.homeserver_connection.mxid(&user.username);
560565
state
561566
.homeserver_connection
562567
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
563568
.await
564569
.unwrap();
565570

566-
let (version, hashed_password) = state
567-
.password_manager
568-
.hash(
569-
&mut state.rng(),
570-
Zeroizing::new("password".to_owned().into_bytes()),
571-
)
572-
.await
573-
.unwrap();
571+
repo.save().await.unwrap();
572+
}
574573

575-
repo.user_password()
576-
.add(
577-
&mut state.rng(),
578-
&state.clock,
579-
&user,
580-
version,
581-
hashed_password,
582-
None,
583-
)
584-
.await
585-
.unwrap();
574+
/// Test that a user can login with a password using the Matrix
575+
/// compatibility API.
576+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
577+
async fn test_user_password_login(pool: PgPool) {
578+
setup();
579+
let state = TestState::from_pool(pool).await.unwrap();
586580

587-
repo.save().await.unwrap();
581+
user_with_password(&state, "alice", "password").await;
588582

589583
// Now let's try to login with the password, without asking for a refresh token.
590584
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
@@ -662,6 +656,50 @@ mod tests {
662656
assert_eq!(body, old_body);
663657
}
664658

659+
/// Test that a user can login with a password using the Matrix
660+
/// compatibility API, using a MXID as identifier
661+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
662+
async fn test_user_password_login_mxid(pool: PgPool) {
663+
setup();
664+
let state = TestState::from_pool(pool).await.unwrap();
665+
666+
user_with_password(&state, "alice", "password").await;
667+
668+
// Login with a full MXID as identifier
669+
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
670+
"type": "m.login.password",
671+
"identifier": {
672+
"type": "m.id.user",
673+
"user": "@alice:example.com",
674+
},
675+
"password": "password",
676+
}));
677+
678+
let response = state.request(request).await;
679+
response.assert_status(StatusCode::OK);
680+
let body: ResponseBody = response.json();
681+
assert!(!body.access_token.is_empty());
682+
assert_eq!(body.device_id.as_ref().unwrap().as_str().len(), 10);
683+
assert_eq!(body.user_id, "@alice:example.com");
684+
assert_eq!(body.refresh_token, None);
685+
assert_eq!(body.expires_in_ms, None);
686+
687+
// With a MXID, but with the wrong server name
688+
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
689+
"type": "m.login.password",
690+
"identifier": {
691+
"type": "m.id.user",
692+
"user": "@alice:something.corp",
693+
},
694+
"password": "password",
695+
}));
696+
697+
let response = state.request(request).await;
698+
response.assert_status(StatusCode::FORBIDDEN);
699+
let body: serde_json::Value = response.json();
700+
assert_eq!(body["errcode"], "M_FORBIDDEN");
701+
}
702+
665703
/// Test that password logins are rate limited.
666704
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
667705
async fn test_password_login_rate_limit(pool: PgPool) {

crates/handlers/src/views/login.rs

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,19 @@ pub(crate) async fn post(
168168
return Ok((cookie_jar, Html(content)).into_response());
169169
}
170170

171+
// Extract the localpart of the MXID, fallback to the bare username
172+
let username = homeserver
173+
.localpart(&form.username)
174+
.unwrap_or(&form.username);
175+
171176
match login(
172177
password_manager,
173178
&mut repo,
174179
rng,
175180
&clock,
176181
limiter,
177182
requester,
178-
&form.username,
183+
username,
179184
&form.password,
180185
user_agent,
181186
)
@@ -479,30 +484,34 @@ mod test {
479484
.contains(&escape_html(&second_provider_login.path_and_query())));
480485
}
481486

482-
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
483-
async fn test_password_login(pool: PgPool) {
484-
setup();
485-
let state = TestState::from_pool(pool).await.unwrap();
487+
async fn user_with_password(state: &TestState, username: &str, password: &str) {
486488
let mut rng = state.rng();
487-
let cookies = CookieHelper::new();
488-
489-
// Provision a user with a password
490489
let mut repo = state.repository().await.unwrap();
491490
let user = repo
492491
.user()
493-
.add(&mut rng, &state.clock, "john".to_owned())
492+
.add(&mut rng, &state.clock, username.to_owned())
494493
.await
495494
.unwrap();
496495
let (version, hash) = state
497496
.password_manager
498-
.hash(&mut rng, Zeroizing::new("hunter2".as_bytes().to_vec()))
497+
.hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec()))
499498
.await
500499
.unwrap();
501500
repo.user_password()
502501
.add(&mut rng, &state.clock, &user, version, hash, None)
503502
.await
504503
.unwrap();
505504
repo.save().await.unwrap();
505+
}
506+
507+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
508+
async fn test_password_login(pool: PgPool) {
509+
setup();
510+
let state = TestState::from_pool(pool).await.unwrap();
511+
let cookies = CookieHelper::new();
512+
513+
// Provision a user with a password
514+
user_with_password(&state, "john", "hunter2").await;
506515

507516
// Render the login page to get a CSRF token
508517
let request = Request::get("/login").empty();
@@ -542,6 +551,93 @@ mod test {
542551
assert!(response.body().contains("john"));
543552
}
544553

554+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
555+
async fn test_password_login_with_mxid(pool: PgPool) {
556+
setup();
557+
let state = TestState::from_pool(pool).await.unwrap();
558+
let cookies = CookieHelper::new();
559+
560+
// Provision a user with a password
561+
user_with_password(&state, "john", "hunter2").await;
562+
563+
// Render the login page to get a CSRF token
564+
let request = Request::get("/login").empty();
565+
let request = cookies.with_cookies(request);
566+
let response = state.request(request).await;
567+
cookies.save_cookies(&response);
568+
response.assert_status(StatusCode::OK);
569+
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
570+
// Extract the CSRF token from the response body
571+
let csrf_token = response
572+
.body()
573+
.split("name=\"csrf\" value=\"")
574+
.nth(1)
575+
.unwrap()
576+
.split('\"')
577+
.next()
578+
.unwrap();
579+
580+
// Submit the login form
581+
let request = Request::post("/login").form(serde_json::json!({
582+
"csrf": csrf_token,
583+
"username": "@john:example.com",
584+
"password": "hunter2",
585+
}));
586+
let request = cookies.with_cookies(request);
587+
let response = state.request(request).await;
588+
cookies.save_cookies(&response);
589+
response.assert_status(StatusCode::SEE_OTHER);
590+
591+
// Now if we get to the home page, we should see the user's username
592+
let request = Request::get("/").empty();
593+
let request = cookies.with_cookies(request);
594+
let response = state.request(request).await;
595+
cookies.save_cookies(&response);
596+
response.assert_status(StatusCode::OK);
597+
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
598+
assert!(response.body().contains("john"));
599+
}
600+
601+
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
602+
async fn test_password_login_with_mxid_wrong_server(pool: PgPool) {
603+
setup();
604+
let state = TestState::from_pool(pool).await.unwrap();
605+
let cookies = CookieHelper::new();
606+
607+
// Provision a user with a password
608+
user_with_password(&state, "john", "hunter2").await;
609+
610+
// Render the login page to get a CSRF token
611+
let request = Request::get("/login").empty();
612+
let request = cookies.with_cookies(request);
613+
let response = state.request(request).await;
614+
cookies.save_cookies(&response);
615+
response.assert_status(StatusCode::OK);
616+
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
617+
// Extract the CSRF token from the response body
618+
let csrf_token = response
619+
.body()
620+
.split("name=\"csrf\" value=\"")
621+
.nth(1)
622+
.unwrap()
623+
.split('\"')
624+
.next()
625+
.unwrap();
626+
627+
// Submit the login form
628+
let request = Request::post("/login").form(serde_json::json!({
629+
"csrf": csrf_token,
630+
"username": "@john:something.corp",
631+
"password": "hunter2",
632+
}));
633+
let request = cookies.with_cookies(request);
634+
let response = state.request(request).await;
635+
636+
// This shouldn't have worked, we're back on the login page
637+
response.assert_status(StatusCode::OK);
638+
assert!(response.body().contains("Invalid credentials"));
639+
}
640+
545641
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
546642
async fn test_password_login_rate_limit(pool: PgPool) {
547643
setup();

crates/matrix/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ workspace = true
1515
anyhow.workspace = true
1616
async-trait.workspace = true
1717
tokio.workspace = true
18+
ruma-common.workspace = true

crates/matrix/src/lib.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ mod mock;
88

99
use std::{collections::HashSet, sync::Arc};
1010

11+
use ruma_common::UserId;
12+
1113
pub use self::mock::HomeserverConnection as MockHomeserverConnection;
1214

1315
// TODO: this should probably be another error type by default
@@ -193,6 +195,22 @@ pub trait HomeserverConnection: Send + Sync {
193195
format!("@{}:{}", localpart, self.homeserver())
194196
}
195197

198+
/// Get the localpart of a Matrix ID if it has the right server name
199+
///
200+
/// Returns [`None`] if the input isn't a valid MXID, or if the server name
201+
/// doesn't match
202+
///
203+
/// # Parameters
204+
///
205+
/// * `mxid` - The MXID of the user
206+
fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> {
207+
let mxid = <&UserId>::try_from(mxid).ok()?;
208+
if mxid.server_name() != self.homeserver() {
209+
return None;
210+
}
211+
Some(mxid.localpart())
212+
}
213+
196214
/// Query the state of a user on the homeserver.
197215
///
198216
/// # Parameters

0 commit comments

Comments
 (0)