diff --git a/Cargo.lock b/Cargo.lock index 114ec20f9..10dae7210 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2422,7 +2422,6 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", ] [[package]] @@ -2885,16 +2884,6 @@ dependencies = [ "serde", ] -[[package]] -name = "iri-string" -version = "0.7.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc0f0a572e8ffe56e2ff4f769f32ffe919282c3916799f8b68688b6030063bea" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -3189,6 +3178,7 @@ dependencies = [ "mime", "oauth2-types", "rand", + "reqwest", "sentry", "serde", "serde_json", @@ -3255,6 +3245,7 @@ dependencies = [ "prometheus", "rand", "rand_chacha", + "reqwest", "rustls", "sentry", "sentry-tower", @@ -3386,6 +3377,7 @@ dependencies = [ "psl", "rand", "rand_chacha", + "reqwest", "rustls", "schemars", "sentry", @@ -3404,6 +3396,7 @@ dependencies = [ "tracing-subscriber", "ulid", "url", + "wiremock", "zeroize", "zxcvbn", ] @@ -3412,27 +3405,15 @@ dependencies = [ name = "mas-http" version = "0.12.0" dependencies = [ - "anyhow", - "async-trait", - "bytes", "futures-util", "headers", "http", - "http-body", - "http-body-util", - "hyper", - "hyper-rustls", "hyper-util", - "mas-tower", "opentelemetry", + "opentelemetry-http", "opentelemetry-semantic-conventions", - "pin-project-lite", - "rustls", + "reqwest", "rustls-platform-verifier", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror", "tokio", "tower 0.5.1", "tower-http", @@ -3609,8 +3590,10 @@ dependencies = [ "mas-axum-utils", "mas-http", "mas-matrix", + "reqwest", "serde", "serde_json", + "thiserror", "tower 0.5.1", "tracing", "url", @@ -3622,12 +3605,11 @@ name = "mas-oidc-client" version = "0.12.0" dependencies = [ "assert_matches", + "async-trait", "base64ct", "bitflags 2.6.0", - "bytes", "chrono", "form_urlencoded", - "futures-util", "headers", "http", "http-body-util", @@ -3640,14 +3622,13 @@ dependencies = [ "oauth2-types", "rand", "rand_chacha", + "reqwest", "rustls", "serde", "serde_json", "serde_urlencoded", - "serde_with", "thiserror", "tokio", - "tower 0.5.1", "tracing", "url", "wiremock", @@ -4154,11 +4135,8 @@ dependencies = [ "async-trait", "bytes", "http", - "http-body-util", - "hyper", - "hyper-util", "opentelemetry", - "tokio", + "reqwest", ] [[package]] @@ -5027,8 +5005,11 @@ checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", + "futures-channel", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", @@ -5052,12 +5033,12 @@ dependencies = [ "sync_wrapper 1.0.1", "tokio", "tokio-rustls", + "tokio-socks", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", "windows-registry", ] @@ -5417,12 +5398,15 @@ version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5484316556650182f03b43d4c746ce0e3e48074a21e2f51244b648b6542e1066" dependencies = [ + "httpdate", + "reqwest", "sentry-backtrace", "sentry-contexts", "sentry-core", "sentry-panic", "sentry-tower", "sentry-tracing", + "tokio", ] [[package]] @@ -6290,6 +6274,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-socks" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" +dependencies = [ + "either", + "futures-util", + "thiserror", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.16" @@ -6393,7 +6389,6 @@ dependencies = [ "pin-project-lite", "sync_wrapper 0.1.2", "tokio", - "tokio-util", "tower-layer", "tower-service", "tracing", @@ -6413,14 +6408,12 @@ dependencies = [ "http-body-util", "http-range-header", "httpdate", - "iri-string", "mime", "mime_guess", "percent-encoding", "pin-project-lite", "tokio", "tokio-util", - "tower 0.5.1", "tower-layer", "tower-service", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 3840ab231..6e1970498 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,6 +107,10 @@ features = ["derive"] version = "0.10.19" features = ["env", "yaml", "test"] +# Utilities for dealing with futures +[workspace.dependencies.futures-util] +version = "0.3.31" + # Rate-limiting [workspace.dependencies.governor] version = "0.7.0" @@ -182,6 +186,12 @@ version = "0.3.0" [workspace.dependencies.rand] version = "0.8.5" +# High-level HTTP client +[workspace.dependencies.reqwest] +version = "0.12.8" +default-features = false +features = ["http2", "rustls-tls-manual-roots", "charset", "json", "socks"] + # TLS stack [workspace.dependencies.rustls] version = "0.23.15" @@ -215,7 +225,7 @@ features = [ [workspace.dependencies.sentry] version = "0.34.0" default-features = false -features = ["backtrace", "contexts", "panic", "tower"] +features = ["backtrace", "contexts", "panic", "tower", "reqwest"] # Sentry tower layer [workspace.dependencies.sentry-tower] @@ -272,7 +282,7 @@ features = ["util"] # Tower HTTP layers [workspace.dependencies.tower-http] version = "0.6.1" -features = ["cors", "fs", "add-extension"] +features = ["cors", "fs", "add-extension", "set-header"] # Logging and tracing [workspace.dependencies.tracing] @@ -286,7 +296,7 @@ version = "0.24.0" features = ["trace", "metrics"] [workspace.dependencies.opentelemetry-http] version = "0.13.0" -features = ["hyper"] +features = ["reqwest"] [workspace.dependencies.opentelemetry-semantic-conventions] version = "0.16.0" [workspace.dependencies.tracing-opentelemetry] @@ -303,6 +313,10 @@ features = ["serde"] version = "1.1.3" features = ["serde"] +# HTTP mock server +[workspace.dependencies.wiremock] +version = "0.6.2" + # A few profile opt-level tweaks to make the test suite run faster [profile.dev.package] num-bigint-dig.opt-level = 3 diff --git a/clippy.toml b/clippy.toml index a300ae05b..ac0f49bf4 100644 --- a/clippy.toml +++ b/clippy.toml @@ -5,6 +5,8 @@ disallowed-methods = [ { path = "chrono::Utc::now", reason = "source the current time from the clock instead" }, { path = "ulid::Ulid::from_datetime", reason = "use Ulid::from_datetime_with_source instead" }, { path = "ulid::Ulid::new", reason = "use Ulid::from_datetime_with_source instead" }, + { path = "reqwest::Client::new", reason = "use mas_http::reqwest_client instead" }, + { path = "reqwest::RequestBuilder::send", reason = "use send_traced instead" }, ] disallowed-types = [ diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 844724faf..24e679314 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -18,7 +18,7 @@ axum-extra.workspace = true bytes.workspace = true chrono.workspace = true data-encoding = "2.6.0" -futures-util = "0.3.31" +futures-util.workspace = true headers.workspace = true http.workspace = true http-body.workspace = true @@ -27,6 +27,7 @@ hyper-util.workspace = true icu_locid = "1.4.0" mime = "0.3.17" rand.workspace = true +reqwest.workspace = true sentry.workspace = true serde.workspace = true serde_with = "3.11.0" @@ -41,7 +42,7 @@ ulid.workspace = true oauth2-types.workspace = true mas-data-model.workspace = true -mas-http = { workspace = true, features = ["client"] } +mas-http.workspace = true mas-iana.workspace = true mas-jose.workspace = true mas-keystore.workspace = true diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index e1f953887..51ce79978 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -19,7 +19,7 @@ use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason}; use headers::{authorization::Basic, Authorization}; use http::{Request, StatusCode}; use mas_data_model::{Client, JwksOrJwksUri}; -use mas_http::HttpServiceExt; +use mas_http::RequestBuilderExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; @@ -28,9 +28,6 @@ use oauth2_types::errors::{ClientError, ClientErrorCode}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use thiserror::Error; -use tower::{Service, ServiceExt}; - -use crate::http_client_factory::HttpClientFactory; static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; @@ -104,7 +101,7 @@ impl Credentials { #[tracing::instrument(skip_all, err)] pub async fn verify( &self, - http_client_factory: &HttpClientFactory, + http_client: &reqwest::Client, encrypter: &Encrypter, method: &OAuthClientAuthenticationMethod, client: &Client, @@ -146,7 +143,7 @@ impl Credentials { .as_ref() .ok_or(CredentialsVerificationError::InvalidClientConfig)?; - let jwks = fetch_jwks(http_client_factory, jwks) + let jwks = fetch_jwks(http_client, jwks) .await .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?; @@ -181,7 +178,7 @@ impl Credentials { } async fn fetch_jwks( - http_client_factory: &HttpClientFactory, + http_client: &reqwest::Client, jwks: &JwksOrJwksUri, ) -> Result { let uri = match jwks { @@ -189,19 +186,15 @@ async fn fetch_jwks( JwksOrJwksUri::JwksUri(u) => u, }; - let request = http::Request::builder() - .uri(uri.as_str()) - .body(mas_http::EmptyBody::new()) - .unwrap(); - - let mut client = http_client_factory - .client("client.fetch_jwks") - .response_body_to_bytes() - .json_response::(); - - let response = client.ready().await?.call(request).await?; + let response = http_client + .get(uri.as_str()) + .send_traced() + .await? + .error_for_status()? + .json() + .await?; - Ok(response.into_body()) + Ok(response) } #[derive(Debug, Error)] diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs deleted file mode 100644 index 584dba499..000000000 --- a/crates/axum-utils/src/http_client_factory.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use http_body_util::Full; -use hyper_util::rt::TokioExecutor; -use mas_http::{ - make_traced_connector, BodyToBytesResponseLayer, Client, ClientLayer, ClientService, - HttpService, TracedClient, TracedConnector, -}; -use tower::{ - util::{MapErrLayer, MapRequestLayer}, - BoxError, Layer, -}; - -#[derive(Debug, Clone)] -pub struct HttpClientFactory { - traced_connector: TracedConnector, - client_layer: ClientLayer, -} - -impl Default for HttpClientFactory { - fn default() -> Self { - Self::new() - } -} - -impl HttpClientFactory { - /// Constructs a new HTTP client factory - #[must_use] - pub fn new() -> Self { - Self { - traced_connector: make_traced_connector(), - client_layer: ClientLayer::new(), - } - } - - /// Constructs a new HTTP client - pub fn client(&self, category: &'static str) -> ClientService> - where - B: axum::body::HttpBody + Send, - B::Data: Send, - { - let client = Client::builder(TokioExecutor::new()).build(self.traced_connector.clone()); - self.client_layer - .clone() - .with_category(category) - .layer(client) - } - - /// Constructs a new [`HttpService`], suitable for `mas-oidc-client` - pub fn http_service(&self, category: &'static str) -> HttpService { - let client = self.client(category); - let client = ( - MapErrLayer::new(BoxError::from), - MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), - BodyToBytesResponseLayer, - ) - .layer(client); - - HttpService::new(client) - } -} diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs index db26454d2..aa6fad9e6 100644 --- a/crates/axum-utils/src/lib.rs +++ b/crates/axum-utils/src/lib.rs @@ -12,7 +12,6 @@ pub mod cookies; pub mod csrf; pub mod error_wrapper; pub mod fancy_error; -pub mod http_client_factory; pub mod jwt; pub mod language_detection; pub mod sentry; diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index d5bc9dac8..45dd9f8c6 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -31,6 +31,7 @@ itertools = "0.13.0" listenfd = "1.0.1" rand.workspace = true rand_chacha = "0.3.1" +reqwest.workspace = true rustls.workspace = true serde_json.workspace = true serde_yaml = "0.9.34" @@ -71,8 +72,8 @@ sentry-tower.workspace = true mas-config.workspace = true mas-data-model.workspace = true mas-email.workspace = true -mas-handlers = { workspace = true } -mas-http = { workspace = true, features = ["client"] } +mas-handlers.workspace = true +mas-http.workspace = true mas-i18n.workspace = true mas-iana.workspace = true mas-keystore.workspace = true diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index b688b6fe3..525ec5591 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -14,7 +14,7 @@ use ipnetwork::IpNetwork; use mas_data_model::SiteConfig; use mas_handlers::{ passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, - GraphQLSchema, HttpClientFactory, Limiter, MetadataCache, RequesterFingerprint, + GraphQLSchema, Limiter, MetadataCache, RequesterFingerprint, }; use mas_i18n::Translator; use mas_keystore::{Encrypter, Keystore}; @@ -43,7 +43,7 @@ pub struct AppState { pub homeserver_connection: SynapseConnection, pub policy_factory: Arc, pub graphql_schema: GraphQLSchema, - pub http_client_factory: HttpClientFactory, + pub http_client: reqwest::Client, pub password_manager: PasswordManager, pub metadata_cache: MetadataCache, pub site_config: SiteConfig, @@ -116,13 +116,9 @@ impl AppState { let mut repo = PgRepository::from_conn(conn); - let http_service = self - .http_client_factory - .http_service("upstream_oauth2.metadata"); - self.metadata_cache .warm_up_and_run( - http_service, + &self.http_client, std::time::Duration::from_secs(60 * 15), &mut repo, ) @@ -173,9 +169,9 @@ impl FromRef for UrlBuilder { } } -impl FromRef for HttpClientFactory { +impl FromRef for reqwest::Client { fn from_ref(input: &AppState) -> Self { - input.http_client_factory.clone() + input.http_client.clone() } } diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 74ad0d3c4..8ffb50306 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -8,13 +8,7 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; -use http_body_util::BodyExt; -use hyper::{Response, Uri}; use mas_config::{ConfigurationSectionExt, PolicyConfig}; -use mas_handlers::HttpClientFactory; -use mas_http::HttpServiceExt; -use tokio::io::AsyncWriteExt; -use tower::{Service, ServiceExt}; use tracing::{info, info_span}; use crate::util::policy_factory_from_config; @@ -27,93 +21,15 @@ pub(super) struct Options { #[derive(Parser, Debug)] enum Subcommand { - /// Perform an HTTP request with the default HTTP client - Http { - /// Show response headers - #[arg(long, short = 'I')] - show_headers: bool, - - /// Parse the response as JSON - #[arg(long, short = 'j')] - json: bool, - - /// URI where to perform a GET request - url: Uri, - }, - /// Check that the policies compile Policy, } -fn print_headers(parts: &hyper::http::response::Parts) { - println!( - "{:?} {} {}", - parts.version, - parts.status.as_str(), - parts.status.canonical_reason().unwrap_or_default() - ); - - for (header, value) in &parts.headers { - println!("{header}: {value:?}"); - } - println!(); -} - impl Options { #[tracing::instrument(skip_all)] pub async fn run(self, figment: &Figment) -> anyhow::Result { use Subcommand as SC; - let http_client_factory = HttpClientFactory::new(); match self.subcommand { - SC::Http { - show_headers, - json: false, - url, - } => { - let _span = info_span!("cli.debug.http").entered(); - let mut client = http_client_factory.client("debug"); - let request = hyper::Request::builder() - .uri(url) - .body(axum::body::Body::empty())?; - - let response = client.ready().await?.call(request).await?; - let (parts, body) = response.into_parts(); - - if show_headers { - print_headers(&parts); - } - - let mut body = body.collect().await?.to_bytes(); - let mut stdout = tokio::io::stdout(); - stdout.write_all_buf(&mut body).await?; - } - - SC::Http { - show_headers, - json: true, - url, - } => { - let _span = info_span!("cli.debug.http").entered(); - let mut client = http_client_factory - .client("debug") - .response_body_to_bytes() - .json_response(); - let request = hyper::Request::builder() - .uri(url) - .body(axum::body::Body::empty())?; - - let response: Response = - client.ready().await?.call(request).await?; - let (parts, body) = response.into_parts(); - - if show_headers { - print_headers(&parts); - } - - let body = serde_json::to_string_pretty(&body)?; - println!("{body}"); - } - SC::Policy => { let _span = info_span!("cli.debug.policy").entered(); let config = PolicyConfig::extract_or_default(figment)?; diff --git a/crates/cli/src/commands/doctor.rs b/crates/cli/src/commands/doctor.rs index 0155a8033..d507b8308 100644 --- a/crates/cli/src/commands/doctor.rs +++ b/crates/cli/src/commands/doctor.rs @@ -15,9 +15,7 @@ use anyhow::Context; use clap::Parser; use figment::Figment; use mas_config::{ConfigurationSection, RootConfig}; -use mas_handlers::HttpClientFactory; -use mas_http::HttpServiceExt; -use tower::{Service, ServiceExt}; +use mas_http::RequestBuilderExt; use tracing::{error, info, info_span, warn}; use url::{Host, Url}; @@ -36,7 +34,7 @@ impl Options { let config = RootConfig::extract(figment)?; // We'll need an HTTP client - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let base_url = config.http.public_base.as_str(); let issuer = config.http.issuer.as_ref().map(url::Url::as_str); let issuer = issuer.unwrap_or(base_url); @@ -55,15 +53,7 @@ This means some clients will refuse to use it."# } let well_known_uri = format!("https://{matrix_domain}/.well-known/matrix/client"); - let mut client = http_client_factory - .client("doctor") - .response_body_to_bytes() - .json_response::(); - - let request = hyper::Request::builder() - .uri(&well_known_uri) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client.get(&well_known_uri).send_traced().await; let expected_well_known = serde_json::json!({ "m.homeserver": { @@ -86,15 +76,21 @@ Make sure the homeserver is reachable and the well-known document is available a ); } - let body = response.into_body(); - - if let Some(auth) = body.get("org.matrix.msc2965.authentication") { - if let Some(wk_issuer) = auth.get("issuer").and_then(|issuer| issuer.as_str()) { - if issuer == wk_issuer { - info!(r#"✅ Matrix client well-known at "{well_known_uri}" is valid"#); - } else { - warn!( - r#"⚠️ Matrix client well-known has an "org.matrix.msc2965.authentication" section, but the issuer is not the same as the homeserver. + let result = response.json::().await; + + match result { + Ok(body) => { + if let Some(auth) = body.get("org.matrix.msc2965.authentication") { + if let Some(wk_issuer) = + auth.get("issuer").and_then(|issuer| issuer.as_str()) + { + if issuer == wk_issuer { + info!( + r#"✅ Matrix client well-known at "{well_known_uri}" is valid"# + ); + } else { + warn!( + r#"⚠️ Matrix client well-known has an "org.matrix.msc2965.authentication" section, but the issuer is not the same as the homeserver. Check the well-known document at "{well_known_uri}" This can happen because MAS parses the URL its config differently from the homeserver. This means some OIDC-native clients might not work. @@ -116,18 +112,18 @@ And in the Synapse config: See {DOCS_BASE}/setup/homeserver.html "# - ); - } - } else { - error!( - r#"❌ Matrix client well-known "org.matrix.msc2965.authentication" does not have a valid "issuer" field. + ); + } + } else { + error!( + r#"❌ Matrix client well-known "org.matrix.msc2965.authentication" does not have a valid "issuer" field. Check the well-known document at "{well_known_uri}" "# - ); - } - } else { - warn!( - r#"Matrix client well-known is missing the "org.matrix.msc2965.authentication" section. + ); + } + } else { + warn!( + r#"Matrix client well-known is missing the "org.matrix.msc2965.authentication" section. Check the well-known document at "{well_known_uri}" Make sure Synapse has delegated auth enabled: @@ -143,14 +139,29 @@ If it is not Synapse handling the well-known document, update it to include the See {DOCS_BASE}/setup/homeserver.html "# - ); - } + ); + } + // Return the discovered homeserver base URL + body.get("m.homeserver") + .and_then(|hs| hs.get("base_url")) + .and_then(|base_url| base_url.as_str()) + .and_then(|base_url| Url::parse(base_url).ok()) + } + Err(e) => { + warn!( + r#"⚠️ Invalid JSON for the well-known document at "{well_known_uri}". +Make sure going to {well_known_uri:?} in a web browser returns a valid JSON document, similar to: + +{expected_well_known:#} + +See {DOCS_BASE}/setup/homeserver.html - // Return the discovered homeserver base URL - body.get("m.homeserver") - .and_then(|hs| hs.get("base_url")) - .and_then(|base_url| base_url.as_str()) - .and_then(|base_url| Url::parse(base_url).ok()) +Error details: {e} +"# + ); + None + } + } } Err(e) => { warn!( @@ -172,10 +183,10 @@ Error details: {e} // Now try to reach the homeserver let client_versions = hs_api.join("/_matrix/client/versions")?; - let request = hyper::Request::builder() - .uri(client_versions.as_str()) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client + .get(client_versions.as_str()) + .send_traced() + .await; let can_reach_cs = match result { Ok(response) => { let status = response.status(); @@ -222,18 +233,15 @@ Error details: {e} // Try the whoami API. If it replies with `M_UNKNOWN` this is because Synapse // couldn't reach MAS let whoami = hs_api.join("/_matrix/client/v3/account/whoami")?; - let request = hyper::Request::builder() - .header( - "Authorization", - "Bearer averyinvalidtokenireallyhopethisisnotvalid", - ) - .uri(whoami.as_str()) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client + .get(whoami.as_str()) + .bearer_auth("averyinvalidtokenireallyhopethisisnotvalid") + .send_traced() + .await; match result { Ok(response) => { - let (parts, body) = response.into_parts(); - let status = parts.status; + let status = response.status(); + let body = response.text().await.unwrap_or("???".into()); match status.as_u16() { 401 => info!( @@ -276,10 +284,7 @@ Error details: {e} // Try to reach the admin API on an unauthorized endpoint let server_version = hs_api.join("/_synapse/admin/v1/server_version")?; - let request = hyper::Request::builder() - .uri(server_version.as_str()) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client.get(server_version.as_str()).send_traced().await; match result { Ok(response) => { let status = response.status(); @@ -304,11 +309,11 @@ Error details: {e} // Try to reach an authenticated admin API endpoint let background_updates = hs_api.join("/_synapse/admin/v1/background_updates/status")?; - let request = hyper::Request::builder() - .uri(background_updates.as_str()) - .header("Authorization", format!("Bearer {admin_token}")) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client + .get(background_updates.as_str()) + .bearer_auth(&admin_token) + .send_traced() + .await; match result { Ok(response) => { let status = response.status(); @@ -353,17 +358,17 @@ Error details: {e} // Try to reach the legacy login API let compat_login = external_cs_api_endpoint.join("/_matrix/client/v3/login")?; let compat_login = compat_login.as_str(); - let request = hyper::Request::builder() - .uri(compat_login) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client.get(compat_login).send_traced().await; match result { Ok(response) => { let status = response.status(); if status.is_success() { // Now we need to inspect the body to figure out whether it's Synapse or MAS // which handled the request - let body = response.into_body(); + let body = response + .json::() + .await + .unwrap_or_default(); let flows = body .get("flows") .and_then(|flows| flows.as_array()) diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 07d4fa601..d8b55777b 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -16,7 +16,6 @@ use mas_config::{ }; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_email::Address; -use mas_handlers::HttpClientFactory; use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ @@ -512,7 +511,7 @@ impl Options { yes, ignore_password_complexity, } => { - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let password_config = PasswordsConfig::extract_or_default(figment)?; let database_config = DatabaseConfig::extract_or_default(figment)?; let matrix_config = MatrixConfig::extract(figment)?; @@ -522,7 +521,7 @@ impl Options { matrix_config.homeserver, matrix_config.endpoint, matrix_config.secret, - http_client_factory, + http_client, ); let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 8386b4689..0558453d8 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -13,7 +13,7 @@ use itertools::Itertools; use mas_config::{ AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config, }; -use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache}; +use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache}; use mas_listener::server::Server; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; @@ -148,13 +148,13 @@ impl Options { let templates = templates_from_config(&config.templates, &site_config, &url_builder).await?; - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let homeserver_connection = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), config.matrix.secret.clone(), - http_client_factory.clone(), + http_client.clone(), ); if !self.no_worker { @@ -241,7 +241,7 @@ impl Options { homeserver_connection, policy_factory, graphql_schema, - http_client_factory, + http_client, password_manager, metadata_cache, site_config, diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index 7a797fc16..3bbef12dd 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -9,7 +9,6 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; use mas_config::{AppConfig, ConfigurationSection}; -use mas_handlers::HttpClientFactory; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; use rand::{ @@ -57,12 +56,12 @@ impl Options { let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let conn = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), config.matrix.secret.clone(), - http_client_factory, + http_client, ); drop(config); diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index eee5a73a4..220480b93 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -16,17 +16,36 @@ use tracing_subscriber::{ filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer, Registry, }; -use crate::sentry_transport::HyperTransportFactory; - mod app_state; mod commands; -mod sentry_transport; mod server; mod shutdown; mod sync; mod telemetry; mod util; +#[derive(Debug)] +struct SentryTransportFactory { + client: reqwest::Client, +} + +impl SentryTransportFactory { + fn new() -> Self { + Self { + client: mas_http::reqwest_client(), + } + } +} + +impl sentry::TransportFactory for SentryTransportFactory { + fn create_transport(&self, options: &sentry::ClientOptions) -> Arc { + let transport = + sentry::transports::ReqwestHttpTransport::with_client(options, self.client.clone()); + + Arc::new(transport) + } +} + #[tokio::main] async fn main() -> anyhow::Result { // We're splitting the "fallible" part of main in another function to have a @@ -79,9 +98,7 @@ async fn try_main() -> anyhow::Result { let sentry = sentry::init(( telemetry_config.sentry.dsn.as_deref(), sentry::ClientOptions { - transport: Some(Arc::new(HyperTransportFactory::new( - mas_http::make_untraced_client(), - ))), + transport: Some(Arc::new(SentryTransportFactory::new())), traces_sample_rate: 1.0, auto_session_tracking: true, session_mode: sentry::SessionMode::Request, diff --git a/crates/cli/src/sentry_transport/mod.rs b/crates/cli/src/sentry_transport/mod.rs deleted file mode 100644 index 8eb8dad35..000000000 --- a/crates/cli/src/sentry_transport/mod.rs +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Implements a transport for Sentry based on Hyper. -//! -//! This avoids the dependency on `reqwest`, which helps avoiding having two -//! HTTP and TLS stacks in the binary. -//! -//! The [`ratelimit`] and [`tokio_thread`] modules are directly copied from the -//! Sentry codebase. - -use std::{sync::Arc, time::Duration}; - -use bytes::Bytes; -use http_body_util::{BodyExt, Full}; -use hyper::{header::RETRY_AFTER, StatusCode}; -use mas_http::UntracedClient; -use sentry::{sentry_debug, ClientOptions, Transport, TransportFactory}; - -use self::tokio_thread::TransportThread; - -mod ratelimit; -mod tokio_thread; - -pub struct HyperTransport { - thread: TransportThread, -} - -pub struct HyperTransportFactory { - client: UntracedClient>, -} - -impl HyperTransportFactory { - pub fn new(client: UntracedClient>) -> Self { - Self { client } - } -} - -impl TransportFactory for HyperTransportFactory { - fn create_transport(&self, options: &ClientOptions) -> Arc { - Arc::new(HyperTransport::new(options, self.client.clone())) - } -} - -impl HyperTransport { - pub fn new(options: &ClientOptions, client: UntracedClient>) -> Self { - let dsn = options.dsn.as_ref().unwrap(); - let user_agent = options.user_agent.clone(); - let auth = dsn.to_auth(Some(&user_agent)).to_string(); - let url = dsn.envelope_api_url().to_string(); - - let thread = TransportThread::new(move |envelope, mut rl| { - let mut body = Vec::new(); - envelope.to_writer(&mut body).unwrap(); - - let request = hyper::Request::post(&url) - .header("X-Sentry-Auth", &auth) - .body(Full::new(Bytes::from(body))) - .unwrap(); - - let fut = client.request(request); - - async move { - match fut.await { - Ok(response) => { - if let Some(sentry_header) = response - .headers() - .get("x-sentry-rate-limits") - .and_then(|x| x.to_str().ok()) - { - rl.update_from_sentry_header(sentry_header); - } else if let Some(retry_after) = response - .headers() - .get(RETRY_AFTER) - .and_then(|x| x.to_str().ok()) - { - rl.update_from_retry_after(retry_after); - } else if response.status() == StatusCode::TOO_MANY_REQUESTS { - rl.update_from_429(); - } - - match response.into_body().collect().await { - Err(err) => { - sentry_debug!("Failed to read sentry response: {}", err); - } - Ok(body) => { - let body = body.to_bytes(); - let text = String::from_utf8_lossy(&body); - sentry_debug!("Get response: `{}`", text); - } - } - } - Err(err) => { - sentry_debug!("Failed to send envelope: {}", err); - } - } - - rl - } - }); - - Self { thread } - } -} - -impl Transport for HyperTransport { - fn send_envelope(&self, envelope: sentry::Envelope) { - self.thread.send(envelope); - } - - fn flush(&self, timeout: Duration) -> bool { - self.thread.flush(timeout) - } - - fn shutdown(&self, timeout: Duration) -> bool { - self.flush(timeout) - } -} diff --git a/crates/cli/src/sentry_transport/ratelimit.rs b/crates/cli/src/sentry_transport/ratelimit.rs deleted file mode 100644 index e3d0b66e1..000000000 --- a/crates/cli/src/sentry_transport/ratelimit.rs +++ /dev/null @@ -1,180 +0,0 @@ -// Taken from sentry/transports/ratelimit.rs -#![allow(clippy::all, clippy::pedantic)] - -use std::time::{Duration, SystemTime}; - -use httpdate::parse_http_date; -use sentry::{protocol::EnvelopeItem, Envelope}; - -/// A Utility that helps with rate limiting sentry requests. -#[derive(Debug, Default)] -pub struct RateLimiter { - global: Option, - error: Option, - session: Option, - transaction: Option, - attachment: Option, -} - -impl RateLimiter { - /// Create a new RateLimiter. - pub fn new() -> Self { - Self::default() - } - - /// Updates the RateLimiter with information from a `Retry-After` header. - pub fn update_from_retry_after(&mut self, header: &str) { - let new_time = if let Ok(value) = header.parse::() { - SystemTime::now() + Duration::from_secs(value.ceil() as u64) - } else if let Ok(value) = parse_http_date(header) { - value - } else { - SystemTime::now() + Duration::from_secs(60) - }; - - self.global = Some(new_time); - } - - /// Updates the RateLimiter with information from a `X-Sentry-Rate-Limits` - /// header. - pub fn update_from_sentry_header(&mut self, header: &str) { - // = (,)+ - // =