Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions crates/cli/src/app_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ use mas_handlers::{
};
use mas_i18n::Translator;
use mas_keystore::{Encrypter, Keystore};
use mas_matrix::BoxHomeserverConnection;
use mas_matrix_synapse::SynapseConnection;
use mas_matrix::HomeserverConnection;
use mas_policy::{Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, SystemClock};
Expand All @@ -37,7 +36,7 @@ pub struct AppState {
pub cookie_manager: CookieManager,
pub encrypter: Encrypter,
pub url_builder: UrlBuilder,
pub homeserver_connection: SynapseConnection,
pub homeserver_connection: Arc<dyn HomeserverConnection>,
pub policy_factory: Arc<PolicyFactory>,
pub graphql_schema: GraphQLSchema,
pub http_client: reqwest::Client,
Expand Down Expand Up @@ -204,9 +203,9 @@ impl FromRef<AppState> for Limiter {
}
}

impl FromRef<AppState> for BoxHomeserverConnection {
impl FromRef<AppState> for Arc<dyn HomeserverConnection> {
fn from_ref(input: &AppState) -> Self {
Box::new(input.homeserver_connection.clone())
Arc::clone(&input.homeserver_connection)
}
}

Expand Down
17 changes: 7 additions & 10 deletions crates/cli/src/commands/manage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use mas_config::{
use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User};
use mas_email::Address;
use mas_matrix::HomeserverConnection;
use mas_matrix_synapse::SynapseConnection;
use mas_storage::{
Clock, RepositoryAccess, SystemClock,
compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository},
Expand All @@ -33,7 +32,10 @@ use rand::{RngCore, SeedableRng};
use sqlx::{Acquire, types::Uuid};
use tracing::{error, info, info_span, warn};

use crate::util::{database_connection_from_config, password_manager_from_config};
use crate::util::{
database_connection_from_config, homeserver_connection_from_config,
password_manager_from_config,
};

const USER_ATTRIBUTES_HEADING: &str = "User attributes";

Expand Down Expand Up @@ -491,12 +493,7 @@ impl Options {
let matrix_config = MatrixConfig::extract(figment)?;

let password_manager = password_manager_from_config(&password_config).await?;
let homeserver = SynapseConnection::new(
matrix_config.homeserver,
matrix_config.endpoint,
matrix_config.secret,
http_client,
);
let homeserver = homeserver_connection_from_config(&matrix_config, http_client);
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
Expand Down Expand Up @@ -746,7 +743,7 @@ impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> {
async fn check_and_normalize_username<'a>(
localpart_or_mxid: &'a str,
repo: &mut dyn RepositoryAccess<Error = DatabaseError>,
homeserver: &SynapseConnection,
homeserver: &dyn HomeserverConnection,
) -> anyhow::Result<&'a str> {
// XXX: this is a very basic MXID to localpart conversion
// Strip any leading '@'
Expand Down Expand Up @@ -828,7 +825,7 @@ impl UserCreationRequest<'_> {
}

/// Show the user creation request in a human-readable format
fn show(&self, term: &Term, homeserver: &SynapseConnection) -> std::io::Result<()> {
fn show(&self, term: &Term, homeserver: &dyn HomeserverConnection) -> std::io::Result<()> {
let value_style = Style::new().green();
let key_style = Style::new().bold();
let warning_style = Style::new().italic().red().bright();
Expand Down
15 changes: 5 additions & 10 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ use mas_config::{
};
use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache};
use mas_listener::server::Server;
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
Expand All @@ -26,9 +25,9 @@ use crate::{
app_state::AppState,
lifecycle::LifecycleManager,
util::{
database_pool_from_config, mailer_from_config, password_manager_from_config,
policy_factory_from_config, site_config_from_config, templates_from_config,
test_mailer_in_background,
database_pool_from_config, homeserver_connection_from_config, mailer_from_config,
password_manager_from_config, policy_factory_from_config, site_config_from_config,
templates_from_config, test_mailer_in_background,
},
};

Expand Down Expand Up @@ -153,12 +152,8 @@ impl Options {

let http_client = mas_http::reqwest_client();

let homeserver_connection = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client.clone(),
);
let homeserver_connection =
homeserver_connection_from_config(&config.matrix, http_client.clone());

if !self.no_worker {
let mailer = mailer_from_config(&config.email, &templates)?;
Expand Down
12 changes: 3 additions & 9 deletions crates/cli/src/commands/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ use std::{process::ExitCode, time::Duration};
use clap::Parser;
use figment::Figment;
use mas_config::{AppConfig, ConfigurationSection};
use mas_matrix_synapse::SynapseConnection;
use mas_router::UrlBuilder;
use tracing::{info, info_span};

use crate::{
lifecycle::LifecycleManager,
util::{
database_pool_from_config, mailer_from_config, site_config_from_config,
templates_from_config, test_mailer_in_background,
database_pool_from_config, homeserver_connection_from_config, mailer_from_config,
site_config_from_config, templates_from_config, test_mailer_in_background,
},
};

Expand Down Expand Up @@ -58,12 +57,7 @@ impl Options {
test_mailer_in_background(&mailer, Duration::from_secs(30));

let http_client = mas_http::reqwest_client();
let conn = SynapseConnection::new(
config.matrix.homeserver.clone(),
config.matrix.endpoint.clone(),
config.matrix.secret.clone(),
http_client,
);
let conn = homeserver_connection_from_config(&config.matrix, http_client);

drop(config);

Expand Down
34 changes: 31 additions & 3 deletions crates/cli/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::time::Duration;
use std::{sync::Arc, time::Duration};

use anyhow::Context;
use mas_config::{
AccountConfig, BrandingConfig, CaptchaConfig, DatabaseConfig, EmailConfig, EmailSmtpMode,
EmailTransportKind, ExperimentalConfig, MatrixConfig, PasswordsConfig, PolicyConfig,
TemplatesConfig,
EmailTransportKind, ExperimentalConfig, HomeserverKind, MatrixConfig, PasswordsConfig,
PolicyConfig, TemplatesConfig,
};
use mas_data_model::{SessionExpirationConfig, SiteConfig};
use mas_email::{MailTransport, Mailer};
use mas_handlers::passwords::PasswordManager;
use mas_matrix::{HomeserverConnection, ReadOnlyHomeserverConnection};
use mas_matrix_synapse::SynapseConnection;
use mas_policy::PolicyFactory;
use mas_router::UrlBuilder;
use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates};
Expand Down Expand Up @@ -346,6 +348,32 @@ pub async fn database_connection_from_config(
.context("could not connect to the database")
}

/// Create a clonable, type-erased [`HomeserverConnection`] from the
/// configuration
pub fn homeserver_connection_from_config(
config: &MatrixConfig,
http_client: reqwest::Client,
) -> Arc<dyn HomeserverConnection> {
match config.kind {
HomeserverKind::Synapse => Arc::new(SynapseConnection::new(
config.homeserver.clone(),
config.endpoint.clone(),
config.secret.clone(),
http_client,
)),
HomeserverKind::SynapseReadOnly => {
let connection = SynapseConnection::new(
config.homeserver.clone(),
config.endpoint.clone(),
config.secret.clone(),
http_client,
);
let readonly = ReadOnlyHomeserverConnection::new(connection);
Arc::new(readonly)
}
}
}

#[cfg(test)]
mod tests {
use rand::SeedableRng;
Expand Down
21 changes: 21 additions & 0 deletions crates/config/src/sections/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,29 @@ fn default_endpoint() -> Url {
Url::parse("http://localhost:8008/").unwrap()
}

/// The kind of homeserver it is.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
#[serde(rename_all = "snake_case")]
pub enum HomeserverKind {
/// Homeserver is Synapse
#[default]
Synapse,

/// Homeserver is Synapse, in read-only mode
///
/// This is meant for testing rolling out Matrix Authentication Service with
/// no risk of writing data to the homeserver.
SynapseReadOnly,
}

/// Configuration related to the Matrix homeserver
#[serde_as]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct MatrixConfig {
/// The kind of homeserver it is.
#[serde(default)]
pub kind: HomeserverKind,

/// The server name of the homeserver.
#[serde(default = "default_homeserver")]
pub homeserver: String,
Expand All @@ -49,6 +68,7 @@ impl MatrixConfig {
R: Rng + Send,
{
Self {
kind: HomeserverKind::default(),
homeserver: default_homeserver(),
secret: Alphanumeric.sample_string(&mut rng, 32),
endpoint: default_endpoint(),
Expand All @@ -57,6 +77,7 @@ impl MatrixConfig {

pub(crate) fn test() -> Self {
Self {
kind: HomeserverKind::default(),
homeserver: default_homeserver(),
secret: "test".to_owned(),
endpoint: default_endpoint(),
Expand Down
2 changes: 1 addition & 1 deletion crates/config/src/sections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub use self::{
BindConfig as HttpBindConfig, HttpConfig, ListenerConfig as HttpListenerConfig,
Resource as HttpResource, TlsConfig as HttpTlsConfig, UnixOrTcp,
},
matrix::MatrixConfig,
matrix::{HomeserverKind, MatrixConfig},
passwords::{Algorithm as PasswordAlgorithm, PasswordsConfig},
policy::PolicyConfig,
rate_limiting::RateLimitingConfig,
Expand Down
6 changes: 4 additions & 2 deletions crates/handlers/src/admin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::sync::Arc;

use aide::{
axum::ApiRouter,
openapi::{OAuth2Flow, OAuth2Flows, OpenApi, SecurityScheme, Server, Tag},
Expand All @@ -19,7 +21,7 @@ use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use indexmap::IndexMap;
use mas_axum_utils::FancyError;
use mas_http::CorsLayerExt;
use mas_matrix::BoxHomeserverConnection;
use mas_matrix::HomeserverConnection;
use mas_router::{
ApiDoc, ApiDocCallback, OAuth2AuthorizationEndpoint, OAuth2TokenEndpoint, Route, SimpleRoute,
UrlBuilder,
Expand Down Expand Up @@ -107,7 +109,7 @@ fn finish(t: TransformOpenApi) -> TransformOpenApi {
pub fn router<S>() -> (OpenApi, Router<S>)
where
S: Clone + Send + Sync + 'static,
BoxHomeserverConnection: FromRef<S>,
Arc<dyn HomeserverConnection>: FromRef<S>,
PasswordManager: FromRef<S>,
BoxRng: FromRequestParts<S>,
CallContext: FromRequestParts<S>,
Expand Down
6 changes: 4 additions & 2 deletions crates/handlers/src/admin/v1/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::sync::Arc;

use aide::axum::{
ApiRouter,
routing::{get_with, post_with},
};
use axum::extract::{FromRef, FromRequestParts};
use mas_matrix::BoxHomeserverConnection;
use mas_matrix::HomeserverConnection;
use mas_storage::BoxRng;

use super::call_context::CallContext;
Expand All @@ -25,7 +27,7 @@ mod users;
pub fn router<S>() -> ApiRouter<S>
where
S: Clone + Send + Sync + 'static,
BoxHomeserverConnection: FromRef<S>,
Arc<dyn HomeserverConnection>: FromRef<S>,
PasswordManager: FromRef<S>,
BoxRng: FromRequestParts<S>,
CallContext: FromRequestParts<S>,
Expand Down
6 changes: 4 additions & 2 deletions crates/handlers/src/admin/v1/users/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::sync::Arc;

use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_matrix::BoxHomeserverConnection;
use mas_matrix::HomeserverConnection;
use mas_storage::{
BoxRng,
queue::{ProvisionUserJob, QueueJobRepositoryExt as _},
Expand Down Expand Up @@ -135,7 +137,7 @@ pub async fn handler(
mut repo, clock, ..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
State(homeserver): State<BoxHomeserverConnection>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
Json(params): Json<Request>,
) -> Result<(StatusCode, Json<SingleResponse<User>>), RouteError> {
if repo.user().exists(&params.username).await? {
Expand Down
6 changes: 4 additions & 2 deletions crates/handlers/src/admin/v1/users/unlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use std::sync::Arc;

use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use mas_matrix::BoxHomeserverConnection;
use mas_matrix::HomeserverConnection;
use ulid::Ulid;

use crate::{
Expand Down Expand Up @@ -67,7 +69,7 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
#[tracing::instrument(name = "handler.admin.v1.users.unlock", skip_all, err)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
State(homeserver): State<BoxHomeserverConnection>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
id: UlidPathParam,
) -> Result<Json<SingleResponse<User>>, RouteError> {
let id = *id;
Expand Down
4 changes: 2 additions & 2 deletions crates/handlers/src/bin/api-schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)]
#![warn(clippy::pedantic)]

use std::io::Write;
use std::{io::Write, sync::Arc};

use aide::openapi::{Server, ServerVariable};
use indexmap::IndexMap;
Expand Down Expand Up @@ -55,7 +55,7 @@ impl_from_request_parts!(mas_storage::BoxRng);
impl_from_request_parts!(mas_handlers::BoundActivityTracker);
impl_from_ref!(mas_router::UrlBuilder);
impl_from_ref!(mas_templates::Templates);
impl_from_ref!(mas_matrix::BoxHomeserverConnection);
impl_from_ref!(Arc<dyn mas_matrix::HomeserverConnection>);
impl_from_ref!(mas_keystore::Keystore);
impl_from_ref!(mas_handlers::passwords::PasswordManager);

Expand Down
Loading
Loading