Skip to content

Commit 1f2eccc

Browse files
committed
compat login (sso): support using client-provided device_id
1 parent 8ab4b7e commit 1f2eccc

16 files changed

+287
-177
lines changed

crates/data-model/src/compat/sso_login.rs

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@ use ulid::Ulid;
1010
use url::Url;
1111

1212
use super::CompatSession;
13-
use crate::InvalidTransitionError;
13+
use crate::{BrowserSession, InvalidTransitionError};
1414

1515
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
1616
pub enum CompatSsoLoginState {
1717
#[default]
1818
Pending,
1919
Fulfilled {
2020
fulfilled_at: DateTime<Utc>,
21-
session_id: Ulid,
21+
browser_session_id: Ulid,
2222
},
2323
Exchanged {
2424
fulfilled_at: DateTime<Utc>,
2525
exchanged_at: DateTime<Utc>,
26-
session_id: Ulid,
26+
compat_session_id: Ulid,
2727
},
2828
}
2929

@@ -80,18 +80,20 @@ impl CompatSsoLoginState {
8080
}
8181
}
8282

83-
/// Get the session ID associated with the login.
83+
/// Get the compat session ID associated with the login.
8484
///
85-
/// Returns `None` if the compat SSO login state is [`Pending`].
85+
/// Returns `None` if the compat SSO login state is [`Pending`] or
86+
/// [`Fulfilled`].
8687
///
8788
/// [`Pending`]: CompatSsoLoginState::Pending
8889
#[must_use]
8990
pub fn session_id(&self) -> Option<Ulid> {
9091
match self {
91-
Self::Pending => None,
92-
Self::Fulfilled { session_id, .. } | Self::Exchanged { session_id, .. } => {
93-
Some(*session_id)
94-
}
92+
Self::Pending | Self::Fulfilled { .. } => None,
93+
Self::Exchanged {
94+
compat_session_id: session_id,
95+
..
96+
} => Some(*session_id),
9597
}
9698
}
9799

