diff --git a/crates/data-model/src/compat/sso_login.rs b/crates/data-model/src/compat/sso_login.rs index fa894146d..448601a92 100644 --- a/crates/data-model/src/compat/sso_login.rs +++ b/crates/data-model/src/compat/sso_login.rs @@ -10,7 +10,7 @@ use ulid::Ulid; use url::Url; use super::CompatSession; -use crate::InvalidTransitionError; +use crate::{BrowserSession, InvalidTransitionError}; #[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)] pub enum CompatSsoLoginState { @@ -18,12 +18,12 @@ pub enum CompatSsoLoginState { Pending, Fulfilled { fulfilled_at: DateTime, - session_id: Ulid, + browser_session_id: Ulid, }, Exchanged { fulfilled_at: DateTime, exchanged_at: DateTime, - session_id: Ulid, + compat_session_id: Ulid, }, } @@ -80,18 +80,21 @@ impl CompatSsoLoginState { } } - /// Get the session ID associated with the login. + /// Get the compat session ID associated with the login. /// - /// Returns `None` if the compat SSO login state is [`Pending`]. + /// Returns `None` if the compat SSO login state is [`Pending`] or + /// [`Fulfilled`]. /// /// [`Pending`]: CompatSsoLoginState::Pending + /// [`Fulfilled`]: CompatSsoLoginState::Fulfilled #[must_use] pub fn session_id(&self) -> Option { match self { - Self::Pending => None, - Self::Fulfilled { session_id, .. } | Self::Exchanged { session_id, .. } => { - Some(*session_id) - } + Self::Pending | Self::Fulfilled { .. } => None, + Self::Exchanged { + compat_session_id: session_id, + .. + } => Some(*session_id), } } @@ -106,12 +109,12 @@ impl CompatSsoLoginState { pub fn fulfill( self, fulfilled_at: DateTime, - session: &CompatSession, + browser_session: &BrowserSession, ) -> Result { match self { Self::Pending => Ok(Self::Fulfilled { fulfilled_at, - session_id: session.id, + browser_session_id: browser_session.id, }), Self::Fulfilled { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), } @@ -126,15 +129,19 @@ impl CompatSsoLoginState { /// /// [`Fulfilled`]: CompatSsoLoginState::Fulfilled /// [`Exchanged`]: CompatSsoLoginState::Exchanged - pub fn exchange(self, exchanged_at: DateTime) -> Result { + pub fn exchange( + self, + exchanged_at: DateTime, + compat_session: &CompatSession, + ) -> Result { match self { Self::Fulfilled { fulfilled_at, - session_id, + browser_session_id: _, } => Ok(Self::Exchanged { fulfilled_at, exchanged_at, - session_id, + compat_session_id: compat_session.id, }), Self::Pending { .. } | Self::Exchanged { .. } => Err(InvalidTransitionError), } @@ -171,9 +178,9 @@ impl CompatSsoLogin { pub fn fulfill( mut self, fulfilled_at: DateTime, - session: &CompatSession, + browser_session: &BrowserSession, ) -> Result { - self.state = self.state.fulfill(fulfilled_at, session)?; + self.state = self.state.fulfill(fulfilled_at, browser_session)?; Ok(self) } @@ -186,8 +193,12 @@ impl CompatSsoLogin { /// /// [`Fulfilled`]: CompatSsoLoginState::Fulfilled /// [`Exchanged`]: CompatSsoLoginState::Exchanged - pub fn exchange(mut self, exchanged_at: DateTime) -> Result { - self.state = self.state.exchange(exchanged_at)?; + pub fn exchange( + mut self, + exchanged_at: DateTime, + compat_session: &CompatSession, + ) -> Result { + self.state = self.state.exchange(exchanged_at, compat_session)?; Ok(self) } } diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 7c1a6d97a..3040d758c 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -103,6 +103,14 @@ pub struct RequestBody { #[serde(default)] refresh_token: bool, + + /// ID of the client device. + /// If this does not correspond to a known client device, a new device will + /// be created. The given device ID must not be the same as a + /// cross-signing key ID. The server will auto-generate a `device_id` if + /// this is not specified. + #[serde(default, skip_serializing_if = "Option::is_none")] + device_id: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -162,9 +170,6 @@ pub enum RouteError { #[error("user not found")] UserNotFound, - #[error("session not found")] - SessionNotFound, - #[error("user has no password")] NoPassword, @@ -193,13 +198,11 @@ impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { let event_id = sentry::capture_error(&self); let response = match self { - Self::Internal(_) | Self::SessionNotFound | Self::ProvisionDeviceFailed(_) => { - MatrixError { - errcode: "M_UNKNOWN", - error: "Internal server error", - status: StatusCode::INTERNAL_SERVER_ERROR, - } - } + Self::Internal(_) | Self::ProvisionDeviceFailed(_) => MatrixError { + errcode: "M_UNKNOWN", + error: "Internal server error", + status: StatusCode::INTERNAL_SERVER_ERROR, + }, Self::RateLimited(_) => MatrixError { errcode: "M_LIMIT_EXCEEDED", error: "Too many login attempts", @@ -310,11 +313,22 @@ pub(crate) async fn post( &homeserver, user, password, + input.device_id, // TODO check for validity ) .await? } - (_, Credentials::Token { token }) => token_login(&mut repo, &clock, &token).await?, + (_, Credentials::Token { token }) => { + token_login( + &mut repo, + &clock, + &token, + input.device_id, + &homeserver, + &mut rng, + ) + .await? + } _ => { return Err(RouteError::Unsupported); @@ -373,6 +387,9 @@ async fn token_login( repo: &mut BoxRepository, clock: &dyn Clock, token: &str, + requested_device_id: Option, + homeserver: &dyn HomeserverConnection, + rng: &mut (dyn RngCore + Send), ) -> Result<(CompatSession, User), RouteError> { let login = repo .compat_sso_login() @@ -381,7 +398,7 @@ async fn token_login( .ok_or(RouteError::InvalidLoginToken)?; let now = clock.now(); - let session_id = match login.state { + let browser_session_id = match login.state { CompatSsoLoginState::Pending => { tracing::error!( compat_sso_login.id = %login.id, @@ -391,25 +408,25 @@ async fn token_login( } CompatSsoLoginState::Fulfilled { fulfilled_at, - session_id, + browser_session_id, .. } => { if now > fulfilled_at + Duration::microseconds(30 * 1000 * 1000) { return Err(RouteError::LoginTookTooLong); } - session_id + browser_session_id } CompatSsoLoginState::Exchanged { exchanged_at, - session_id, + compat_session_id, .. } => { if now > exchanged_at + Duration::microseconds(30 * 1000 * 1000) { // TODO: log that session out tracing::error!( compat_sso_login.id = %login.id, - compat_session.id = %session_id, + compat_session.id = %compat_session_id, "Login token exchanged a second time more than 30s after" ); } @@ -418,22 +435,60 @@ async fn token_login( } }; - let session = repo - .compat_session() - .lookup(session_id) - .await? - .ok_or(RouteError::SessionNotFound)?; + let Some(browser_session) = repo.browser_session().lookup(browser_session_id).await? else { + tracing::error!( + compat_sso_login.id = %login.id, + browser_session.id = %browser_session_id, + "Attempt to exchange login token but no associated browser session found" + ); + return Err(RouteError::InvalidLoginToken); + }; + if !browser_session.active() || !browser_session.user.is_valid() { + tracing::info!( + compat_sso_login.id = %login.id, + browser_session.id = %browser_session_id, + "Attempt to exchange login token but browser session is not active" + ); + return Err(RouteError::InvalidLoginToken); + } - let user = repo - .user() - .lookup(session.user_id) - .await? - .filter(mas_data_model::User::is_valid) - .ok_or(RouteError::UserNotFound)?; + // Lock the user sync to make sure we don't get into a race condition + repo.user() + .acquire_lock_for_sync(&browser_session.user) + .await?; + + let device = if let Some(requested_device_id) = requested_device_id { + Device::from(requested_device_id) + } else { + Device::generate(rng) + }; + let mxid = homeserver.mxid(&browser_session.user.username); + homeserver + .create_device(&mxid, device.as_str()) + .await + .map_err(RouteError::ProvisionDeviceFailed)?; - repo.compat_sso_login().exchange(clock, login).await?; + repo.app_session() + .finish_sessions_to_replace_device(clock, &browser_session.user, &device) + .await?; - Ok((session, user)) + let compat_session = repo + .compat_session() + .add( + rng, + clock, + &browser_session.user, + device, + Some(&browser_session), + false, + ) + .await?; + + repo.compat_sso_login() + .exchange(clock, login, &compat_session) + .await?; + + Ok((compat_session, browser_session.user)) } async fn user_password_login( @@ -446,6 +501,7 @@ async fn user_password_login( homeserver: &dyn HomeserverConnection, username: String, password: String, + requested_device_id: Option, ) -> Result<(CompatSession, User), RouteError> { // Try getting the localpart out of the MXID let username = homeserver.localpart(&username).unwrap_or(&username); @@ -498,14 +554,23 @@ async fn user_password_login( // Lock the user sync to make sure we don't get into a race condition repo.user().acquire_lock_for_sync(&user).await?; - // Now that the user credentials have been verified, start a new compat session - let device = Device::generate(&mut rng); let mxid = homeserver.mxid(&user.username); + + // Now that the user credentials have been verified, start a new compat session + let device = if let Some(requested_device_id) = requested_device_id { + Device::from(requested_device_id) + } else { + Device::generate(&mut rng) + }; homeserver .create_device(&mxid, device.as_str()) .await .map_err(RouteError::ProvisionDeviceFailed)?; + repo.app_session() + .finish_sessions_to_replace_device(clock, &user, &device) + .await?; + let session = repo .compat_session() .add(&mut rng, clock, &user, device, None, false) @@ -1000,7 +1065,7 @@ mod tests { } "###); - let (device, token) = get_login_token(&state, &user).await; + let token = get_login_token(&state, &user).await; // Try to login with the token. let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ @@ -1011,14 +1076,13 @@ mod tests { response.assert_status(StatusCode::OK); let body: serde_json::Value = response.json(); - insta::assert_json_snapshot!(body, @r###" + insta::assert_json_snapshot!(body, @r#" { - "access_token": "mct_uihy4bk51gxgUbUTa4XIh92RARTPTj_xADEE4", - "device_id": "Yp7FM44zJN", + "access_token": "mct_bnkWh1tPmm1MZOpygPaXwygX8PfxEY_hE6do1", + "device_id": "O3Ju1MUh3Z", "user_id": "@alice:example.com" } - "###); - assert_eq!(body["device_id"], device.to_string()); + "#); // Try again with the same token, it should fail. let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ @@ -1036,7 +1100,7 @@ mod tests { "###); // Try to login, but wait too long before sending the request. - let (_device, token) = get_login_token(&state, &user).await; + let token = get_login_token(&state, &user).await; // Advance the clock to make the token expire. state @@ -1064,14 +1128,13 @@ mod tests { /// # Panics /// /// Panics if the repository fails. - async fn get_login_token(state: &TestState, user: &User) -> (Device, String) { + async fn get_login_token(state: &TestState, user: &User) -> String { // XXX: This is a bit manual, but this is what basically the SSO login flow // does. let mut repo = state.repository().await.unwrap(); - // Generate a device and a token randomly + // Generate a token randomly let token = Alphanumeric.sample_string(&mut state.rng(), 32); - let device = Device::generate(&mut state.rng()); // Start a compat SSO login flow let login = repo @@ -1085,27 +1148,20 @@ mod tests { .await .unwrap(); - // Complete the flow by fulfilling it with a session - let compat_session = repo - .compat_session() - .add( - &mut state.rng(), - &state.clock, - user, - device.clone(), - None, - false, - ) + // Advance the flow by fulfilling it with a browser session + let browser_session = repo + .browser_session() + .add(&mut state.rng(), &state.clock, user, None) .await .unwrap(); - - repo.compat_sso_login() - .fulfill(&state.clock, login, &compat_session) + let _login = repo + .compat_sso_login() + .fulfill(&state.clock, login, &browser_session) .await .unwrap(); repo.save().await.unwrap(); - (device, token) + token } } diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 856d5356b..8da507d70 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use anyhow::Context; use axum::{ @@ -17,12 +17,9 @@ use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, }; -use mas_data_model::Device; -use mas_matrix::HomeserverConnection; use mas_router::{CompatLoginSsoAction, UrlBuilder}; use mas_storage::{ - BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, - compat::{CompatSessionRepository, CompatSsoLoginRepository}, + BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, compat::CompatSsoLoginRepository, }; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; use serde::{Deserialize, Serialize}; @@ -133,7 +130,6 @@ pub async fn post( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, - State(homeserver): State>, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, @@ -174,8 +170,10 @@ pub async fn post( .await? .context("Could not find compat SSO login")?; - // Bail out if that login session is more than 30min old - if clock.now() > login.created_at + Duration::microseconds(30 * 60 * 1000 * 1000) { + // Bail out if that login session isn't pending, or is more than 30min old + if !login.is_pending() + || clock.now() > login.created_at + Duration::microseconds(30 * 60 * 1000 * 1000) + { let ctx = ErrorContext::new() .with_code("compat_sso_login_expired") .with_description("This login session expired.".to_owned()) @@ -202,30 +200,10 @@ pub async fn post( redirect_uri }; - // Lock the user sync to make sure we don't get into a race condition - repo.user().acquire_lock_for_sync(&session.user).await?; - - let device = Device::generate(&mut rng); - let mxid = homeserver.mxid(&session.user.username); - homeserver - .create_device(&mxid, device.as_str()) - .await - .context("Failed to provision device")?; - - let compat_session = repo - .compat_session() - .add( - &mut rng, - &clock, - &session.user, - device, - Some(&session), - false, - ) - .await?; - + // Note that if the login is not Pending, + // this fails and aborts the transaction. repo.compat_sso_login() - .fulfill(&clock, login, &compat_session) + .fulfill(&clock, login, &session) .await?; repo.save().await?; diff --git a/crates/storage-pg/.sqlx/query-373f7eb215b0e515b000a37e55bd055954f697f257de026b74ec408938a52a1a.json b/crates/storage-pg/.sqlx/query-373f7eb215b0e515b000a37e55bd055954f697f257de026b74ec408938a52a1a.json new file mode 100644 index 000000000..9ebd78f6f --- /dev/null +++ b/crates/storage-pg/.sqlx/query-373f7eb215b0e515b000a37e55bd055954f697f257de026b74ec408938a52a1a.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE oauth2_sessions SET finished_at = $3 WHERE user_id = $1 AND $2 = ANY(scope_list) AND finished_at IS NULL\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "373f7eb215b0e515b000a37e55bd055954f697f257de026b74ec408938a52a1a" +} diff --git a/crates/storage-pg/.sqlx/query-3f9d76f442c82a1631da931950b83b80c9620e1825ab07ab6c52f3f1a32d2527.json b/crates/storage-pg/.sqlx/query-3f9d76f442c82a1631da931950b83b80c9620e1825ab07ab6c52f3f1a32d2527.json new file mode 100644 index 000000000..231084544 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-3f9d76f442c82a1631da931950b83b80c9620e1825ab07ab6c52f3f1a32d2527.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE compat_sso_logins\n SET\n user_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "3f9d76f442c82a1631da931950b83b80c9620e1825ab07ab6c52f3f1a32d2527" +} diff --git a/crates/storage-pg/.sqlx/query-1787a5e86b60f57295fe5111259a29ffb15aa31e707cb7f2ad4269d125f6d8c9.json b/crates/storage-pg/.sqlx/query-933d2bed9c00eb9b37bfe757266ead15df5e0a4209ff47dcf4a5f19d35154e89.json similarity index 77% rename from crates/storage-pg/.sqlx/query-1787a5e86b60f57295fe5111259a29ffb15aa31e707cb7f2ad4269d125f6d8c9.json rename to crates/storage-pg/.sqlx/query-933d2bed9c00eb9b37bfe757266ead15df5e0a4209ff47dcf4a5f19d35154e89.json index 56b805e1f..8a040439c 100644 --- a/crates/storage-pg/.sqlx/query-1787a5e86b60f57295fe5111259a29ffb15aa31e707cb7f2ad4269d125f6d8c9.json +++ b/crates/storage-pg/.sqlx/query-933d2bed9c00eb9b37bfe757266ead15df5e0a4209ff47dcf4a5f19d35154e89.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE compat_session_id = $1\n ", + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n , user_session_id\n\n FROM compat_sso_logins\n WHERE compat_session_id = $1\n ", "describe": { "columns": [ { @@ -37,6 +37,11 @@ "ordinal": 6, "name": "compat_session_id", "type_info": "Uuid" + }, + { + "ordinal": 7, + "name": "user_session_id", + "type_info": "Uuid" } ], "parameters": { @@ -51,8 +56,9 @@ false, true, true, + true, true ] }, - "hash": "1787a5e86b60f57295fe5111259a29ffb15aa31e707cb7f2ad4269d125f6d8c9" + "hash": "933d2bed9c00eb9b37bfe757266ead15df5e0a4209ff47dcf4a5f19d35154e89" } diff --git a/crates/storage-pg/.sqlx/query-9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2.json b/crates/storage-pg/.sqlx/query-9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2.json deleted file mode 100644 index 81df89675..000000000 --- a/crates/storage-pg/.sqlx/query-9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2\n WHERE\n compat_sso_login_id = $1\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "9348d87f9e06b614c7e90bdc93bcf38236766aaf4d894bf768debdff2b59fae2" -} diff --git a/crates/storage-pg/.sqlx/query-ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82.json b/crates/storage-pg/.sqlx/query-a7094d84d313602729fde155cfbe63041fca7cbab407f98452462ec45e3cfd16.json similarity index 77% rename from crates/storage-pg/.sqlx/query-ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82.json rename to crates/storage-pg/.sqlx/query-a7094d84d313602729fde155cfbe63041fca7cbab407f98452462ec45e3cfd16.json index 058df6f64..effac88b3 100644 --- a/crates/storage-pg/.sqlx/query-ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82.json +++ b/crates/storage-pg/.sqlx/query-a7094d84d313602729fde155cfbe63041fca7cbab407f98452462ec45e3cfd16.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE compat_sso_login_id = $1\n ", + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n , user_session_id\n\n FROM compat_sso_logins\n WHERE compat_sso_login_id = $1\n ", "describe": { "columns": [ { @@ -37,6 +37,11 @@ "ordinal": 6, "name": "compat_session_id", "type_info": "Uuid" + }, + { + "ordinal": 7, + "name": "user_session_id", + "type_info": "Uuid" } ], "parameters": { @@ -51,8 +56,9 @@ false, true, true, + true, true ] }, - "hash": "ddb22dd9ae9367af65a607e1fdc48b3d9581d67deea0c168f24e02090082bb82" + "hash": "a7094d84d313602729fde155cfbe63041fca7cbab407f98452462ec45e3cfd16" } diff --git a/crates/storage-pg/.sqlx/query-b74e4d620bed4832a4e8e713a346691f260a7eca4bf494d6fb11c7cf699adaad.json b/crates/storage-pg/.sqlx/query-b74e4d620bed4832a4e8e713a346691f260a7eca4bf494d6fb11c7cf699adaad.json new file mode 100644 index 000000000..68f1b1764 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-b74e4d620bed4832a4e8e713a346691f260a7eca4bf494d6fb11c7cf699adaad.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE compat_sessions SET finished_at = $3 WHERE user_id = $1 AND device_id = $2 AND finished_at IS NULL\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "b74e4d620bed4832a4e8e713a346691f260a7eca4bf494d6fb11c7cf699adaad" +} diff --git a/crates/storage-pg/.sqlx/query-478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4.json b/crates/storage-pg/.sqlx/query-ce36eb8d3e4478a4e8520919ff41f1a5e6470cef581b1638f5578546dd28c4df.json similarity index 77% rename from crates/storage-pg/.sqlx/query-478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4.json rename to crates/storage-pg/.sqlx/query-ce36eb8d3e4478a4e8520919ff41f1a5e6470cef581b1638f5578546dd28c4df.json index f7bb7f438..1a425c860 100644 --- a/crates/storage-pg/.sqlx/query-478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4.json +++ b/crates/storage-pg/.sqlx/query-ce36eb8d3e4478a4e8520919ff41f1a5e6470cef581b1638f5578546dd28c4df.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n ", + "query": "\n SELECT compat_sso_login_id\n , login_token\n , redirect_uri\n , created_at\n , fulfilled_at\n , exchanged_at\n , compat_session_id\n , user_session_id\n\n FROM compat_sso_logins\n WHERE login_token = $1\n ", "describe": { "columns": [ { @@ -37,6 +37,11 @@ "ordinal": 6, "name": "compat_session_id", "type_info": "Uuid" + }, + { + "ordinal": 7, + "name": "user_session_id", + "type_info": "Uuid" } ], "parameters": { @@ -51,8 +56,9 @@ false, true, true, + true, true ] }, - "hash": "478f0ad710da8bfd803c6cddd982bc504d1b6bd0f5283de53c8c7b1b4b7dafd4" + "hash": "ce36eb8d3e4478a4e8520919ff41f1a5e6470cef581b1638f5578546dd28c4df" } diff --git a/crates/storage-pg/.sqlx/query-4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b.json b/crates/storage-pg/.sqlx/query-e8e48db74ac1ab5baa1e4b121643cfa33a0bf3328df6e869464fe7f31429b81e.json similarity index 52% rename from crates/storage-pg/.sqlx/query-4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b.json rename to crates/storage-pg/.sqlx/query-e8e48db74ac1ab5baa1e4b121643cfa33a0bf3328df6e869464fe7f31429b81e.json index ddf67cfd9..6f461c850 100644 --- a/crates/storage-pg/.sqlx/query-4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b.json +++ b/crates/storage-pg/.sqlx/query-e8e48db74ac1ab5baa1e4b121643cfa33a0bf3328df6e869464fe7f31429b81e.json @@ -1,16 +1,16 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE compat_sso_logins\n SET\n compat_session_id = $2,\n fulfilled_at = $3\n WHERE\n compat_sso_login_id = $1\n ", + "query": "\n UPDATE compat_sso_logins\n SET\n exchanged_at = $2,\n compat_session_id = $3\n WHERE\n compat_sso_login_id = $1\n ", "describe": { "columns": [], "parameters": { "Left": [ "Uuid", - "Uuid", - "Timestamptz" + "Timestamptz", + "Uuid" ] }, "nullable": [] }, - "hash": "4d79ce892e4595edb8b801e94fb0cbef28facdfd2e45d1c72c57f47418fbe24b" + "hash": "e8e48db74ac1ab5baa1e4b121643cfa33a0bf3328df6e869464fe7f31429b81e" } diff --git a/crates/storage-pg/.sqlx/query-fcd8b4b9e003d1540357c6bf1ff9c715560d011d4c01112703a9c046170c84f1.json b/crates/storage-pg/.sqlx/query-fcd8b4b9e003d1540357c6bf1ff9c715560d011d4c01112703a9c046170c84f1.json index ef1ac0372..f5503fa0e 100644 --- a/crates/storage-pg/.sqlx/query-fcd8b4b9e003d1540357c6bf1ff9c715560d011d4c01112703a9c046170c84f1.json +++ b/crates/storage-pg/.sqlx/query-fcd8b4b9e003d1540357c6bf1ff9c715560d011d4c01112703a9c046170c84f1.json @@ -23,7 +23,7 @@ "Left": [] }, "nullable": [ - true, + false, true, null ] diff --git a/crates/storage-pg/migrations/20250404105103_compat_sso_login_browser_session.sql b/crates/storage-pg/migrations/20250404105103_compat_sso_login_browser_session.sql new file mode 100644 index 000000000..4b63590d6 --- /dev/null +++ b/crates/storage-pg/migrations/20250404105103_compat_sso_login_browser_session.sql @@ -0,0 +1,23 @@ +-- Copyright 2025 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + + +-- Compat SSO Logins in the 'fulfilled' state will now be attached to +-- browser sessions, not compat sessions. +-- Only those in the 'exchanged' state will now have a compat session. +-- +-- Rationale: We can't create the compat session without the client +-- being given an opportunity to specify the device_id, which does not happen +-- until the exchange phase. + +-- Empty the table because we don't want to need to think about backwards +-- compatibility for fulfilled logins that don't have an attached +-- browser session ID. +TRUNCATE compat_sso_logins; + +ALTER TABLE compat_sso_logins + -- browser sessions and user sessions are the same thing + ADD COLUMN user_session_id UUID + REFERENCES user_sessions(user_session_id) ON DELETE CASCADE; diff --git a/crates/storage-pg/src/app_session.rs b/crates/storage-pg/src/app_session.rs index ea97ec2d4..dc8e5a479 100644 --- a/crates/storage-pg/src/app_session.rs +++ b/crates/storage-pg/src/app_session.rs @@ -7,20 +7,25 @@ //! A module containing PostgreSQL implementation of repositories for sessions use async_trait::async_trait; -use mas_data_model::{CompatSession, CompatSessionState, Device, Session, SessionState, UserAgent}; +use mas_data_model::{ + CompatSession, CompatSessionState, Device, Session, SessionState, User, UserAgent, +}; use mas_storage::{ - Page, Pagination, + Clock, Page, Pagination, app_session::{AppSession, AppSessionFilter, AppSessionRepository, AppSessionState}, compat::CompatSessionFilter, oauth2::OAuth2SessionFilter, }; use oauth2_types::scope::{Scope, ScopeToken}; +use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT; use sea_query::{ Alias, ColumnRef, CommonTableExpression, Expr, PostgresQueryBuilder, Query, UnionType, }; use sea_query_binder::SqlxBinder; use sqlx::PgConnection; +use tracing::Instrument; use ulid::Ulid; +use uuid::Uuid; use crate::{ DatabaseError, ExecuteExt, @@ -457,6 +462,63 @@ impl AppSessionRepository for PgAppSessionRepository<'_> { .try_into() .map_err(DatabaseError::to_invalid_operation) } + + #[tracing::instrument( + name = "db.app_session.finish_sessions_to_replace_device", + fields( + db.query.text, + %user.id, + %device_id = device.as_str() + ), + skip_all, + err, + )] + async fn finish_sessions_to_replace_device( + &mut self, + clock: &dyn Clock, + user: &User, + device: &Device, + ) -> Result<(), Self::Error> { + // TODO need to invoke this from all the oauth2 login sites + let span = tracing::info_span!( + "db.app_session.finish_sessions_to_replace_device.compat_sessions", + { DB_QUERY_TEXT } = tracing::field::Empty, + ); + let finished_at = clock.now(); + sqlx::query!( + " + UPDATE compat_sessions SET finished_at = $3 WHERE user_id = $1 AND device_id = $2 AND finished_at IS NULL + ", + Uuid::from(user.id), + device.as_str(), + finished_at + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + + if let Ok(device_as_scope_token) = device.to_scope_token() { + let span = tracing::info_span!( + "db.app_session.finish_sessions_to_replace_device.oauth2_sessions", + { DB_QUERY_TEXT } = tracing::field::Empty, + ); + sqlx::query!( + " + UPDATE oauth2_sessions SET finished_at = $3 WHERE user_id = $1 AND $2 = ANY(scope_list) AND finished_at IS NULL + ", + Uuid::from(user.id), + device_as_scope_token.as_str(), + finished_at + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + Ok(()) + } } #[cfg(test)] diff --git a/crates/storage-pg/src/compat/mod.rs b/crates/storage-pg/src/compat/mod.rs index 4a50cb875..1d0e40426 100644 --- a/crates/storage-pg/src/compat/mod.rs +++ b/crates/storage-pg/src/compat/mod.rs @@ -209,21 +209,41 @@ mod tests { .unwrap(); assert!(login.is_pending()); + // Start a browser session for the user + let browser_session = repo + .browser_session() + .add(&mut rng, &clock, &user, None) + .await + .unwrap(); + // Start a compat session for that user let device = Device::generate(&mut rng); let sso_login_session = repo .compat_session() - .add(&mut rng, &clock, &user, device, None, false) + .add( + &mut rng, + &clock, + &user, + device, + Some(&browser_session), + false, + ) .await .unwrap(); // Associate the login with the session let login = repo .compat_sso_login() - .fulfill(&clock, login, &sso_login_session) + .fulfill(&clock, login, &browser_session) .await .unwrap(); assert!(login.is_fulfilled()); + let login = repo + .compat_sso_login() + .exchange(&clock, login, &sso_login_session) + .await + .unwrap(); + assert!(login.is_exchanged()); // Now query the session list with both the unknown and SSO login session type // filter @@ -594,26 +614,33 @@ mod tests { .expect("login not found"); assert_eq!(login_lookup, login); + // Start a compat session for that user + let device = Device::generate(&mut rng); + let compat_session = repo + .compat_session() + .add(&mut rng, &clock, &user, device, None, false) + .await + .unwrap(); + // Exchanging before fulfilling should not work // Note: It should also not poison the SQL transaction let res = repo .compat_sso_login() - .exchange(&clock, login.clone()) + .exchange(&clock, login.clone(), &compat_session) .await; assert!(res.is_err()); - // Start a compat session for that user - let device = Device::generate(&mut rng); - let session = repo - .compat_session() - .add(&mut rng, &clock, &user, device, None, false) + // Start a browser session for that user + let browser_session = repo + .browser_session() + .add(&mut rng, &clock, &user, None) .await .unwrap(); // Associate the login with the session let login = repo .compat_sso_login() - .fulfill(&clock, login, &session) + .fulfill(&clock, login, &browser_session) .await .unwrap(); assert!(login.is_fulfilled()); @@ -629,14 +656,14 @@ mod tests { // Note: It should also not poison the SQL transaction let res = repo .compat_sso_login() - .fulfill(&clock, login.clone(), &session) + .fulfill(&clock, login.clone(), &browser_session) .await; assert!(res.is_err()); // Exchange that login let login = repo .compat_sso_login() - .exchange(&clock, login) + .exchange(&clock, login, &compat_session) .await .unwrap(); assert!(login.is_exchanged()); @@ -652,7 +679,7 @@ mod tests { // Note: It should also not poison the SQL transaction let res = repo .compat_sso_login() - .exchange(&clock, login.clone()) + .exchange(&clock, login.clone(), &compat_session) .await; assert!(res.is_err()); @@ -660,7 +687,7 @@ mod tests { // Note: It should also not poison the SQL transaction let res = repo .compat_sso_login() - .fulfill(&clock, login.clone(), &session) + .fulfill(&clock, login.clone(), &browser_session) .await; assert!(res.is_err()); diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index fe8c35b3c..10c9fd9ad 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -157,14 +157,10 @@ impl TryFrom for (CompatSession, Option CompatSsoLoginState::Fulfilled { - fulfilled_at, - session_id: session.id, - }, (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged { fulfilled_at, exchanged_at, - session_id: session.id, + compat_session_id: session.id, }, _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), }; diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs index 34da9f093..2c794921b 100644 --- a/crates/storage-pg/src/compat/sso_login.rs +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -6,7 +6,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; -use mas_data_model::{CompatSession, CompatSsoLogin, CompatSsoLoginState}; +use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, CompatSsoLoginState}; use mas_storage::{ Clock, Page, Pagination, compat::{CompatSsoLoginFilter, CompatSsoLoginRepository}, @@ -22,7 +22,7 @@ use uuid::Uuid; use crate::{ DatabaseError, DatabaseInconsistencyError, filter::{Filter, StatementExt}, - iden::{CompatSessions, CompatSsoLogins}, + iden::{CompatSsoLogins, UserSessions}, pagination::QueryBuilderExt, tracing::ExecuteExt, }; @@ -41,7 +41,7 @@ impl<'c> PgCompatSsoLoginRepository<'c> { } } -#[derive(sqlx::FromRow)] +#[derive(sqlx::FromRow, Debug)] #[enum_def] struct CompatSsoLoginLookup { compat_sso_login_id: Uuid, @@ -50,6 +50,7 @@ struct CompatSsoLoginLookup { created_at: DateTime, fulfilled_at: Option>, exchanged_at: Option>, + user_session_id: Option, compat_session_id: Option, } @@ -65,17 +66,24 @@ impl TryFrom for CompatSsoLogin { .source(e) })?; - let state = match (res.fulfilled_at, res.exchanged_at, res.compat_session_id) { - (None, None, None) => CompatSsoLoginState::Pending, - (Some(fulfilled_at), None, Some(session_id)) => CompatSsoLoginState::Fulfilled { - fulfilled_at, - session_id: session_id.into(), - }, - (Some(fulfilled_at), Some(exchanged_at), Some(session_id)) => { + let state = match ( + res.fulfilled_at, + res.exchanged_at, + res.user_session_id, + res.compat_session_id, + ) { + (None, None, None, None) => CompatSsoLoginState::Pending, + (Some(fulfilled_at), None, Some(browser_session_id), None) => { + CompatSsoLoginState::Fulfilled { + fulfilled_at, + browser_session_id: browser_session_id.into(), + } + } + (Some(fulfilled_at), Some(exchanged_at), _, Some(compat_session_id)) => { CompatSsoLoginState::Exchanged { fulfilled_at, exchanged_at, - session_id: session_id.into(), + compat_session_id: compat_session_id.into(), } } _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)), @@ -98,14 +106,14 @@ impl Filter for CompatSsoLoginFilter<'_> { Expr::exists( Query::select() .expr(Expr::cust("1")) - .from(CompatSessions::Table) + .from(UserSessions::Table) .and_where( - Expr::col((CompatSessions::Table, CompatSessions::UserId)) + Expr::col((UserSessions::Table, UserSessions::UserId)) .eq(Uuid::from(user.id)), ) .and_where( - Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)) - .equals((CompatSessions::Table, CompatSessions::CompatSessionId)), + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId)) + .equals((UserSessions::Table, UserSessions::UserSessionId)), ) .take(), ) @@ -151,6 +159,7 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { , fulfilled_at , exchanged_at , compat_session_id + , user_session_id FROM compat_sso_logins WHERE compat_sso_login_id = $1 @@ -189,6 +198,7 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { , fulfilled_at , exchanged_at , compat_session_id + , user_session_id FROM compat_sso_logins WHERE compat_session_id = $1 @@ -226,6 +236,7 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { , fulfilled_at , exchanged_at , compat_session_id + , user_session_id FROM compat_sso_logins WHERE login_token = $1 @@ -292,9 +303,8 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { fields( db.query.text, %compat_sso_login.id, - %compat_session.id, - compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str), - user.id = %compat_session.user_id, + %browser_session.id, + user.id = %browser_session.user.id, ), err, )] @@ -302,24 +312,24 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { &mut self, clock: &dyn Clock, compat_sso_login: CompatSsoLogin, - compat_session: &CompatSession, + browser_session: &BrowserSession, ) -> Result { let fulfilled_at = clock.now(); let compat_sso_login = compat_sso_login - .fulfill(fulfilled_at, compat_session) + .fulfill(fulfilled_at, browser_session) .map_err(DatabaseError::to_invalid_operation)?; let res = sqlx::query!( r#" UPDATE compat_sso_logins SET - compat_session_id = $2, + user_session_id = $2, fulfilled_at = $3 WHERE compat_sso_login_id = $1 "#, Uuid::from(compat_sso_login.id), - Uuid::from(compat_session.id), + Uuid::from(browser_session.id), fulfilled_at, ) .traced() @@ -337,6 +347,8 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { fields( db.query.text, %compat_sso_login.id, + %compat_session.id, + compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str), ), err, )] @@ -344,22 +356,25 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { &mut self, clock: &dyn Clock, compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, ) -> Result { let exchanged_at = clock.now(); let compat_sso_login = compat_sso_login - .exchange(exchanged_at) + .exchange(exchanged_at, compat_session) .map_err(DatabaseError::to_invalid_operation)?; let res = sqlx::query!( r#" UPDATE compat_sso_logins SET - exchanged_at = $2 + exchanged_at = $2, + compat_session_id = $3 WHERE compat_sso_login_id = $1 "#, Uuid::from(compat_sso_login.id), exchanged_at, + Uuid::from(compat_session.id), ) .traced() .execute(&mut *self.conn) @@ -392,6 +407,10 @@ impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> { Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)), CompatSsoLoginLookupIden::CompatSessionId, ) + .expr_as( + Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId)), + CompatSsoLoginLookupIden::UserSessionId, + ) .expr_as( Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)), CompatSsoLoginLookupIden::LoginToken, diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index 89d79c90b..71e6f7591 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -61,6 +61,7 @@ pub enum CompatSsoLogins { RedirectUri, LoginToken, CompatSessionId, + UserSessionId, CreatedAt, FulfilledAt, ExchangedAt, diff --git a/crates/storage/src/app_session.rs b/crates/storage/src/app_session.rs index 52fc4483d..fd1850d3d 100644 --- a/crates/storage/src/app_session.rs +++ b/crates/storage/src/app_session.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{BrowserSession, CompatSession, Device, Session, User}; -use crate::{Page, Pagination, repository_impl}; +use crate::{Clock, Page, Pagination, repository_impl}; /// The state of a session #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -188,6 +188,20 @@ pub trait AppSessionRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result; + + /// Finishes any application sessions that are using the specified device's + /// ID. + /// + /// This is intended for logging in using an existing device ID (i.e. + /// replacing a device). + /// + /// Should be called *before* creating a new session for the device. + async fn finish_sessions_to_replace_device( + &mut self, + clock: &dyn Clock, + user: &User, + device: &Device, + ) -> Result<(), Self::Error>; } repository_impl!(AppSessionRepository: @@ -198,4 +212,11 @@ repository_impl!(AppSessionRepository: ) -> Result, Self::Error>; async fn count(&mut self, filter: AppSessionFilter<'_>) -> Result; + + async fn finish_sessions_to_replace_device( + &mut self, + clock: &dyn Clock, + user: &User, + device: &Device, + ) -> Result<(), Self::Error>; ); diff --git a/crates/storage/src/compat/sso_login.rs b/crates/storage/src/compat/sso_login.rs index 425541d65..08e8c5491 100644 --- a/crates/storage/src/compat/sso_login.rs +++ b/crates/storage/src/compat/sso_login.rs @@ -5,7 +5,7 @@ // Please see LICENSE in the repository root for full details. use async_trait::async_trait; -use mas_data_model::{CompatSession, CompatSsoLogin, User}; +use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, User}; use rand_core::RngCore; use ulid::Ulid; use url::Url; @@ -168,7 +168,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { redirect_uri: Url, ) -> Result; - /// Fulfill a compat SSO login by providing a compat session + /// Fulfill a compat SSO login by providing a browser session /// /// Returns the fulfilled compat SSO login /// @@ -176,8 +176,8 @@ pub trait CompatSsoLoginRepository: Send + Sync { /// /// * `clock`: The clock used to generate the timestamps /// * `compat_sso_login`: The compat SSO login to fulfill - /// * `compat_session`: The compat session to associate with the compat SSO - /// login + /// * `browser_session`: The browser session to associate with the compat + /// SSO login /// /// # Errors /// @@ -186,7 +186,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { &mut self, clock: &dyn Clock, compat_sso_login: CompatSsoLogin, - compat_session: &CompatSession, + browser_session: &BrowserSession, ) -> Result; /// Mark a compat SSO login as exchanged @@ -197,6 +197,8 @@ pub trait CompatSsoLoginRepository: Send + Sync { /// /// * `clock`: The clock used to generate the timestamps /// * `compat_sso_login`: The compat SSO login to mark as exchanged + /// * `compat_session`: The compat session created as a result of the + /// exchange /// /// # Errors /// @@ -205,6 +207,7 @@ pub trait CompatSsoLoginRepository: Send + Sync { &mut self, clock: &dyn Clock, compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, ) -> Result; /// List [`CompatSsoLogin`] with the given filter and pagination @@ -262,13 +265,14 @@ repository_impl!(CompatSsoLoginRepository: &mut self, clock: &dyn Clock, compat_sso_login: CompatSsoLogin, - compat_session: &CompatSession, + browser_session: &BrowserSession, ) -> Result; async fn exchange( &mut self, clock: &dyn Clock, compat_sso_login: CompatSsoLogin, + compat_session: &CompatSession, ) -> Result; async fn list(