diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index 6f8f059f0..f89eb13e1 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -15,8 +15,7 @@ use mas_handlers::{ }; use mas_i18n::Translator; use mas_keystore::{Encrypter, Keystore}; -use mas_matrix::BoxHomeserverConnection; -use mas_matrix_synapse::SynapseConnection; +use mas_matrix::HomeserverConnection; use mas_policy::{Policy, PolicyFactory}; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock}; @@ -37,7 +36,7 @@ pub struct AppState { pub cookie_manager: CookieManager, pub encrypter: Encrypter, pub url_builder: UrlBuilder, - pub homeserver_connection: SynapseConnection, + pub homeserver_connection: Arc, pub policy_factory: Arc, pub graphql_schema: GraphQLSchema, pub http_client: reqwest::Client, @@ -204,9 +203,9 @@ impl FromRef for Limiter { } } -impl FromRef for BoxHomeserverConnection { +impl FromRef for Arc { fn from_ref(input: &AppState) -> Self { - Box::new(input.homeserver_connection.clone()) + Arc::clone(&input.homeserver_connection) } } diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 03df778a8..003841a35 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -17,7 +17,6 @@ use mas_config::{ use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_email::Address; use mas_matrix::HomeserverConnection; -use mas_matrix_synapse::SynapseConnection; use mas_storage::{ Clock, RepositoryAccess, SystemClock, compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository}, @@ -33,7 +32,10 @@ use rand::{RngCore, SeedableRng}; use sqlx::{Acquire, types::Uuid}; use tracing::{error, info, info_span, warn}; -use crate::util::{database_connection_from_config, password_manager_from_config}; +use crate::util::{ + database_connection_from_config, homeserver_connection_from_config, + password_manager_from_config, +}; const USER_ATTRIBUTES_HEADING: &str = "User attributes"; @@ -491,12 +493,7 @@ impl Options { let matrix_config = MatrixConfig::extract(figment)?; let password_manager = password_manager_from_config(&password_config).await?; - let homeserver = SynapseConnection::new( - matrix_config.homeserver, - matrix_config.endpoint, - matrix_config.secret, - http_client, - ); + let homeserver = homeserver_connection_from_config(&matrix_config, http_client); let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; let mut repo = PgRepository::from_conn(txn); @@ -746,7 +743,7 @@ impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> { async fn check_and_normalize_username<'a>( localpart_or_mxid: &'a str, repo: &mut dyn RepositoryAccess, - homeserver: &SynapseConnection, + homeserver: &dyn HomeserverConnection, ) -> anyhow::Result<&'a str> { // XXX: this is a very basic MXID to localpart conversion // Strip any leading '@' @@ -828,7 +825,7 @@ impl UserCreationRequest<'_> { } /// Show the user creation request in a human-readable format - fn show(&self, term: &Term, homeserver: &SynapseConnection) -> std::io::Result<()> { + fn show(&self, term: &Term, homeserver: &dyn HomeserverConnection) -> std::io::Result<()> { let value_style = Style::new().green(); let key_style = Style::new().bold(); let warning_style = Style::new().italic().red().bright(); diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index d58fcb9da..892945de9 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -15,7 +15,6 @@ use mas_config::{ }; use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache}; use mas_listener::server::Server; -use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; use mas_storage::SystemClock; use mas_storage_pg::MIGRATOR; @@ -26,9 +25,9 @@ use crate::{ app_state::AppState, lifecycle::LifecycleManager, util::{ - database_pool_from_config, mailer_from_config, password_manager_from_config, - policy_factory_from_config, site_config_from_config, templates_from_config, - test_mailer_in_background, + database_pool_from_config, homeserver_connection_from_config, mailer_from_config, + password_manager_from_config, policy_factory_from_config, site_config_from_config, + templates_from_config, test_mailer_in_background, }, }; @@ -153,12 +152,8 @@ impl Options { let http_client = mas_http::reqwest_client(); - let homeserver_connection = SynapseConnection::new( - config.matrix.homeserver.clone(), - config.matrix.endpoint.clone(), - config.matrix.secret.clone(), - http_client.clone(), - ); + let homeserver_connection = + homeserver_connection_from_config(&config.matrix, http_client.clone()); if !self.no_worker { let mailer = mailer_from_config(&config.email, &templates)?; diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index 187d6b7cb..da16e848a 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -9,15 +9,14 @@ use std::{process::ExitCode, time::Duration}; use clap::Parser; use figment::Figment; use mas_config::{AppConfig, ConfigurationSection}; -use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; use tracing::{info, info_span}; use crate::{ lifecycle::LifecycleManager, util::{ - database_pool_from_config, mailer_from_config, site_config_from_config, - templates_from_config, test_mailer_in_background, + database_pool_from_config, homeserver_connection_from_config, mailer_from_config, + site_config_from_config, templates_from_config, test_mailer_in_background, }, }; @@ -58,12 +57,7 @@ impl Options { test_mailer_in_background(&mailer, Duration::from_secs(30)); let http_client = mas_http::reqwest_client(); - let conn = SynapseConnection::new( - config.matrix.homeserver.clone(), - config.matrix.endpoint.clone(), - config.matrix.secret.clone(), - http_client, - ); + let conn = homeserver_connection_from_config(&config.matrix, http_client); drop(config); diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 0f155afc7..27c23eeb5 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -4,17 +4,19 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use anyhow::Context; use mas_config::{ AccountConfig, BrandingConfig, CaptchaConfig, DatabaseConfig, EmailConfig, EmailSmtpMode, - EmailTransportKind, ExperimentalConfig, MatrixConfig, PasswordsConfig, PolicyConfig, - TemplatesConfig, + EmailTransportKind, ExperimentalConfig, HomeserverKind, MatrixConfig, PasswordsConfig, + PolicyConfig, TemplatesConfig, }; use mas_data_model::{SessionExpirationConfig, SiteConfig}; use mas_email::{MailTransport, Mailer}; use mas_handlers::passwords::PasswordManager; +use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection}; +use mas_matrix_synapse::SynapseConnection; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates}; @@ -346,6 +348,32 @@ pub async fn database_connection_from_config( .context("could not connect to the database") } +/// Create a clonable, type-erased [`HomeserverConnection`] from the +/// configuration +pub fn homeserver_connection_from_config( + config: &MatrixConfig, + http_client: reqwest::Client, +) -> Arc { + match config.kind { + HomeserverKind::Synapse => Arc::new(SynapseConnection::new( + config.homeserver.clone(), + config.endpoint.clone(), + config.secret.clone(), + http_client, + )), + HomeserverKind::SynapseReadOnly => { + let connection = SynapseConnection::new( + config.homeserver.clone(), + config.endpoint.clone(), + config.secret.clone(), + http_client, + ); + let readonly = ReadOnlyHomeserverConnection::new(connection); + Arc::new(readonly) + } + } +} + #[cfg(test)] mod tests { use rand::SeedableRng; diff --git a/crates/config/src/sections/matrix.rs b/crates/config/src/sections/matrix.rs index eb145fa8c..d5e35907e 100644 --- a/crates/config/src/sections/matrix.rs +++ b/crates/config/src/sections/matrix.rs @@ -23,10 +23,29 @@ fn default_endpoint() -> Url { Url::parse("http://localhost:8008/").unwrap() } +/// The kind of homeserver it is. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)] +#[serde(rename_all = "snake_case")] +pub enum HomeserverKind { + /// Homeserver is Synapse + #[default] + Synapse, + + /// Homeserver is Synapse, in read-only mode + /// + /// This is meant for testing rolling out Matrix Authentication Service with + /// no risk of writing data to the homeserver. + SynapseReadOnly, +} + /// Configuration related to the Matrix homeserver #[serde_as] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct MatrixConfig { + /// The kind of homeserver it is. + #[serde(default)] + pub kind: HomeserverKind, + /// The server name of the homeserver. #[serde(default = "default_homeserver")] pub homeserver: String, @@ -49,6 +68,7 @@ impl MatrixConfig { R: Rng + Send, { Self { + kind: HomeserverKind::default(), homeserver: default_homeserver(), secret: Alphanumeric.sample_string(&mut rng, 32), endpoint: default_endpoint(), @@ -57,6 +77,7 @@ impl MatrixConfig { pub(crate) fn test() -> Self { Self { + kind: HomeserverKind::default(), homeserver: default_homeserver(), secret: "test".to_owned(), endpoint: default_endpoint(), diff --git a/crates/config/src/sections/mod.rs b/crates/config/src/sections/mod.rs index aa773e70b..d415f646a 100644 --- a/crates/config/src/sections/mod.rs +++ b/crates/config/src/sections/mod.rs @@ -37,7 +37,7 @@ pub use self::{ BindConfig as HttpBindConfig, HttpConfig, ListenerConfig as HttpListenerConfig, Resource as HttpResource, TlsConfig as HttpTlsConfig, UnixOrTcp, }, - matrix::MatrixConfig, + matrix::{HomeserverKind, MatrixConfig}, passwords::{Algorithm as PasswordAlgorithm, PasswordsConfig}, policy::PolicyConfig, rate_limiting::RateLimitingConfig, diff --git a/crates/handlers/src/admin/mod.rs b/crates/handlers/src/admin/mod.rs index 4edc8b2ca..dcc781576 100644 --- a/crates/handlers/src/admin/mod.rs +++ b/crates/handlers/src/admin/mod.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use aide::{ axum::ApiRouter, openapi::{OAuth2Flow, OAuth2Flows, OpenApi, SecurityScheme, Server, Tag}, @@ -19,7 +21,7 @@ use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use indexmap::IndexMap; use mas_axum_utils::FancyError; use mas_http::CorsLayerExt; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_router::{ ApiDoc, ApiDocCallback, OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, Route, SimpleRoute, UrlBuilder, @@ -107,7 +109,7 @@ fn finish(t: TransformOpenApi) -> TransformOpenApi { pub fn router() -> (OpenApi, Router) where S: Clone + Send + Sync + 'static, - BoxHomeserverConnection: FromRef, + Arc: FromRef, PasswordManager: FromRef, BoxRng: FromRequestParts, CallContext: FromRequestParts, diff --git a/crates/handlers/src/admin/v1/mod.rs b/crates/handlers/src/admin/v1/mod.rs index 02c63e2a5..09d56ee2a 100644 --- a/crates/handlers/src/admin/v1/mod.rs +++ b/crates/handlers/src/admin/v1/mod.rs @@ -4,12 +4,14 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use aide::axum::{ ApiRouter, routing::{get_with, post_with}, }; use axum::extract::{FromRef, FromRequestParts}; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_storage::BoxRng; use super::call_context::CallContext; @@ -25,7 +27,7 @@ mod users; pub fn router() -> ApiRouter where S: Clone + Send + Sync + 'static, - BoxHomeserverConnection: FromRef, + Arc: FromRef, PasswordManager: FromRef, BoxRng: FromRequestParts, CallContext: FromRequestParts, diff --git a/crates/handlers/src/admin/v1/users/add.rs b/crates/handlers/src/admin/v1/users/add.rs index b4e0abd7e..d6c83db51 100644 --- a/crates/handlers/src/admin/v1/users/add.rs +++ b/crates/handlers/src/admin/v1/users/add.rs @@ -4,10 +4,12 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use aide::{NoApi, OperationIo, transform::TransformOperation}; use axum::{Json, extract::State, response::IntoResponse}; use hyper::StatusCode; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_storage::{ BoxRng, queue::{ProvisionUserJob, QueueJobRepositoryExt as _}, @@ -135,7 +137,7 @@ pub async fn handler( mut repo, clock, .. }: CallContext, NoApi(mut rng): NoApi, - State(homeserver): State, + State(homeserver): State>, Json(params): Json, ) -> Result<(StatusCode, Json>), RouteError> { if repo.user().exists(¶ms.username).await? { diff --git a/crates/handlers/src/admin/v1/users/unlock.rs b/crates/handlers/src/admin/v1/users/unlock.rs index 1626d4f58..76bb738c3 100644 --- a/crates/handlers/src/admin/v1/users/unlock.rs +++ b/crates/handlers/src/admin/v1/users/unlock.rs @@ -4,10 +4,12 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use aide::{OperationIo, transform::TransformOperation}; use axum::{Json, extract::State, response::IntoResponse}; use hyper::StatusCode; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use ulid::Ulid; use crate::{ @@ -67,7 +69,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation { #[tracing::instrument(name = "handler.admin.v1.users.unlock", skip_all, err)] pub async fn handler( CallContext { mut repo, .. }: CallContext, - State(homeserver): State, + State(homeserver): State>, id: UlidPathParam, ) -> Result>, RouteError> { let id = *id; diff --git a/crates/handlers/src/bin/api-schema.rs b/crates/handlers/src/bin/api-schema.rs index 993719564..a70856f17 100644 --- a/crates/handlers/src/bin/api-schema.rs +++ b/crates/handlers/src/bin/api-schema.rs @@ -13,7 +13,7 @@ )] #![warn(clippy::pedantic)] -use std::io::Write; +use std::{io::Write, sync::Arc}; use aide::openapi::{Server, ServerVariable}; use indexmap::IndexMap; @@ -55,7 +55,7 @@ impl_from_request_parts!(mas_storage::BoxRng); impl_from_request_parts!(mas_handlers::BoundActivityTracker); impl_from_ref!(mas_router::UrlBuilder); impl_from_ref!(mas_templates::Templates); -impl_from_ref!(mas_matrix::BoxHomeserverConnection); +impl_from_ref!(Arc); impl_from_ref!(mas_keystore::Keystore); impl_from_ref!(mas_handlers::passwords::PasswordManager); diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index f3966357b..7c1a6d97a 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use axum::{ Json, extract::{State, rejection::JsonRejection}, @@ -16,7 +18,7 @@ use mas_axum_utils::sentry::SentryEventID; use mas_data_model::{ CompatSession, CompatSsoLoginState, Device, SiteConfig, TokenType, User, UserAgent, }; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, compat::{ @@ -268,7 +270,7 @@ pub(crate) async fn post( State(password_manager): State, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, - State(homeserver): State, + State(homeserver): State>, State(site_config): State, State(limiter): State, requester: RequesterFingerprint, @@ -441,7 +443,7 @@ async fn user_password_login( limiter: &Limiter, requester: RequesterFingerprint, repo: &mut BoxRepository, - homeserver: &BoxHomeserverConnection, + homeserver: &dyn HomeserverConnection, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index b436c33c3..f5fe6432f 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use anyhow::Context; use axum::{ @@ -18,7 +18,7 @@ use mas_axum_utils::{ csrf::{CsrfExt, ProtectedForm}, }; use mas_data_model::Device; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_router::{CompatLoginSsoAction, UrlBuilder}; use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, @@ -120,7 +120,7 @@ pub async fn post( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, - State(homeserver): State, + State(homeserver): State>, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, diff --git a/crates/handlers/src/graphql/mod.rs b/crates/handlers/src/graphql/mod.rs index 306f50cdd..013a37c54 100644 --- a/crates/handlers/src/graphql/mod.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -69,7 +69,7 @@ pub struct ExtraRouterParameters { struct GraphQLState { pool: PgPool, - homeserver_connection: Arc>, + homeserver_connection: Arc, policy_factory: Arc, site_config: SiteConfig, password_manager: PasswordManager, @@ -99,7 +99,7 @@ impl state::State for GraphQLState { &self.site_config } - fn homeserver_connection(&self) -> &dyn HomeserverConnection { + fn homeserver_connection(&self) -> &dyn HomeserverConnection { self.homeserver_connection.as_ref() } @@ -129,7 +129,7 @@ impl state::State for GraphQLState { pub fn schema( pool: &PgPool, policy_factory: &Arc, - homeserver_connection: impl HomeserverConnection + 'static, + homeserver_connection: impl HomeserverConnection + 'static, site_config: SiteConfig, password_manager: PasswordManager, url_builder: UrlBuilder, diff --git a/crates/handlers/src/graphql/model/matrix.rs b/crates/handlers/src/graphql/model/matrix.rs index b79d9b0b4..930742285 100644 --- a/crates/handlers/src/graphql/model/matrix.rs +++ b/crates/handlers/src/graphql/model/matrix.rs @@ -26,7 +26,7 @@ impl MatrixUser { pub(crate) async fn load( conn: &C, user: &str, - ) -> Result { + ) -> Result { let mxid = conn.mxid(user); let info = conn.query_user(&mxid).await?; diff --git a/crates/handlers/src/graphql/state.rs b/crates/handlers/src/graphql/state.rs index fe5c07158..83d50b86a 100644 --- a/crates/handlers/src/graphql/state.rs +++ b/crates/handlers/src/graphql/state.rs @@ -17,7 +17,7 @@ pub trait State { async fn repository(&self) -> Result; async fn policy(&self) -> Result; fn password_manager(&self) -> PasswordManager; - fn homeserver_connection(&self) -> &dyn HomeserverConnection; + fn homeserver_connection(&self) -> &dyn HomeserverConnection; fn clock(&self) -> BoxClock; fn rng(&self) -> BoxRng; fn site_config(&self) -> &SiteConfig; diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 5ba623876..3a43fee42 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -15,7 +15,11 @@ clippy::let_with_type_underscore, )] -use std::{convert::Infallible, sync::LazyLock, time::Duration}; +use std::{ + convert::Infallible, + sync::{Arc, LazyLock}, + time::Duration, +}; use axum::{ Extension, Router, @@ -35,7 +39,7 @@ use mas_axum_utils::{FancyError, cookies::CookieJar}; use mas_data_model::SiteConfig; use mas_http::CorsLayerExt; use mas_keystore::{Encrypter, Keystore}; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_policy::Policy; use mas_router::{Route, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng}; @@ -198,7 +202,7 @@ where Encrypter: FromRef, reqwest::Client: FromRef, SiteConfig: FromRef, - BoxHomeserverConnection: FromRef, + Arc: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, Policy: FromRequestParts, @@ -254,7 +258,7 @@ where S: Clone + Send + Sync + 'static, UrlBuilder: FromRef, SiteConfig: FromRef, - BoxHomeserverConnection: FromRef, + Arc: FromRef, PasswordManager: FromRef, Limiter: FromRef, BoundActivityTracker: FromRequestParts, @@ -322,7 +326,7 @@ where SiteConfig: FromRef, Limiter: FromRef, reqwest::Client: FromRef, - BoxHomeserverConnection: FromRef, + Arc: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, Policy: FromRequestParts, diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 5846d0d22..d76c81f1b 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use axum::{Json, extract::State, response::IntoResponse}; use axum_extra::typed_header::TypedHeader; use chrono::Duration; @@ -17,7 +19,7 @@ use mas_data_model::{ AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType, UserAgent, }; use mas_keystore::{Encrypter, Keystore}; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_oidc_client::types::scope::ScopeToken; use mas_policy::Policy; use mas_router::UrlBuilder; @@ -226,7 +228,7 @@ pub(crate) async fn post( State(url_builder): State, activity_tracker: BoundActivityTracker, mut repo: BoxRepository, - State(homeserver): State, + State(homeserver): State>, State(site_config): State, State(encrypter): State, policy: Policy, @@ -337,7 +339,7 @@ async fn authorization_code_grant( url_builder: &UrlBuilder, site_config: &SiteConfig, mut repo: BoxRepository, - homeserver: &BoxHomeserverConnection, + homeserver: &Arc, user_agent: Option, ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { // Check that the client is allowed to use this grant type @@ -741,7 +743,7 @@ async fn device_code_grant( url_builder: &UrlBuilder, site_config: &SiteConfig, mut repo: BoxRepository, - homeserver: &BoxHomeserverConnection, + homeserver: &Arc, user_agent: Option, ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { // Check that the client is allowed to use this grant type diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index 96073301a..60da629e2 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -31,7 +31,7 @@ use mas_config::RateLimitingConfig; use mas_data_model::SiteConfig; use mas_i18n::Translator; use mas_keystore::{Encrypter, JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_matrix::{BoxHomeserverConnection, HomeserverConnection, MockHomeserverConnection}; +use mas_matrix::{HomeserverConnection, MockHomeserverConnection}; use mas_policy::{InstantiateError, Policy, PolicyFactory}; use mas_router::{SimpleRoute, UrlBuilder}; use mas_storage::{BoxClock, BoxRepository, BoxRng, clock::MockClock}; @@ -420,7 +420,7 @@ impl graphql::State for TestGraphQLState { self.password_manager.clone() } - fn homeserver_connection(&self) -> &dyn HomeserverConnection { + fn homeserver_connection(&self) -> &dyn HomeserverConnection { &self.homeserver_connection } @@ -513,9 +513,9 @@ impl FromRef for SiteConfig { } } -impl FromRef for BoxHomeserverConnection { +impl FromRef for Arc { fn from_ref(input: &TestState) -> Self { - Box::new(input.homeserver_connection.clone()) + input.homeserver_connection.clone() } } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 58153206f..f8631112e 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use axum::{ Form, extract::{Path, State}, @@ -19,7 +21,7 @@ use mas_axum_utils::{ }; use mas_data_model::{User, UserAgent}; use mas_jose::jwt::Jwt; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ @@ -200,7 +202,7 @@ pub(crate) async fn get( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, - State(homeserver): State, + State(homeserver): State>, cookie_jar: CookieJar, activity_tracker: BoundActivityTracker, user_agent: Option>, @@ -512,7 +514,7 @@ pub(crate) async fn post( PreferredLanguage(locale): PreferredLanguage, activity_tracker: BoundActivityTracker, State(templates): State, - State(homeserver): State, + State(homeserver): State>, State(url_builder): State, State(site_config): State, Path(link_id): Path, diff --git a/crates/handlers/src/views/login.rs b/crates/handlers/src/views/login.rs index c1fad0b7f..90f496557 100644 --- a/crates/handlers/src/views/login.rs +++ b/crates/handlers/src/views/login.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use axum::{ extract::{Form, Query, State}, response::{Html, IntoResponse, Response}, @@ -17,7 +19,7 @@ use mas_axum_utils::{ }; use mas_data_model::{BrowserSession, UserAgent, oauth2::LoginHint}; use mas_i18n::DataLocale; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_router::{UpstreamOAuth2Authorize, UrlBuilder}; use mas_storage::{ BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, @@ -56,7 +58,7 @@ pub(crate) async fn get( State(templates): State, State(url_builder): State, State(site_config): State, - State(homeserver): State, + State(homeserver): State>, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, Query(query): Query, @@ -99,7 +101,7 @@ pub(crate) async fn get( csrf_token, &mut repo, &templates, - homeserver, + &homeserver, ) .await?; @@ -116,7 +118,7 @@ pub(crate) async fn post( State(templates): State, State(url_builder): State, State(limiter): State, - State(homeserver): State, + State(homeserver): State>, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, requester: RequesterFingerprint, @@ -161,7 +163,7 @@ pub(crate) async fn post( csrf_token, &mut repo, &templates, - homeserver, + &homeserver, ) .await?; @@ -207,7 +209,7 @@ pub(crate) async fn post( csrf_token, &mut repo, &templates, - homeserver, + &homeserver, ) .await?; @@ -301,7 +303,7 @@ async fn login( fn handle_login_hint( ctx: &mut LoginContext, next: &PostAuthContext, - homeserver: &BoxHomeserverConnection, + homeserver: &dyn HomeserverConnection, ) { let form_state = ctx.form_state_mut(); @@ -326,11 +328,11 @@ async fn render( csrf_token: CsrfToken, repo: &mut impl RepositoryAccess, templates: &Templates, - homeserver: BoxHomeserverConnection, + homeserver: &dyn HomeserverConnection, ) -> Result { let next = action.load_context(repo).await?; let ctx = if let Some(next) = next { - handle_login_hint(&mut ctx, &next, &homeserver); + handle_login_hint(&mut ctx, &next, homeserver); ctx.with_post_action(next) } else { diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index fef85e8d0..470959cb5 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -4,7 +4,7 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::str::FromStr; +use std::{str::FromStr, sync::Arc}; use axum::{ extract::{Form, Query, State}, @@ -20,7 +20,7 @@ use mas_axum_utils::{ }; use mas_data_model::{CaptchaConfig, UserAgent}; use mas_i18n::DataLocale; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ @@ -128,7 +128,7 @@ pub(crate) async fn post( State(templates): State, State(url_builder): State, State(site_config): State, - State(homeserver): State, + State(homeserver): State>, State(http_client): State, (State(limiter), requester): (State, RequesterFingerprint), mut policy: Policy, diff --git a/crates/handlers/src/views/register/steps/finish.rs b/crates/handlers/src/views/register/steps/finish.rs index 770905491..3edb6fbfd 100644 --- a/crates/handlers/src/views/register/steps/finish.rs +++ b/crates/handlers/src/views/register/steps/finish.rs @@ -3,6 +3,8 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. +use std::sync::Arc; + use anyhow::Context as _; use axum::{ extract::{Path, State}, @@ -12,7 +14,7 @@ use axum_extra::TypedHeader; use chrono::Duration; use mas_axum_utils::{FancyError, SessionInfoExt as _, cookies::CookieJar}; use mas_data_model::UserAgent; -use mas_matrix::BoxHomeserverConnection; +use mas_matrix::HomeserverConnection; use mas_router::{PostAuthAction, UrlBuilder}; use mas_storage::{ BoxClock, BoxRepository, BoxRng, @@ -38,7 +40,7 @@ pub(crate) async fn get( activity_tracker: BoundActivityTracker, user_agent: Option>, State(url_builder): State, - State(homeserver): State, + State(homeserver): State>, State(templates): State, PreferredLanguage(lang): PreferredLanguage, cookie_jar: CookieJar, diff --git a/crates/matrix-synapse/src/lib.rs b/crates/matrix-synapse/src/lib.rs index 36f9e97d0..9d3e9c812 100644 --- a/crates/matrix-synapse/src/lib.rs +++ b/crates/matrix-synapse/src/lib.rs @@ -157,8 +157,6 @@ struct UsernameAvailableResponse { #[async_trait::async_trait] impl HomeserverConnection for SynapseConnection { - type Error = anyhow::Error; - fn homeserver(&self) -> &str { &self.homeserver } @@ -172,7 +170,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn query_user(&self, mxid: &str) -> Result { + async fn query_user(&self, mxid: &str) -> Result { let mxid = urlencoding::encode(mxid); let response = self @@ -207,7 +205,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn is_localpart_available(&self, localpart: &str) -> Result { + async fn is_localpart_available(&self, localpart: &str) -> Result { let localpart = urlencoding::encode(localpart); let response = self @@ -252,7 +250,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn provision_user(&self, request: &ProvisionRequest) -> Result { + async fn provision_user(&self, request: &ProvisionRequest) -> Result { let mut body = SynapseUser { external_ids: Some(vec![ExternalID { auth_provider: SYNAPSE_AUTH_PROVIDER.to_owned(), @@ -311,7 +309,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { let mxid = urlencoding::encode(mxid); let response = self @@ -348,7 +346,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { let mxid = urlencoding::encode(mxid); let device_id = urlencoding::encode(device_id); @@ -384,7 +382,11 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + async fn sync_devices( + &self, + mxid: &str, + devices: HashSet, + ) -> Result<(), anyhow::Error> { // Get the list of current devices let mxid_url = urlencoding::encode(mxid); @@ -454,7 +456,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> { let mxid = urlencoding::encode(mxid); let response = self @@ -521,7 +523,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { + async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> { let mxid = urlencoding::encode(mxid); let response = self .put(&format!("_matrix/client/v3/profile/{mxid}/displayname")) @@ -554,7 +556,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Display), )] - async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error> { + async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> { self.set_displayname(mxid, "").await } @@ -567,7 +569,7 @@ impl HomeserverConnection for SynapseConnection { ), err(Debug), )] - async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error> { + async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> { let mxid = urlencoding::encode(mxid); let response = self diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index 76f32e09f..59cdb4880 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -5,16 +5,15 @@ // Please see LICENSE in the repository root for full details. mod mock; +mod readonly; use std::{collections::HashSet, sync::Arc}; use ruma_common::UserId; -pub use self::mock::HomeserverConnection as MockHomeserverConnection; - -// TODO: this should probably be another error type by default -pub type BoxHomeserverConnection = - Box>; +pub use self::{ + mock::HomeserverConnection as MockHomeserverConnection, readonly::ReadOnlyHomeserverConnection, +}; #[derive(Debug)] pub struct MatrixUser { @@ -180,9 +179,6 @@ impl ProvisionRequest { #[async_trait::async_trait] pub trait HomeserverConnection: Send + Sync { - /// The error type returned by all methods. - type Error; - /// Get the homeserver URL. fn homeserver(&self) -> &str; @@ -221,7 +217,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the user does not /// exist. - async fn query_user(&self, mxid: &str) -> Result; + async fn query_user(&self, mxid: &str) -> Result; /// Provision a user on the homeserver. /// @@ -234,7 +230,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the user could not /// be provisioned. - async fn provision_user(&self, request: &ProvisionRequest) -> Result; + async fn provision_user(&self, request: &ProvisionRequest) -> Result; /// Check whether a given username is available on the homeserver. /// @@ -245,7 +241,7 @@ pub trait HomeserverConnection: Send + Sync { /// # Errors /// /// Returns an error if the homeserver is unreachable. - async fn is_localpart_available(&self, localpart: &str) -> Result; + async fn is_localpart_available(&self, localpart: &str) -> Result; /// Create a device for a user on the homeserver. /// @@ -258,7 +254,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the device could /// not be created. - async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>; + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error>; /// Delete a device for a user on the homeserver. /// @@ -271,7 +267,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the device could /// not be deleted. - async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>; + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error>; /// Sync the list of devices of a user with the homeserver. /// @@ -284,7 +280,8 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the devices could /// not be synced. - async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error>; + async fn sync_devices(&self, mxid: &str, devices: HashSet) + -> Result<(), anyhow::Error>; /// Delete a user on the homeserver. /// @@ -297,7 +294,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the user could not /// be deleted. - async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error>; + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error>; /// Reactivate a user on the homeserver. /// @@ -309,7 +306,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the user could not /// be reactivated. - async fn reactivate_user(&self, mxid: &str) -> Result<(), Self::Error>; + async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error>; /// Set the displayname of a user on the homeserver. /// @@ -322,7 +319,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the displayname /// could not be set. - async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error>; + async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error>; /// Unset the displayname of a user on the homeserver. /// @@ -334,7 +331,7 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the displayname /// could not be unset. - async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error>; + async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error>; /// Temporarily allow a user to reset their cross-signing keys. /// @@ -346,58 +343,60 @@ pub trait HomeserverConnection: Send + Sync { /// /// Returns an error if the homeserver is unreachable or the cross-signing /// reset could not be allowed. - async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error>; + async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error>; } #[async_trait::async_trait] impl HomeserverConnection for &T { - type Error = T::Error; - fn homeserver(&self) -> &str { (**self).homeserver() } - async fn query_user(&self, mxid: &str) -> Result { + async fn query_user(&self, mxid: &str) -> Result { (**self).query_user(mxid).await } - async fn provision_user(&self, request: &ProvisionRequest) -> Result { + async fn provision_user(&self, request: &ProvisionRequest) -> Result { (**self).provision_user(request).await } - async fn is_localpart_available(&self, localpart: &str) -> Result { + async fn is_localpart_available(&self, localpart: &str) -> Result { (**self).is_localpart_available(localpart).await } - async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { (**self).create_device(mxid, device_id).await } - async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { (**self).delete_device(mxid, device_id).await } - async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + async fn sync_devices( + &self, + mxid: &str, + devices: HashSet, + ) -> Result<(), anyhow::Error> { (**self).sync_devices(mxid, devices).await } - async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> { (**self).delete_user(mxid, erase).await } - async fn reactivate_user(&self, mxid: &str) -> Result<(), Self::Error> { + async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> { (**self).reactivate_user(mxid).await } - async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { + async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> { (**self).set_displayname(mxid, displayname).await } - async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error> { + async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> { (**self).unset_displayname(mxid).await } - async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error> { + async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> { (**self).allow_cross_signing_reset(mxid).await } } @@ -405,53 +404,55 @@ impl HomeserverConnection for &T // Implement for Arc where T: HomeserverConnection #[async_trait::async_trait] impl HomeserverConnection for Arc { - type Error = T::Error; - fn homeserver(&self) -> &str { (**self).homeserver() } - async fn query_user(&self, mxid: &str) -> Result { + async fn query_user(&self, mxid: &str) -> Result { (**self).query_user(mxid).await } - async fn provision_user(&self, request: &ProvisionRequest) -> Result { + async fn provision_user(&self, request: &ProvisionRequest) -> Result { (**self).provision_user(request).await } - async fn is_localpart_available(&self, localpart: &str) -> Result { + async fn is_localpart_available(&self, localpart: &str) -> Result { (**self).is_localpart_available(localpart).await } - async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { (**self).create_device(mxid, device_id).await } - async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { (**self).delete_device(mxid, device_id).await } - async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + async fn sync_devices( + &self, + mxid: &str, + devices: HashSet, + ) -> Result<(), anyhow::Error> { (**self).sync_devices(mxid, devices).await } - async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> { (**self).delete_user(mxid, erase).await } - async fn reactivate_user(&self, mxid: &str) -> Result<(), Self::Error> { + async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> { (**self).reactivate_user(mxid).await } - async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { + async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> { (**self).set_displayname(mxid, displayname).await } - async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error> { + async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> { (**self).unset_displayname(mxid).await } - async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error> { + async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> { (**self).allow_cross_signing_reset(mxid).await } } diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs index 50805ee12..22b9a43d5 100644 --- a/crates/matrix/src/mock.rs +++ b/crates/matrix/src/mock.rs @@ -50,13 +50,11 @@ impl HomeserverConnection { #[async_trait] impl crate::HomeserverConnection for HomeserverConnection { - type Error = anyhow::Error; - fn homeserver(&self) -> &str { &self.homeserver } - async fn query_user(&self, mxid: &str) -> Result { + async fn query_user(&self, mxid: &str) -> Result { let users = self.users.read().await; let user = users.get(mxid).context("User not found")?; Ok(MatrixUser { @@ -66,7 +64,7 @@ impl crate::HomeserverConnection for HomeserverConnection { }) } - async fn provision_user(&self, request: &ProvisionRequest) -> Result { + async fn provision_user(&self, request: &ProvisionRequest) -> Result { let mut users = self.users.write().await; let inserted = !users.contains_key(request.mxid()); let user = users.entry(request.mxid().to_owned()).or_insert(MockUser { @@ -99,7 +97,7 @@ impl crate::HomeserverConnection for HomeserverConnection { Ok(inserted) } - async fn is_localpart_available(&self, localpart: &str) -> Result { + async fn is_localpart_available(&self, localpart: &str) -> Result { if self.reserved_localparts.read().await.contains(localpart) { return Ok(false); } @@ -109,28 +107,32 @@ impl crate::HomeserverConnection for HomeserverConnection { Ok(!users.contains_key(&mxid)) } - async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn create_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.devices.insert(device_id.to_owned()); Ok(()) } - async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error> { + async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.devices.remove(device_id); Ok(()) } - async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + async fn sync_devices( + &self, + mxid: &str, + devices: HashSet, + ) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.devices = devices; Ok(()) } - async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.devices.clear(); @@ -144,7 +146,7 @@ impl crate::HomeserverConnection for HomeserverConnection { Ok(()) } - async fn reactivate_user(&self, mxid: &str) -> Result<(), Self::Error> { + async fn reactivate_user(&self, mxid: &str) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.deactivated = false; @@ -152,21 +154,21 @@ impl crate::HomeserverConnection for HomeserverConnection { Ok(()) } - async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), Self::Error> { + async fn set_displayname(&self, mxid: &str, displayname: &str) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.displayname = Some(displayname.to_owned()); Ok(()) } - async fn unset_displayname(&self, mxid: &str) -> Result<(), Self::Error> { + async fn unset_displayname(&self, mxid: &str) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.displayname = None; Ok(()) } - async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), Self::Error> { + async fn allow_cross_signing_reset(&self, mxid: &str) -> Result<(), anyhow::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; user.cross_signing_reset_allowed = true; diff --git a/crates/matrix/src/readonly.rs b/crates/matrix/src/readonly.rs new file mode 100644 index 000000000..b51040080 --- /dev/null +++ b/crates/matrix/src/readonly.rs @@ -0,0 +1,78 @@ +// Copyright 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::collections::HashSet; + +use crate::{HomeserverConnection, MatrixUser, ProvisionRequest}; + +/// A wrapper around a [`HomeserverConnection`] that only allows read +/// operations. +pub struct ReadOnlyHomeserverConnection { + inner: C, +} + +impl ReadOnlyHomeserverConnection { + pub fn new(inner: C) -> Self + where + C: HomeserverConnection, + { + Self { inner } + } +} + +#[async_trait::async_trait] +impl HomeserverConnection for ReadOnlyHomeserverConnection { + fn homeserver(&self) -> &str { + self.inner.homeserver() + } + + async fn query_user(&self, mxid: &str) -> Result { + self.inner.query_user(mxid).await + } + + async fn provision_user(&self, _request: &ProvisionRequest) -> Result { + anyhow::bail!("Provisioning is not supported in read-only mode"); + } + + async fn is_localpart_available(&self, localpart: &str) -> Result { + self.inner.is_localpart_available(localpart).await + } + + async fn create_device(&self, _mxid: &str, _device_id: &str) -> Result<(), anyhow::Error> { + anyhow::bail!("Device creation is not supported in read-only mode"); + } + + async fn delete_device(&self, _mxid: &str, _device_id: &str) -> Result<(), anyhow::Error> { + anyhow::bail!("Device deletion is not supported in read-only mode"); + } + + async fn sync_devices( + &self, + _mxid: &str, + _devices: HashSet, + ) -> Result<(), anyhow::Error> { + anyhow::bail!("Device synchronization is not supported in read-only mode"); + } + + async fn delete_user(&self, _mxid: &str, _erase: bool) -> Result<(), anyhow::Error> { + anyhow::bail!("User deletion is not supported in read-only mode"); + } + + async fn reactivate_user(&self, _mxid: &str) -> Result<(), anyhow::Error> { + anyhow::bail!("User reactivation is not supported in read-only mode"); + } + + async fn set_displayname(&self, _mxid: &str, _displayname: &str) -> Result<(), anyhow::Error> { + anyhow::bail!("User displayname update is not supported in read-only mode"); + } + + async fn unset_displayname(&self, _mxid: &str) -> Result<(), anyhow::Error> { + anyhow::bail!("User displayname update is not supported in read-only mode"); + } + + async fn allow_cross_signing_reset(&self, _mxid: &str) -> Result<(), anyhow::Error> { + anyhow::bail!("Allowing cross-signing reset is not supported in read-only mode"); + } +} diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index d95941b8a..41ec78fc8 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -40,7 +40,7 @@ struct State { pool: Pool, mailer: Mailer, clock: SystemClock, - homeserver: Arc>, + homeserver: Arc, url_builder: UrlBuilder, site_config: SiteConfig, } @@ -50,7 +50,7 @@ impl State { pool: Pool, clock: SystemClock, mailer: Mailer, - homeserver: impl HomeserverConnection + 'static, + homeserver: impl HomeserverConnection + 'static, url_builder: UrlBuilder, site_config: SiteConfig, ) -> Self { @@ -91,7 +91,7 @@ impl State { Ok(repo) } - pub fn matrix_connection(&self) -> &dyn HomeserverConnection { + pub fn matrix_connection(&self) -> &dyn HomeserverConnection { self.homeserver.as_ref() } @@ -112,7 +112,7 @@ impl State { pub async fn init( pool: &Pool, mailer: &Mailer, - homeserver: impl HomeserverConnection + 'static, + homeserver: impl HomeserverConnection + 'static, url_builder: UrlBuilder, site_config: &SiteConfig, cancellation_token: CancellationToken, diff --git a/docs/config.schema.json b/docs/config.schema.json index ce5c12aa3..a998f08fd 100644 --- a/docs/config.schema.json +++ b/docs/config.schema.json @@ -1612,6 +1612,15 @@ "secret" ], "properties": { + "kind": { + "description": "The kind of homeserver it is.", + "default": "synapse", + "allOf": [ + { + "$ref": "#/definitions/HomeserverKind" + } + ] + }, "homeserver": { "description": "The server name of the homeserver.", "default": "localhost:8008", @@ -1629,6 +1638,25 @@ } } }, + "HomeserverKind": { + "description": "The kind of homeserver it is.", + "oneOf": [ + { + "description": "Homeserver is Synapse", + "type": "string", + "enum": [ + "synapse" + ] + }, + { + "description": "Homeserver is Synapse, in read-only mode\n\nThis is meant for testing rolling out Matrix Authentication Service with no risk of writing data to the homeserver.", + "type": "string", + "enum": [ + "synapse_read_only" + ] + } + ] + }, "PolicyConfig": { "description": "Application secrets", "type": "object",