@@ -106,12 +108,12 @@ impl CompatSsoLoginState {
106108
pub fn fulfill(
107109
self,
108110
fulfilled_at: DateTime<Utc>,
109-
session: &CompatSession,
111+
browser_session: &BrowserSession,
110112
) -> Result<Self, InvalidTransitionError> {
111113
match self {
112114
Self::Pending => Ok(Self::Fulfilled {
113115
fulfilled_at,
114-
session_id: session.id,
116+
browser_session_id: browser_session.id,
115117
}),
116118
Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError),
117119
}
@@ -126,15 +128,19 @@ impl CompatSsoLoginState {
126128
///
127129
/// [`Fulfilled`]: CompatSsoLoginState::Fulfilled
128130
/// [`Exchanged`]: CompatSsoLoginState::Exchanged
129-
pub fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
131+
pub fn exchange(
132+
self,
133+
exchanged_at: DateTime<Utc>,
134+
compat_session: &CompatSession,
135+
) -> Result<Self, InvalidTransitionError> {
130136
match self {
131137
Self::Fulfilled {
132138
fulfilled_at,
133-
session_id,
139+
browser_session_id: _,
134140
} => Ok(Self::Exchanged {
135141
fulfilled_at,
136142
exchanged_at,
137-
session_id,
143+
compat_session_id: compat_session.id,
138144
}),
139145
Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError),
140146
}
@@ -171,9 +177,9 @@ impl CompatSsoLogin {
171177
pub fn fulfill(
172178
mut self,
173179
fulfilled_at: DateTime<Utc>,
174-
session: &CompatSession,
180+
browser_session: &BrowserSession,
175181
) -> Result<Self, InvalidTransitionError> {
176-
self.state = self.state.fulfill(fulfilled_at, session)?;
182+
self.state = self.state.fulfill(fulfilled_at, browser_session)?;
177183
Ok(self)
178184
}
179185

@@ -186,8 +192,12 @@ impl CompatSsoLogin {
186192
///
187193
/// [`Fulfilled`]: CompatSsoLoginState::Fulfilled
188194
/// [`Exchanged`]: CompatSsoLoginState::Exchanged
189-
pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
190-
self.state = self.state.exchange(exchanged_at)?;
195+
pub fn exchange(
196+
mut self,
197+
exchanged_at: DateTime<Utc>,
198+
compat_session: &CompatSession,
199+
) -> Result<Self, InvalidTransitionError> {
200+
self.state = self.state.exchange(exchanged_at, compat_session)?;
191201
Ok(self)
192202
}
193203
}

crates/handlers/src/compat/login.rs

Lines changed: 87 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,6 @@ pub enum RouteError {
170170
#[error("user not found")]
171171
UserNotFound,
172172

173-
#[error("session not found")]
174-
SessionNotFound,
175-
176173
#[error("user has no password")]
177174
NoPassword,
178175

@@ -201,13 +198,11 @@ impl IntoResponse for RouteError {
201198
fn into_response(self) -> axum::response::Response {
202199
let event_id = sentry::capture_error(&self);
203200
let response = match self {
204-
Self::Internal(_) | Self::SessionNotFound | Self::ProvisionDeviceFailed(_) => {
205-
MatrixError {
206-
errcode: "M_UNKNOWN",
207-
error: "Internal server error",
208-
status: StatusCode::INTERNAL_SERVER_ERROR,
209-
}
210-
}
201+
Self::Internal(_) | Self::ProvisionDeviceFailed(_) => MatrixError {
202+
errcode: "M_UNKNOWN",
203+
error: "Internal server error",
204+
status: StatusCode::INTERNAL_SERVER_ERROR,
205+
},
211206
Self::RateLimited(_) => MatrixError {
212207
errcode: "M_LIMIT_EXCEEDED",
213208
error: "Too many login attempts",
@@ -323,7 +318,17 @@ pub(crate) async fn post(
323318
.await?
324319
}
325320

326-
(_, Credentials::Token { token }) => token_login(&mut repo, &clock, &token).await?,
321+
(_, Credentials::Token { token }) => {
322+
token_login(
323+
&mut repo,
324+
&clock,
325+
&token,
326+
input.device_id,
327+
&homeserver,
328+
&mut rng,
329+
)
330+
.await?
331+
}
327332

328333
_ => {
329334
return Err(RouteError::Unsupported);
@@ -382,6 +387,9 @@ async fn token_login(
382387
repo: &mut BoxRepository,
383388
clock: &dyn Clock,
384389
token: &str,
390+
requested_device_id: Option<String>,
391+
homeserver: &dyn HomeserverConnection,
392+
rng: &mut (dyn RngCore + Send),
385393
) -> Result<(CompatSession, User), RouteError> {
386394
let login = repo
387395
.compat_sso_login()
@@ -390,7 +398,7 @@ async fn token_login(
390398
.ok_or(RouteError::InvalidLoginToken)?;
391399

392400
let now = clock.now();
393-
let session_id = match login.state {
401+
let browser_session_id = match login.state {
394402
CompatSsoLoginState::Pending => {
395403
tracing::error!(
396404
compat_sso_login.id = %login.id,
@@ -400,25 +408,25 @@ async fn token_login(
400408
}
401409
CompatSsoLoginState::Fulfilled {
402410
fulfilled_at,
403-
session_id,
411+
browser_session_id,
404412
..
405413
} => {
406414
if now > fulfilled_at + Duration::microseconds(30 * 1000 * 1000) {
407415
return Err(RouteError::LoginTookTooLong);
408416
}
409417

410-
session_id
418+
browser_session_id
411419
}
412420
CompatSsoLoginState::Exchanged {
413421
exchanged_at,
414-
session_id,
422+
compat_session_id,
415423
..
416424
} => {
417425
if now > exchanged_at + Duration::microseconds(30 * 1000 * 1000) {
418426
// TODO: log that session out
419427
tracing::error!(
420428
compat_sso_login.id = %login.id,
421-
compat_session.id = %session_id,
429+
compat_session.id = %compat_session_id,
422430
"Login token exchanged a second time more than 30s after"
423431
);
424432
}
@@ -427,22 +435,56 @@ async fn token_login(
427435
}
428436
};
429437

430-
let session = repo
431-
.compat_session()
432-
.lookup(session_id)
433-
.await?
434-
.ok_or(RouteError::SessionNotFound)?;
438+
let Some(browser_session) = repo.browser_session().lookup(browser_session_id).await? else {
439+
tracing::error!(
440+
compat_sso_login.id = %login.id,
441+
browser_session.id = %browser_session_id,
442+
"Attempt to exchange login token but no associated browser session found"
443+
);
444+
return Err(RouteError::InvalidLoginToken);
445+
};
446+
if !browser_session.active() || !browser_session.user.is_valid() {
447+
tracing::info!(
448+
compat_sso_login.id = %login.id,
449+
browser_session.id = %browser_session_id,
450+
"Attempt to exchange login token but browser session is not active"
451+
);
452+
return Err(RouteError::InvalidLoginToken);
453+
}
435454

436-
let user = repo
437-
.user()
438-
.lookup(session.user_id)
439-
.await?
440-
.filter(mas_data_model::User::is_valid)
441-
.ok_or(RouteError::UserNotFound)?;
455+
// Lock the user sync to make sure we don't get into a race condition
456+
repo.user()
457+
.acquire_lock_for_sync(&browser_session.user)
458+
.await?;
442459

443-
repo.compat_sso_login().exchange(clock, login).await?;
460+
let device = if let Some(requested_device_id) = requested_device_id {
461+
Device::from(requested_device_id)
462+
} else {
463+
Device::generate(rng)
464+
};
465+
let mxid = homeserver.mxid(&browser_session.user.username);
466+
homeserver
467+
.create_device(&mxid, device.as_str())
468+
.await
469+
.map_err(RouteError::ProvisionDeviceFailed)?;
444470

445-
Ok((session, user))
471+
let compat_session = repo
472+
.compat_session()
473+
.add(
474+
rng,
475+
clock,
476+
&browser_session.user,
477+
device,
478+
Some(&browser_session),
479+
false,
480+
)
481+
.await?;
482+
483+
repo.compat_sso_login()
484+
.exchange(clock, login, &compat_session)
485+
.await?;
486+
487+
Ok((compat_session, browser_session.user))
446488
}
447489

448490
async fn user_password_login(
@@ -1015,7 +1057,7 @@ mod tests {
10151057
}
10161058
"###);
10171059

1018-
let (device, token) = get_login_token(&state, &user).await;
1060+
let token = get_login_token(&state, &user).await;
10191061

10201062
// Try to login with the token.
10211063
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
@@ -1026,14 +1068,13 @@ mod tests {
10261068
response.assert_status(StatusCode::OK);
10271069

10281070
let body: serde_json::Value = response.json();
1029-
insta::assert_json_snapshot!(body, @r###"
1071+
insta::assert_json_snapshot!(body, @r#"
10301072
{
1031-
"access_token": "mct_uihy4bk51gxgUbUTa4XIh92RARTPTj_xADEE4",
1032-
"device_id": "Yp7FM44zJN",
1073+
"access_token": "mct_bnkWh1tPmm1MZOpygPaXwygX8PfxEY_hE6do1",
1074+
"device_id": "O3Ju1MUh3Z",
10331075
"user_id": "@alice:example.com"
10341076
}
1035-
"###);
1036-
assert_eq!(body["device_id"], device.to_string());
1077+
"#);
10371078

10381079
// Try again with the same token, it should fail.
10391080
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
@@ -1051,7 +1092,7 @@ mod tests {
10511092
"###);
10521093

10531094
// Try to login, but wait too long before sending the request.
1054-
let (_device, token) = get_login_token(&state, &user).await;
1095+
let token = get_login_token(&state, &user).await;
10551096

10561097
// Advance the clock to make the token expire.
10571098
state
@@ -1079,14 +1120,13 @@ mod tests {
10791120
/// # Panics
10801121
///
10811122
/// Panics if the repository fails.
1082-
async fn get_login_token(state: &TestState, user: &User) -> (Device, String) {
1123+
async fn get_login_token(state: &TestState, user: &User) -> String {
10831124
// XXX: This is a bit manual, but this is what basically the SSO login flow
10841125
// does.
10851126
let mut repo = state.repository().await.unwrap();
10861127

1087-
// Generate a device and a token randomly
1128+
// Generate a token randomly
10881129
let token = Alphanumeric.sample_string(&mut state.rng(), 32);
1089-
let device = Device::generate(&mut state.rng());
10901130

10911131
// Start a compat SSO login flow
10921132
let login = repo
@@ -1100,27 +1140,20 @@ mod tests {
11001140
.await
11011141
.unwrap();
11021142

1103-
// Complete the flow by fulfilling it with a session
1104-
let compat_session = repo
1105-
.compat_session()
1106-
.add(
1107-
&mut state.rng(),
1108-
&state.clock,
1109-
user,
1110-
device.clone(),
1111-
None,
1112-
false,
1113-
)
1143+
// Advance the flow by fulfilling it with a browser session
1144+
let browser_session = repo
1145+
.browser_session()
1146+
.add(&mut state.rng(), &state.clock, user, None)
11141147
.await
11151148
.unwrap();
1116-
1117-
repo.compat_sso_login()
1118-
.fulfill(&state.clock, login, &compat_session)
1149+
let _login = repo
1150+
.compat_sso_login()
1151+
.fulfill(&state.clock, login, &browser_session)
11191152
.await
11201153
.unwrap();
11211154

11221155
repo.save().await.unwrap();
11231156

1124-
(device, token)
1157+
token
11251158
}
11261159
}

0 commit comments

Comments
 (0)