diff --git a/crates/handlers/src/activity_tracker/mod.rs b/crates/handlers/src/activity_tracker/mod.rs index 71d6a8fc2..1cbaec877 100644 --- a/crates/handlers/src/activity_tracker/mod.rs +++ b/crates/handlers/src/activity_tracker/mod.rs @@ -10,7 +10,9 @@ mod worker; use std::net::IpAddr; use chrono::{DateTime, Utc}; -use mas_data_model::{BrowserSession, Clock, CompatSession, Session}; +use mas_data_model::{ + BrowserSession, Clock, CompatSession, Session, personal::session::PersonalSession, +}; use mas_storage::BoxRepositoryFactory; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use ulid::Ulid; @@ -115,7 +117,7 @@ impl ActivityTracker { pub async fn record_personal_access_token_session( &self, clock: &dyn Clock, - session: &Session, + session: &PersonalSession, ip: Option, ) { let res = self diff --git a/crates/handlers/src/activity_tracker/worker.rs b/crates/handlers/src/activity_tracker/worker.rs index 6fa51fce3..9405eab41 100644 --- a/crates/handlers/src/activity_tracker/worker.rs +++ b/crates/handlers/src/activity_tracker/worker.rs @@ -257,7 +257,9 @@ impl Worker { repo.compat_session() .record_batch_activity(compat_sessions) .await?; - // TODO: personal sessions: record + repo.personal_session() + .record_batch_activity(personal_sessions) + .await?; repo.save().await?; self.pending_records.clear(); diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 27ee3fdbc..754eaa942 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -15,7 +15,9 @@ use mas_axum_utils::{ client_authorization::{ClientAuthorization, CredentialsVerificationError}, record_error, }; -use mas_data_model::{BoxClock, Clock, Device, TokenFormatError, TokenType}; +use mas_data_model::{ + BoxClock, Clock, Device, TokenFormatError, TokenType, personal::session::PersonalSessionOwner, +}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_keystore::Encrypter; use mas_matrix::HomeserverConnection; @@ -93,6 +95,14 @@ pub enum RouteError { #[error("unknown compat session {0}")] CantLoadCompatSession(Ulid), + /// The personal access token session is not valid. + #[error("invalid personal access token session {0}")] + InvalidPersonalSession(Ulid), + + /// The personal access token session could not be found in the database. + #[error("unknown personal access token session {0}")] + CantLoadPersonalSession(Ulid), + /// The Device ID in the compat session can't be encoded as a scope #[error("device ID contains characters that are not allowed in a scope")] CantEncodeDeviceID(#[from] mas_data_model::ToScopeTokenError), @@ -103,6 +113,9 @@ pub enum RouteError { #[error("unknown user {0}")] CantLoadUser(Ulid), + #[error("unknown OAuth2 client {0}")] + CantLoadOAuth2Client(Ulid), + #[error("bad request")] BadRequest, @@ -131,7 +144,9 @@ impl IntoResponse for RouteError { e @ (Self::Internal(_) | Self::CantLoadCompatSession(_) | Self::CantLoadOAuthSession(_) + | Self::CantLoadPersonalSession(_) | Self::CantLoadUser(_) + | Self::CantLoadOAuth2Client(_) | Self::FailedToVerifyToken(_)) => ( StatusCode::INTERNAL_SERVER_ERROR, Json( @@ -167,6 +182,7 @@ impl IntoResponse for RouteError { | Self::InvalidUser(_) | Self::InvalidCompatSession(_) | Self::InvalidOAuthSession(_) + | Self::InvalidPersonalSession(_) | Self::InvalidTokenFormat(_) | Self::CantEncodeDeviceID(_) => { INTROSPECTION_COUNTER.add(1, &[KeyValue::new(ACTIVE.clone(), false)]); @@ -627,8 +643,94 @@ pub(crate) async fn post( } TokenType::PersonalAccessToken => { - // TODO - return Err(RouteError::UnknownToken(TokenType::PersonalAccessToken)); + let access_token = repo + .personal_access_token() + .find_by_token(token) + .await? + .ok_or(RouteError::UnknownToken(TokenType::AccessToken))?; + + if !access_token.is_valid(clock.now()) { + return Err(RouteError::InvalidToken(TokenType::AccessToken)); + } + + let session = repo + .personal_session() + .lookup(access_token.session_id) + .await? + .ok_or(RouteError::CantLoadPersonalSession(access_token.session_id))?; + + if !session.is_valid() { + return Err(RouteError::InvalidPersonalSession(session.id)); + } + + let actor_user = repo + .user() + .lookup(session.actor_user_id) + .await? + .ok_or(RouteError::CantLoadUser(session.actor_user_id))?; + + if !actor_user.is_valid() { + return Err(RouteError::InvalidUser(actor_user.id)); + } + + let client_id = match session.owner { + PersonalSessionOwner::User(owner_user_id) => { + let owner_user = repo + .user() + .lookup(owner_user_id) + .await? + .ok_or(RouteError::CantLoadUser(owner_user_id))?; + + if !owner_user.is_valid() { + return Err(RouteError::InvalidUser(owner_user.id)); + } + + None + } + PersonalSessionOwner::OAuth2Client(owner_client_id) => { + let owner_client = repo + .oauth2_client() + .lookup(owner_client_id) + .await? + .ok_or(RouteError::CantLoadOAuth2Client(owner_client_id))?; + + // OAuth2 clients are always valid if they're in the database + Some(owner_client.client_id.clone()) + } + }; + + activity_tracker + .record_personal_access_token_session(&clock, &session, ip) + .await; + + INTROSPECTION_COUNTER.add( + 1, + &[ + KeyValue::new(KIND, "personal_access_token"), + KeyValue::new(ACTIVE, true), + ], + ); + + let scope = normalize_scope(session.scope); + + IntrospectionResponse { + active: true, + scope: Some(scope), + client_id, + username: Some(actor_user.username), + token_type: Some(OAuthTokenTypeHint::AccessToken), + exp: access_token.expires_at, + expires_in: access_token + .expires_at + .map(|expires_at| expires_at.signed_duration_since(clock.now())), + iat: Some(access_token.created_at), + nbf: Some(access_token.created_at), + sub: Some(actor_user.sub), + aud: None, + iss: None, + jti: None, + device_id: None, + } } }; @@ -641,7 +743,9 @@ pub(crate) async fn post( mod tests { use chrono::Duration; use hyper::{Request, StatusCode}; - use mas_data_model::{AccessToken, Clock, RefreshToken}; + use mas_data_model::{ + AccessToken, Clock, RefreshToken, TokenType, personal::session::PersonalSessionOwner, + }; use mas_iana::oauth::OAuthTokenTypeHint; use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest}; use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute}; @@ -1074,4 +1178,125 @@ mod tests { let response: ClientError = response.json(); assert_eq!(response.error, ClientErrorCode::AccessDenied); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_introspect_personal_access_tokens(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + + // Provision a client which will be used to do introspection requests + let request = Request::post(OAuth2RegistrationEndpoint::PATH).json(json!({ + "client_uri": "https://introspecting.com/", + "grant_types": [], + "token_endpoint_auth_method": "client_secret_basic", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::CREATED); + let client: ClientRegistrationResponse = response.json(); + let introspecting_client_id = client.client_id; + let introspecting_client_secret = client.client_secret.unwrap(); + + let mut repo = state.repository().await.unwrap(); + + // Provision an owner user (who provisions the personal session) + let owner_user = repo + .user() + .add(&mut state.rng(), &state.clock, "admin".to_owned()) + .await + .unwrap(); + + // Provision an actor user (which the token represents) + let actor_user = repo + .user() + .add(&mut state.rng(), &state.clock, "bruce".to_owned()) + .await + .unwrap(); + + // admin creates a personal session to control bruce's account + let personal_session = repo + .personal_session() + .add( + &mut state.rng(), + &state.clock, + PersonalSessionOwner::User(owner_user.id), + &actor_user, + "Test Personal Access Token".to_owned(), + Scope::from_iter([OPENID]), + ) + .await + .unwrap(); + + // Generate a personal access token with proper token format + let token_string = TokenType::PersonalAccessToken.generate(&mut state.rng()); + let _personal_access_token = repo + .personal_access_token() + .add( + &mut state.rng(), + &state.clock, + &personal_session, + &token_string, + Some(Duration::try_hours(1).unwrap()), + ) + .await + .unwrap(); + + repo.save().await.unwrap(); + + // Now that we have a personal access token, we can introspect it + let request = Request::post(OAuth2Introspection::PATH) + .basic_auth(&introspecting_client_id, &introspecting_client_secret) + .form(json!({ "token": token_string })); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let response: IntrospectionResponse = response.json(); + assert!(response.active); + // Actor user + assert_eq!(response.username, Some("bruce".to_owned())); + // Not owned by a client + assert_eq!(response.client_id, None); + assert_eq!(response.token_type, Some(OAuthTokenTypeHint::AccessToken)); + assert_eq!(response.scope, Some(Scope::from_iter([OPENID]))); + + // Do the same request, but with a token_type_hint + let last_active = state.clock.now(); + let request = Request::post(OAuth2Introspection::PATH) + .basic_auth(&introspecting_client_id, &introspecting_client_secret) + .form(json!({"token": token_string, "token_type_hint": "access_token"})); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let response: IntrospectionResponse = response.json(); + assert!(response.active); + + // Do the same request, but with the wrong token_type_hint + let request = Request::post(OAuth2Introspection::PATH) + .basic_auth(&introspecting_client_id, &introspecting_client_secret) + .form(json!({"token": token_string, "token_type_hint": "refresh_token"})); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let response: IntrospectionResponse = response.json(); + assert!(!response.active); // It shouldn't be active with wrong hint + + // Advance the clock to invalidate the access token + state.clock.advance(Duration::try_hours(2).unwrap()); + + let request = Request::post(OAuth2Introspection::PATH) + .basic_auth(&introspecting_client_id, &introspecting_client_secret) + .form(json!({ "token": token_string })); + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let response: IntrospectionResponse = response.json(); + assert!(!response.active); // It shouldn't be active anymore + + state.activity_tracker.flush().await; + let mut repo = state.repository().await.unwrap(); + let session = repo + .personal_session() + .lookup(personal_session.id) + .await + .unwrap() + .unwrap(); + assert_eq!(session.last_active_at, Some(last_active)); + repo.save().await.unwrap(); + } } diff --git a/crates/oauth2-types/src/requests.rs b/crates/oauth2-types/src/requests.rs index ac0770411..28203a973 100644 --- a/crates/oauth2-types/src/requests.rs +++ b/crates/oauth2-types/src/requests.rs @@ -807,6 +807,7 @@ pub struct IntrospectionResponse { pub jti: Option, /// MAS extension: explicit device ID + /// Only used for compatibility access and refresh tokens. pub device_id: Option, } diff --git a/crates/storage-pg/.sqlx/query-64b6e274e2bed6814f5ae41ddf57093589f7d1b2b8458521b635546b8012041e.json b/crates/storage-pg/.sqlx/query-64b6e274e2bed6814f5ae41ddf57093589f7d1b2b8458521b635546b8012041e.json new file mode 100644 index 000000000..6b2e85bf1 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-64b6e274e2bed6814f5ae41ddf57093589f7d1b2b8458521b635546b8012041e.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE personal_sessions\n SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at)\n , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip)\n FROM (\n SELECT *\n FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])\n AS t(personal_session_id, last_active_at, last_active_ip)\n ) AS t\n WHERE personal_sessions.personal_session_id = t.personal_session_id\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "UuidArray", + "TimestamptzArray", + "InetArray" + ] + }, + "nullable": [] + }, + "hash": "64b6e274e2bed6814f5ae41ddf57093589f7d1b2b8458521b635546b8012041e" +} diff --git a/crates/storage-pg/src/personal/session.rs b/crates/storage-pg/src/personal/session.rs index 28c725a24..f2b721b70 100644 --- a/crates/storage-pg/src/personal/session.rs +++ b/crates/storage-pg/src/personal/session.rs @@ -361,6 +361,56 @@ impl PersonalSessionRepository for PgPersonalSessionRepository<'_> { .try_into() .map_err(DatabaseError::to_invalid_operation) } + + #[tracing::instrument( + name = "db.personal_session.record_batch_activity", + skip_all, + fields( + db.query.text, + ), + err, + )] + async fn record_batch_activity( + &mut self, + mut activities: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error> { + // Sort the activity by ID, so that when batching the updates, Postgres + // locks the rows in a stable order, preventing deadlocks + activities.sort_unstable(); + let mut ids = Vec::with_capacity(activities.len()); + let mut last_activities = Vec::with_capacity(activities.len()); + let mut ips = Vec::with_capacity(activities.len()); + + for (id, last_activity, ip) in activities { + ids.push(Uuid::from(id)); + last_activities.push(last_activity); + ips.push(ip); + } + + let res = sqlx::query!( + r#" + UPDATE personal_sessions + SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at) + , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip) + FROM ( + SELECT * + FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[]) + AS t(personal_session_id, last_active_at, last_active_ip) + ) AS t + WHERE personal_sessions.personal_session_id = t.personal_session_id + "#, + &ids, + &last_activities, + &ips as &[Option], + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?; + + Ok(()) + } } impl Filter for PersonalSessionFilter<'_> { diff --git a/crates/storage/src/personal/session.rs b/crates/storage/src/personal/session.rs index c090efa30..b4f0bbba4 100644 --- a/crates/storage/src/personal/session.rs +++ b/crates/storage/src/personal/session.rs @@ -3,6 +3,8 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. +use std::net::IpAddr; + use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{ @@ -109,6 +111,21 @@ pub trait PersonalSessionRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result; + + /// Record a batch of [`PersonalSession`] activity + /// + /// # Parameters + /// + /// * `activity`: A list of tuples containing the session ID, the last + /// activity timestamp and the IP address of the client + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; } repository_impl!(PersonalSessionRepository: @@ -137,6 +154,11 @@ repository_impl!(PersonalSessionRepository: ) -> Result, Self::Error>; async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result; + + async fn record_batch_activity( + &mut self, + activity: Vec<(Ulid, DateTime, Option)>, + ) -> Result<(), Self::Error>; ); /// Filter parameters for listing personal sessions