From 366fb1d8e7864fbca7da38b5b135201afcf0871d Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Sun, 20 Oct 2024 22:02:40 +0200 Subject: [PATCH 01/12] WIP: switch to reqwest --- Cargo.lock | 6 + Cargo.toml | 6 + crates/handlers/Cargo.toml | 3 + crates/handlers/src/lib.rs | 1 + crates/handlers/src/test_utils.rs | 10 ++ .../handlers/src/upstream_oauth2/authorize.rs | 6 +- crates/handlers/src/upstream_oauth2/cache.rs | 36 ++--- .../handlers/src/upstream_oauth2/callback.rs | 8 +- crates/http/Cargo.toml | 2 + crates/http/src/lib.rs | 2 + crates/http/src/reqwest.rs | 130 ++++++++++++++++++ crates/oidc-client/Cargo.toml | 1 + crates/oidc-client/src/error.rs | 64 +-------- .../src/requests/authorization_code.rs | 4 +- crates/oidc-client/src/requests/discovery.rs | 36 ++--- crates/oidc-client/src/requests/jose.rs | 20 ++- crates/oidc-client/src/requests/token.rs | 11 +- .../src/types/client_credentials.rs | 5 +- crates/oidc-client/tests/it/main.rs | 12 +- .../tests/it/requests/discovery.rs | 4 +- 20 files changed, 222 insertions(+), 145 deletions(-) create mode 100644 crates/http/src/reqwest.rs diff --git a/Cargo.lock b/Cargo.lock index 114ec20f9..6fa9c412a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3386,6 +3386,7 @@ dependencies = [ "psl", "rand", "rand_chacha", + "reqwest", "rustls", "schemars", "sentry", @@ -3425,8 +3426,10 @@ dependencies = [ "hyper-util", "mas-tower", "opentelemetry", + "opentelemetry-http", "opentelemetry-semantic-conventions", "pin-project-lite", + "reqwest", "rustls", "rustls-platform-verifier", "serde", @@ -3640,6 +3643,7 @@ dependencies = [ "oauth2-types", "rand", "rand_chacha", + "reqwest", "rustls", "serde", "serde_json", @@ -5027,8 +5031,10 @@ checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" dependencies = [ "base64 0.22.1", "bytes", + "encoding_rs", "futures-core", "futures-util", + "h2", "http", "http-body", "http-body-util", diff --git a/Cargo.toml b/Cargo.toml index 3840ab231..09fe978cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -182,6 +182,12 @@ version = "0.3.0" [workspace.dependencies.rand] version = "0.8.5" +# High-level HTTP client +[workspace.dependencies.reqwest] +version = "0.12.8" +default-features = false +features = ["http2", "rustls-tls-manual-roots", "charset", "json"] + # TLS stack [workspace.dependencies.rustls] version = "0.23.15" diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 46a4788c1..d7fb3e850 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -41,6 +41,9 @@ aide.workspace = true async-graphql.workspace = true schemars.workspace = true +# HTTP client +reqwest.workspace = true + # Emails lettre.workspace = true diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 1c90edff8..a7065a832 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -312,6 +312,7 @@ where MetadataCache: FromRef, SiteConfig: FromRef, Limiter: FromRef, + reqwest::Client: FromRef, BoxHomeserverConnection: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, diff --git a/crates/handlers/src/test_utils.rs b/crates/handlers/src/test_utils.rs index e54dd2fc4..729e3d8ea 100644 --- a/crates/handlers/src/test_utils.rs +++ b/crates/handlers/src/test_utils.rs @@ -109,6 +109,7 @@ pub(crate) struct TestState { pub limiter: Limiter, pub clock: Arc, pub rng: Arc>, + pub http_client: reqwest::Client, #[allow(dead_code)] // It is used, as it will cancel the CancellationToken when dropped cancellation_drop_guard: Arc, @@ -169,6 +170,8 @@ impl TestState { ) .await?; + let http_client = mas_http::reqwest_client(); + // TODO: add more test keys to the store let rsa = PrivateKey::load_pem(include_str!("../../keystore/tests/keys/rsa.pkcs1.pem")).unwrap(); @@ -241,6 +244,7 @@ impl TestState { limiter, clock, rng, + http_client, cancellation_drop_guard: Arc::new(shutdown_token.drop_guard()), }) } @@ -494,6 +498,12 @@ impl FromRef for Limiter { } } +impl FromRef for reqwest::Client { + fn from_ref(input: &TestState) -> Self { + input.http_client.clone() + } +} + #[async_trait] impl FromRequestParts for ActivityTracker { type Rejection = Infallible; diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 720004649..21f86a80a 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -62,10 +62,10 @@ impl IntoResponse for RouteError { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(http_client_factory): State, State(metadata_cache): State, mut repo: BoxRepository, State(url_builder): State, + State(http_client): State, cookie_jar: CookieJar, Path(provider_id): Path, Query(query): Query, @@ -77,12 +77,10 @@ pub(crate) async fn get( .filter(UpstreamOAuthProvider::enabled) .ok_or(RouteError::ProviderNotFound)?; - let http_service = http_client_factory.http_service("upstream_oauth2.authorize"); - // First, discover the provider // This is done lazyly according to provider.discovery_mode and the various // endpoint overrides - let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service); + let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_client); lazy_metadata.maybe_discover().await?; let redirect_uri = url_builder.upstream_oauth_callback(provider.id); diff --git a/crates/handlers/src/upstream_oauth2/cache.rs b/crates/handlers/src/upstream_oauth2/cache.rs index d1530fdaf..7eda12192 100644 --- a/crates/handlers/src/upstream_oauth2/cache.rs +++ b/crates/handlers/src/upstream_oauth2/cache.rs @@ -9,7 +9,6 @@ use std::{collections::HashMap, sync::Arc}; use mas_data_model::{ UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode, }; -use mas_http::HttpService; use mas_iana::oauth::PkceCodeChallengeMethod; use mas_oidc_client::error::DiscoveryError; use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess}; @@ -22,7 +21,7 @@ use url::Url; pub struct LazyProviderInfos<'a> { cache: &'a MetadataCache, provider: &'a UpstreamOAuthProvider, - http_service: &'a HttpService, + client: &'a reqwest::Client, loaded_metadata: Option>, } @@ -30,12 +29,12 @@ impl<'a> LazyProviderInfos<'a> { pub fn new( cache: &'a MetadataCache, provider: &'a UpstreamOAuthProvider, - http_service: &'a HttpService, + client: &'a reqwest::Client, ) -> Self { Self { cache, provider, - http_service, + client, loaded_metadata: None, } } @@ -64,7 +63,7 @@ impl<'a> LazyProviderInfos<'a> { let metadata = self .cache - .get(self.http_service, &self.provider.issuer, verify) + .get(self.client, &self.provider.issuer, verify) .await?; self.loaded_metadata = Some(metadata); @@ -155,7 +154,7 @@ impl MetadataCache { #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)] pub async fn warm_up_and_run( &self, - http_service: HttpService, + client: &reqwest::Client, interval: std::time::Duration, repository: &mut R, ) -> Result, R::Error> { @@ -168,18 +167,19 @@ impl MetadataCache { UpstreamOAuthProviderDiscoveryMode::Disabled => continue, }; - if let Err(e) = self.fetch(&http_service, &provider.issuer, verify).await { + if let Err(e) = self.fetch(client, &provider.issuer, verify).await { tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata"); } } // Spawn a background task to refresh the cache regularly let cache = self.clone(); + let client = client.clone(); Ok(tokio::spawn(async move { loop { // Re-fetch the known metadata at the given interval tokio::time::sleep(interval).await; - cache.refresh_all(&http_service).await; + cache.refresh_all(&client).await; } })) } @@ -187,13 +187,12 @@ impl MetadataCache { #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)] async fn fetch( &self, - http_service: &HttpService, + client: &reqwest::Client, issuer: &str, verify: bool, ) -> Result, DiscoveryError> { if verify { - let metadata = - mas_oidc_client::requests::discovery::discover(http_service, issuer).await?; + let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?; let metadata = Arc::new(metadata); self.cache @@ -204,8 +203,7 @@ impl MetadataCache { Ok(metadata) } else { let metadata = - mas_oidc_client::requests::discovery::insecure_discover(http_service, issuer) - .await?; + mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?; let metadata = Arc::new(metadata); self.insecure_cache @@ -221,7 +219,7 @@ impl MetadataCache { #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)] pub async fn get( &self, - http_service: &HttpService, + client: &reqwest::Client, issuer: &str, verify: bool, ) -> Result, DiscoveryError> { @@ -237,12 +235,12 @@ impl MetadataCache { // Drop the cache guard so that we don't deadlock when we try to fetch drop(cache); - let metadata = self.fetch(http_service, issuer, verify).await?; + let metadata = self.fetch(client, issuer, verify).await?; Ok(metadata) } #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)] - async fn refresh_all(&self, http_service: &HttpService) { + async fn refresh_all(&self, client: &reqwest::Client) { // Grab all the keys first to avoid locking the cache for too long let keys: Vec = { let cache = self.cache.read().await; @@ -250,7 +248,7 @@ impl MetadataCache { }; for issuer in keys { - if let Err(e) = self.fetch(http_service, &issuer, true).await { + if let Err(e) = self.fetch(client, &issuer, true).await { tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata"); } } @@ -262,13 +260,14 @@ impl MetadataCache { }; for issuer in keys { - if let Err(e) = self.fetch(http_service, &issuer, false).await { + if let Err(e) = self.fetch(client, &issuer, false).await { tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata"); } } } } +/* TODO: redo those tests #[cfg(test)] mod tests { #![allow(clippy::too_many_lines)] @@ -619,3 +618,4 @@ mod tests { } } } +*/ diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index ce4b89607..f8afc7849 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -134,6 +134,7 @@ pub(crate) async fn get( State(url_builder): State, State(encrypter): State, State(keystore): State, + State(client): State, cookie_jar: CookieJar, Path(provider_id): Path, Query(params): Query, @@ -186,12 +187,11 @@ pub(crate) async fn get( CodeOrError::Code { code } => code, }; - let http_service = http_client_factory.http_service("upstream_oauth2.callback"); - let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service); + let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client); // Fetch the JWKS let jwks = - mas_oidc_client::requests::jose::fetch_jwks(&http_service, lazy_metadata.jwks_uri().await?) + mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?) .await?; // Figure out the client credentials @@ -222,7 +222,7 @@ pub(crate) async fn get( let (response, id_token) = mas_oidc_client::requests::authorization_code::access_token_with_authorization_code( - &http_service, + &client, client_credentials, lazy_metadata.token_endpoint().await?, code, diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index 4480dbb6b..f9967d706 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -23,10 +23,12 @@ hyper.workspace = true hyper-util.workspace = true hyper-rustls = { workspace = true, optional = true } opentelemetry.workspace = true +opentelemetry-http.workspace = true opentelemetry-semantic-conventions.workspace = true rustls = { workspace = true, optional = true } rustls-platform-verifier = { workspace = true, optional = true } pin-project-lite = "0.2.14" +reqwest.workspace = true serde.workspace = true serde_json.workspace = true serde_urlencoded = "0.7.1" diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index afc9d4bcc..64fcf4303 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -13,6 +13,7 @@ mod client; mod ext; mod layers; +mod reqwest; mod service; #[cfg(feature = "client")] @@ -33,6 +34,7 @@ pub use self::{ json_request::{self, JsonRequest, JsonRequestLayer}, json_response::{self, JsonResponse, JsonResponseLayer}, }, + reqwest::{client as reqwest_client, RequestBuilderExt}, service::{BoxCloneSyncService, HttpService}, }; diff --git a/crates/http/src/reqwest.rs b/crates/http/src/reqwest.rs new file mode 100644 index 000000000..9ed10e5f4 --- /dev/null +++ b/crates/http/src/reqwest.rs @@ -0,0 +1,130 @@ +// Copyright 2024 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::{future::Future, time::Duration}; + +use headers::{ContentLength, HeaderMapExt as _, Host, UserAgent}; +use hyper_util::client::legacy::connect::HttpInfo; +use opentelemetry_http::HeaderInjector; +use opentelemetry_semantic_conventions::{ + attribute::{HTTP_REQUEST_BODY_SIZE, HTTP_RESPONSE_BODY_SIZE}, + trace::{ + CLIENT_ADDRESS, CLIENT_PORT, HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, + NETWORK_TRANSPORT, NETWORK_TYPE, SERVER_ADDRESS, SERVER_PORT, URL_FULL, + USER_AGENT_ORIGINAL, + }, +}; +use tracing::Instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; + +static USER_AGENT: &str = concat!("matrix-authentication-service/", env!("CARGO_PKG_VERSION")); + +/// Create a new [`reqwest::Client`] with sane parameters +/// +/// # Panics +/// +/// Panics if the client fails to build, which should never happen +#[must_use] +pub fn client() -> reqwest::Client { + // TODO: make the resolver tracing again + // TODO: can/should we limit in-flight requests? + reqwest::Client::builder() + .use_preconfigured_tls(rustls_platform_verifier::tls_config()) + .user_agent(USER_AGENT) + .timeout(Duration::from_secs(60)) + .connect_timeout(Duration::from_secs(30)) + .read_timeout(Duration::from_secs(30)) + .build() + .expect("failed to create HTTP client") +} + +async fn send_traced( + request: reqwest::RequestBuilder, +) -> Result { + // TODO: have in-flight and request metrics + let (client, request) = request.build_split(); + let mut request = request?; + + let headers = request.headers(); + let host = headers.typed_get::().map(tracing::field::display); + let user_agent = headers + .typed_get::() + .map(tracing::field::display); + let content_length = headers.typed_get().map(|ContentLength(len)| len); + + // Create a new span for the request + let span = tracing::info_span!( + "http.client.request", + "otel.kind" = "client", + "otel.status_code" = tracing::field::Empty, + { HTTP_REQUEST_METHOD } = %request.method(), + { URL_FULL } = %request.url(), + { HTTP_RESPONSE_STATUS_CODE } = tracing::field::Empty, + { SERVER_ADDRESS } = host, + { HTTP_REQUEST_BODY_SIZE } = content_length, + { HTTP_RESPONSE_BODY_SIZE } = tracing::field::Empty, + { NETWORK_TRANSPORT } = "tcp", + { NETWORK_TYPE } = tracing::field::Empty, + { SERVER_ADDRESS } = tracing::field::Empty, + { SERVER_PORT } = tracing::field::Empty, + { CLIENT_ADDRESS } = tracing::field::Empty, + { CLIENT_PORT } = tracing::field::Empty, + { USER_AGENT_ORIGINAL } = user_agent, + "rust.error" = tracing::field::Empty, + ); + let _guard = span.enter(); + + // Inject the span context into the request headers + let context = span.context(); + opentelemetry::global::get_text_map_propagator(|propagator| { + let mut injector = HeaderInjector(request.headers_mut()); + propagator.inject_context(&context, &mut injector); + }); + + match client.execute(request).in_current_span().await { + Ok(response) => { + span.record("otel.status_code", "OK"); + span.record(HTTP_RESPONSE_STATUS_CODE, response.status().as_u16()); + + if let Some(ContentLength(content_length)) = response.headers().typed_get() { + span.record(HTTP_RESPONSE_BODY_SIZE, content_length); + } + + if let Some(http_info) = response.extensions().get::() { + let local = http_info.local_addr(); + let remote = http_info.remote_addr(); + + let family = if local.is_ipv4() { "ipv4" } else { "ipv6" }; + span.record(NETWORK_TYPE, family); + span.record(CLIENT_ADDRESS, remote.ip().to_string()); + span.record(CLIENT_PORT, remote.port()); + span.record(SERVER_ADDRESS, local.ip().to_string()); + span.record(SERVER_PORT, local.port()); + } else { + tracing::warn!("No HttpInfo injected in response extensions"); + } + + Ok(response) + } + Err(err) => { + span.record("otel.status_code", "ERROR"); + span.record("rust.error", &err as &dyn std::error::Error); + Err(err) + } + } +} + +/// An extension trait implemented for [`reqwest::RequestBuilder`] to send a +/// request with a tracing span, and span context propagated. +pub trait RequestBuilderExt { + /// Send the request with a tracing span, and span context propagated. + fn send_traced(self) -> impl Future> + Send; +} + +impl RequestBuilderExt for reqwest::RequestBuilder { + fn send_traced(self) -> impl Future> + Send { + send_traced(self) + } +} diff --git a/crates/oidc-client/Cargo.toml b/crates/oidc-client/Cargo.toml index 556fda86a..04d55706a 100644 --- a/crates/oidc-client/Cargo.toml +++ b/crates/oidc-client/Cargo.toml @@ -26,6 +26,7 @@ http.workspace = true language-tags = "0.3.2" mime = "0.3.17" rand.workspace = true +reqwest.workspace = true serde.workspace = true serde_json.workspace = true serde_urlencoded = "0.7.1" diff --git a/crates/oidc-client/src/error.rs b/crates/oidc-client/src/error.rs index a0bedbb53..79260ca8a 100644 --- a/crates/oidc-client/src/error.rs +++ b/crates/oidc-client/src/error.rs @@ -68,57 +68,19 @@ pub enum DiscoveryError { #[error(transparent)] IntoUrl(#[from] url::ParseError), - /// An error occurred building the request. - #[error(transparent)] - IntoHttp(#[from] http::Error), - /// The server returned an HTTP error status code. #[error(transparent)] - Http(#[from] HttpError), - - /// An error occurred deserializing the response. - #[error(transparent)] - FromJson(#[from] serde_json::Error), + Http(#[from] reqwest::Error), /// An error occurred validating the metadata. #[error(transparent)] Validation(#[from] ProviderMetadataVerificationError), - /// An error occurred sending the request. - #[error(transparent)] - Service(BoxError), - /// Discovery is disabled for this provider. #[error("Discovery is disabled for this provider")] Disabled, } -impl From> for DiscoveryError -where - S: Into, -{ - fn from(err: json_response::Error) -> Self { - match err { - json_response::Error::Deserialize { inner } => inner.into(), - json_response::Error::Service { inner } => inner.into(), - } - } -} - -impl From>> for DiscoveryError -where - S: Into, -{ - fn from(err: catch_http_codes::Error>) -> Self { - match err { - catch_http_codes::Error::HttpError { status_code, inner } => { - Self::Http(HttpError::new(status_code, inner)) - } - catch_http_codes::Error::Service { inner } => Self::Service(inner.into()), - } - } -} - /// All possible errors when registering the client. #[derive(Debug, Error)] pub enum RegistrationError { @@ -563,30 +525,10 @@ where /// All possible errors when requesting a JWKS. #[derive(Debug, Error)] +#[error("Failed to fetch JWKS")] pub enum JwksError { - /// An error occurred building the request. - #[error(transparent)] - IntoHttp(#[from] http::Error), - - /// An error occurred deserializing the response. - #[error(transparent)] - Json(#[from] serde_json::Error), - /// An error occurred sending the request. - #[error(transparent)] - Service(BoxError), -} - -impl From> for JwksError -where - S: Into, -{ - fn from(err: json_response::Error) -> Self { - match err { - json_response::Error::Service { inner } => Self::Service(inner.into()), - json_response::Error::Deserialize { inner } => Self::Json(inner), - } - } + Http(#[from] reqwest::Error), } /// All possible errors when verifying a JWT. diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index e1bccf0a4..39a2dff4d 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -502,7 +502,7 @@ pub async fn build_par_authorization_url( #[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all, fields(token_endpoint))] pub async fn access_token_with_authorization_code( - http_service: &HttpService, + http_client: &reqwest::Client, client_credentials: ClientCredentials, token_endpoint: &Url, code: String, @@ -514,7 +514,7 @@ pub async fn access_token_with_authorization_code( tracing::debug!("Exchanging authorization code for access token..."); let token_response = request_access_token( - http_service, + http_client, client_credentials, token_endpoint, AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant { diff --git a/crates/oidc-client/src/requests/discovery.rs b/crates/oidc-client/src/requests/discovery.rs index c8761a371..c6cfe5767 100644 --- a/crates/oidc-client/src/requests/discovery.rs +++ b/crates/oidc-client/src/requests/discovery.rs @@ -8,21 +8,14 @@ //! //! [Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html -use bytes::Bytes; -use mas_http::{CatchHttpCodesLayer, JsonResponseLayer}; use oauth2_types::oidc::{ProviderMetadata, VerifiedProviderMetadata}; -use tower::{Layer, Service, ServiceExt}; use url::Url; -use crate::{ - error::DiscoveryError, - http_service::HttpService, - utils::{http_all_error_status_codes, http_error_mapper}, -}; +use crate::error::DiscoveryError; /// Fetch the provider metadata. async fn discover_inner( - http_service: &HttpService, + client: &reqwest::Client, issuer: Url, ) -> Result { tracing::debug!("Fetching provider metadata..."); @@ -39,18 +32,17 @@ async fn discover_inner( let config_url = config_url.join(".well-known/openid-configuration")?; - let config_req = http::Request::get(config_url.as_str()).body(Bytes::new())?; + let response = client + .get(config_url.as_str()) + .send() + .await? + .error_for_status()? + .json() + .await?; - let service = ( - JsonResponseLayer::::default(), - CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper), - ) - .layer(http_service.clone()); - - let response = service.ready_oneshot().await?.call(config_req).await?; tracing::debug!(?response); - Ok(response.into_body()) + Ok(response) } /// Fetch the provider metadata and validate it. @@ -60,10 +52,10 @@ async fn discover_inner( /// Returns an error if the request fails or if the data is invalid. #[tracing::instrument(skip_all, fields(issuer))] pub async fn discover( - http_service: &HttpService, + client: &reqwest::Client, issuer: &str, ) -> Result { - let provider_metadata = discover_inner(http_service, issuer.parse()?).await?; + let provider_metadata = discover_inner(client, issuer.parse()?).await?; Ok(provider_metadata.validate(issuer)?) } @@ -92,10 +84,10 @@ pub async fn discover( /// [provider metadata]: https://openid.net/specs/openid-connect-discovery-1_0.html #[tracing::instrument(skip_all, fields(issuer))] pub async fn insecure_discover( - http_service: &HttpService, + client: &reqwest::Client, issuer: &str, ) -> Result { - let provider_metadata = discover_inner(http_service, issuer.parse()?).await?; + let provider_metadata = discover_inner(client, issuer.parse()?).await?; Ok(provider_metadata.insecure_verify_metadata()?) } diff --git a/crates/oidc-client/src/requests/jose.rs b/crates/oidc-client/src/requests/jose.rs index cce6e1193..0a901db3b 100644 --- a/crates/oidc-client/src/requests/jose.rs +++ b/crates/oidc-client/src/requests/jose.rs @@ -8,9 +8,7 @@ use std::collections::HashMap; -use bytes::Bytes; use chrono::{DateTime, Utc}; -use mas_http::JsonResponseLayer; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ claims::{self, TimeOptions}, @@ -18,12 +16,10 @@ use mas_jose::{ jwt::Jwt, }; use serde_json::Value; -use tower::{Layer, Service, ServiceExt}; use url::Url; use crate::{ error::{IdTokenError, JwksError, JwtVerificationError}, - http_service::HttpService, types::IdToken, }; @@ -40,18 +36,20 @@ use crate::{ /// Returns an error if the request fails or if the data is invalid. #[tracing::instrument(skip_all, fields(jwks_uri))] pub async fn fetch_jwks( - http_service: &HttpService, + client: &reqwest::Client, jwks_uri: &Url, ) -> Result { tracing::debug!("Fetching JWKS..."); - let jwks_request = http::Request::get(jwks_uri.as_str()).body(Bytes::new())?; + let response: PublicJsonWebKeySet = client + .get(jwks_uri.as_str()) + .send() + .await? + .error_for_status()? + .json() + .await?; - let service = JsonResponseLayer::::default().layer(http_service.clone()); - - let response = service.ready_oneshot().await?.call(jwks_request).await?; - - Ok(response.into_body()) + Ok(response) } /// The data required to verify a JWT. diff --git a/crates/oidc-client/src/requests/token.rs b/crates/oidc-client/src/requests/token.rs index 196b2bec5..39c592eea 100644 --- a/crates/oidc-client/src/requests/token.rs +++ b/crates/oidc-client/src/requests/token.rs @@ -42,7 +42,7 @@ use crate::{ /// Returns an error if the request fails or the response is invalid. #[tracing::instrument(skip_all, fields(token_endpoint, request))] pub async fn request_access_token( - http_service: &HttpService, + http_client: &reqwest::Client, client_credentials: ClientCredentials, token_endpoint: &Url, request: AccessTokenRequest, @@ -51,17 +51,10 @@ pub async fn request_access_token( ) -> Result { tracing::debug!(?request, "Requesting access token..."); - let token_request = http::Request::post(token_endpoint.as_str()).body(request)?; + let token_request = http_client.post(token_endpoint.as_str()).form(&request); let token_request = client_credentials.apply_to_request(token_request, now, rng)?; - let service = ( - FormUrlencodedRequestLayer::default(), - JsonResponseLayer::::default(), - CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper), - ) - .layer(http_service.clone()); - let res = service.ready_oneshot().await?.call(token_request).await?; let token_response = res.into_body(); diff --git a/crates/oidc-client/src/types/client_credentials.rs b/crates/oidc-client/src/types/client_credentials.rs index 140fb44cb..09018ff1f 100644 --- a/crates/oidc-client/src/types/client_credentials.rs +++ b/crates/oidc-client/src/types/client_credentials.rs @@ -178,10 +178,11 @@ impl ClientCredentials { /// Apply these `ClientCredentials` to the given request. pub(crate) fn apply_to_request( self, - request: Request, + request: reqwest::RequestBuilder, now: DateTime, rng: &mut impl Rng, - ) -> Result>, CredentialsError> { + ) -> Result { + // TODO: get the form in params, augment it and serialize let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?; let (parts, body) = request.into_parts(); diff --git a/crates/oidc-client/tests/it/main.rs b/crates/oidc-client/tests/it/main.rs index 1c9db7eaf..53efb261c 100644 --- a/crates/oidc-client/tests/it/main.rs +++ b/crates/oidc-client/tests/it/main.rs @@ -55,20 +55,14 @@ fn now() -> DateTime { Utc::now() } -async fn init_test() -> (HttpService, MockServer, Url) { +async fn init_test() -> (reqwest::Client, MockServer, Url) { let _ = rustls::crypto::aws_lc_rs::default_provider().install_default(); - let http_service = ( - MapErrLayer::new(BoxError::from), - MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), - BodyToBytesResponseLayer, - ) - .layer(mas_http::make_untraced_client()); - let http_service = BoxCloneSyncService::new(http_service); + let client = mas_http::reqwest_client(); let mock_server = MockServer::start().await; let issuer = Url::parse(&mock_server.uri()).expect("Couldn't parse URL"); - (http_service, mock_server, issuer) + (client, mock_server, issuer) } /// Generate a keystore with a single key for the given algorithm. diff --git a/crates/oidc-client/tests/it/requests/discovery.rs b/crates/oidc-client/tests/it/requests/discovery.rs index 1cefce892..d7cd6cf13 100644 --- a/crates/oidc-client/tests/it/requests/discovery.rs +++ b/crates/oidc-client/tests/it/requests/discovery.rs @@ -44,9 +44,7 @@ async fn pass_discover() { .mount(&mock_server) .await; - let provider_metadata = insecure_discover(&http_service, issuer.as_str()) - .await - .unwrap(); + let provider_metadata = insecure_discover(&client, issuer.as_str()).await.unwrap(); assert_eq!(provider_metadata.issuer(), issuer.as_str()); } From 4c1271a3eacccbbb26e9ac4e8f0e8b38200556c0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 24 Oct 2024 16:41:14 +0200 Subject: [PATCH 02/12] Replace HTTP client in oidc-client with reqwest --- Cargo.lock | 3 +- Cargo.toml | 4 +- clippy.toml | 2 + crates/axum-utils/Cargo.toml | 2 +- crates/cli/Cargo.toml | 5 +- crates/cli/src/app_state.rs | 13 +- crates/cli/src/commands/server.rs | 3 + .../handlers/src/upstream_oauth2/authorize.rs | 4 +- .../handlers/src/upstream_oauth2/callback.rs | 5 +- crates/http/Cargo.toml | 20 +- crates/http/src/layers/mod.rs | 4 +- crates/http/src/lib.rs | 6 +- crates/iana-codegen/Cargo.toml | 4 +- crates/iana-codegen/src/main.rs | 4 + crates/iana-codegen/src/traits.rs | 4 + crates/oidc-client/Cargo.toml | 2 - crates/oidc-client/src/error.rs | 326 +----------------- crates/oidc-client/src/lib.rs | 1 - .../src/requests/account_management.rs | 121 ------- .../src/requests/authorization_code.rs | 121 +------ .../src/requests/client_credentials.rs | 6 +- crates/oidc-client/src/requests/discovery.rs | 3 +- .../oidc-client/src/requests/introspection.rs | 145 -------- crates/oidc-client/src/requests/jose.rs | 3 +- crates/oidc-client/src/requests/mod.rs | 4 - .../oidc-client/src/requests/refresh_token.rs | 5 +- .../oidc-client/src/requests/registration.rs | 91 ----- crates/oidc-client/src/requests/revocation.rs | 82 ----- crates/oidc-client/src/requests/token.rs | 24 +- crates/oidc-client/src/requests/userinfo.rs | 38 +- .../src/types/client_credentials.rs | 316 +++++------------ crates/oidc-client/src/utils/mod.rs | 31 -- crates/oidc-client/tests/it/main.rs | 16 +- .../tests/it/requests/account_management.rs | 127 ------- .../tests/it/requests/authorization_code.rs | 127 +------ .../tests/it/requests/discovery.rs | 8 +- .../tests/it/requests/introspection.rs | 100 ------ crates/oidc-client/tests/it/requests/mod.rs | 4 - .../tests/it/requests/registration.rs | 250 -------------- .../tests/it/requests/revocation.rs | 74 ---- .../tests/it/types/client_credentials.rs | 36 +- 41 files changed, 203 insertions(+), 1941 deletions(-) delete mode 100644 crates/oidc-client/src/requests/account_management.rs delete mode 100644 crates/oidc-client/src/requests/introspection.rs delete mode 100644 crates/oidc-client/src/requests/registration.rs delete mode 100644 crates/oidc-client/src/requests/revocation.rs delete mode 100644 crates/oidc-client/src/utils/mod.rs delete mode 100644 crates/oidc-client/tests/it/requests/account_management.rs delete mode 100644 crates/oidc-client/tests/it/requests/introspection.rs delete mode 100644 crates/oidc-client/tests/it/requests/registration.rs delete mode 100644 crates/oidc-client/tests/it/requests/revocation.rs diff --git a/Cargo.lock b/Cargo.lock index 6fa9c412a..7373d128a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2422,7 +2422,6 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", ] [[package]] @@ -3255,6 +3254,7 @@ dependencies = [ "prometheus", "rand", "rand_chacha", + "reqwest", "rustls", "sentry", "sentry-tower", @@ -5063,7 +5063,6 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", "windows-registry", ] diff --git a/Cargo.toml b/Cargo.toml index 09fe978cf..68a859253 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -273,12 +273,12 @@ features = ["rt"] # Tower services [workspace.dependencies.tower] version = "0.5.1" -features = ["util"] +features = ["util", "limit"] # Tower HTTP layers [workspace.dependencies.tower-http] version = "0.6.1" -features = ["cors", "fs", "add-extension"] +features = ["cors", "fs", "add-extension", "set-header", "follow-redirect"] # Logging and tracing [workspace.dependencies.tracing] diff --git a/clippy.toml b/clippy.toml index a300ae05b..ac0f49bf4 100644 --- a/clippy.toml +++ b/clippy.toml @@ -5,6 +5,8 @@ disallowed-methods = [ { path = "chrono::Utc::now", reason = "source the current time from the clock instead" }, { path = "ulid::Ulid::from_datetime", reason = "use Ulid::from_datetime_with_source instead" }, { path = "ulid::Ulid::new", reason = "use Ulid::from_datetime_with_source instead" }, + { path = "reqwest::Client::new", reason = "use mas_http::reqwest_client instead" }, + { path = "reqwest::RequestBuilder::send", reason = "use send_traced instead" }, ] disallowed-types = [ diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 844724faf..8d16c3f22 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -41,7 +41,7 @@ ulid.workspace = true oauth2-types.workspace = true mas-data-model.workspace = true -mas-http = { workspace = true, features = ["client"] } +mas-http.workspace = true mas-iana.workspace = true mas-jose.workspace = true mas-keystore.workspace = true diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index d5bc9dac8..45dd9f8c6 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -31,6 +31,7 @@ itertools = "0.13.0" listenfd = "1.0.1" rand.workspace = true rand_chacha = "0.3.1" +reqwest.workspace = true rustls.workspace = true serde_json.workspace = true serde_yaml = "0.9.34" @@ -71,8 +72,8 @@ sentry-tower.workspace = true mas-config.workspace = true mas-data-model.workspace = true mas-email.workspace = true -mas-handlers = { workspace = true } -mas-http = { workspace = true, features = ["client"] } +mas-handlers.workspace = true +mas-http.workspace = true mas-i18n.workspace = true mas-iana.workspace = true mas-keystore.workspace = true diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index b688b6fe3..ac511ab01 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -44,6 +44,7 @@ pub struct AppState { pub policy_factory: Arc, pub graphql_schema: GraphQLSchema, pub http_client_factory: HttpClientFactory, + pub http_client: reqwest::Client, pub password_manager: PasswordManager, pub metadata_cache: MetadataCache, pub site_config: SiteConfig, @@ -116,13 +117,9 @@ impl AppState { let mut repo = PgRepository::from_conn(conn); - let http_service = self - .http_client_factory - .http_service("upstream_oauth2.metadata"); - self.metadata_cache .warm_up_and_run( - http_service, + &self.http_client, std::time::Duration::from_secs(60 * 15), &mut repo, ) @@ -179,6 +176,12 @@ impl FromRef for HttpClientFactory { } } +impl FromRef for reqwest::Client { + fn from_ref(input: &AppState) -> Self { + input.http_client.clone() + } +} + impl FromRef for PasswordManager { fn from_ref(input: &AppState) -> Self { input.password_manager.clone() diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 8386b4689..d32b99099 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -150,6 +150,8 @@ impl Options { let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); + let homeserver_connection = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), @@ -242,6 +244,7 @@ impl Options { policy_factory, graphql_schema, http_client_factory, + http_client, password_manager, metadata_cache, site_config, diff --git a/crates/handlers/src/upstream_oauth2/authorize.rs b/crates/handlers/src/upstream_oauth2/authorize.rs index 21f86a80a..61e7e9ac1 100644 --- a/crates/handlers/src/upstream_oauth2/authorize.rs +++ b/crates/handlers/src/upstream_oauth2/authorize.rs @@ -9,9 +9,7 @@ use axum::{ response::{IntoResponse, Redirect}, }; use hyper::StatusCode; -use mas_axum_utils::{ - cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID, -}; +use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID}; use mas_data_model::UpstreamOAuthProvider; use mas_oidc_client::requests::authorization_code::AuthorizationRequestData; use mas_router::UrlBuilder; diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index f8afc7849..6b070e6dd 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -9,9 +9,7 @@ use axum::{ response::IntoResponse, }; use hyper::StatusCode; -use mas_axum_utils::{ - cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID, -}; +use mas_axum_utils::{cookies::CookieJar, sentry::SentryEventID}; use mas_data_model::UpstreamOAuthProvider; use mas_keystore::{Encrypter, Keystore}; use mas_oidc_client::requests::{ @@ -128,7 +126,6 @@ impl IntoResponse for RouteError { pub(crate) async fn get( mut rng: BoxRng, clock: BoxClock, - State(http_client_factory): State, State(metadata_cache): State, mut repo: BoxRepository, State(url_builder): State, diff --git a/crates/http/Cargo.toml b/crates/http/Cargo.toml index f9967d706..75c2b9eb8 100644 --- a/crates/http/Cargo.toml +++ b/crates/http/Cargo.toml @@ -21,12 +21,12 @@ http-body.workspace = true http-body-util.workspace = true hyper.workspace = true hyper-util.workspace = true -hyper-rustls = { workspace = true, optional = true } +hyper-rustls.workspace = true opentelemetry.workspace = true opentelemetry-http.workspace = true opentelemetry-semantic-conventions.workspace = true -rustls = { workspace = true, optional = true } -rustls-platform-verifier = { workspace = true, optional = true } +rustls.workspace = true +rustls-platform-verifier.workspace = true pin-project-lite = "0.2.14" reqwest.workspace = true serde.workspace = true @@ -38,20 +38,8 @@ tower-http.workspace = true tracing.workspace = true tracing-opentelemetry.workspace = true -mas-tower = { workspace = true, optional = true } +mas-tower.workspace = true [dev-dependencies] anyhow.workspace = true tokio.workspace = true - -[features] -client = [ - "dep:mas-tower", - "dep:rustls", - "dep:hyper-rustls", - "dep:rustls-platform-verifier", - "tower/limit", - "tower-http/timeout", - "tower-http/follow-redirect", - "tower-http/set-header", -] diff --git a/crates/http/src/layers/mod.rs b/crates/http/src/layers/mod.rs index cbf67db31..74fbf4fa1 100644 --- a/crates/http/src/layers/mod.rs +++ b/crates/http/src/layers/mod.rs @@ -7,9 +7,7 @@ pub mod body_to_bytes_response; pub mod bytes_to_body_request; pub mod catch_http_codes; +pub mod client; pub mod form_urlencoded_request; pub mod json_request; pub mod json_response; - -#[cfg(feature = "client")] -pub(crate) mod client; diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 64fcf4303..7b7ea661c 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -9,27 +9,23 @@ #![deny(rustdoc::missing_crate_level_docs)] #![allow(clippy::module_name_repetitions)] -#[cfg(feature = "client")] mod client; mod ext; mod layers; mod reqwest; mod service; -#[cfg(feature = "client")] pub use self::{ client::{ make_traced_connector, make_untraced_client, Client, TracedClient, TracedConnector, UntracedClient, UntracedConnector, }, - layers::client::{ClientLayer, ClientService}, -}; -pub use self::{ ext::{set_propagator, CorsLayerExt, ServiceExt as HttpServiceExt}, layers::{ body_to_bytes_response::{self, BodyToBytesResponse, BodyToBytesResponseLayer}, bytes_to_body_request::{self, BytesToBodyRequest, BytesToBodyRequestLayer}, catch_http_codes::{self, CatchHttpCodes, CatchHttpCodesLayer}, + client::{ClientLayer, ClientService}, form_urlencoded_request::{self, FormUrlencodedRequest, FormUrlencodedRequestLayer}, json_request::{self, JsonRequest, JsonRequestLayer}, json_response::{self, JsonResponse, JsonResponseLayer}, diff --git a/crates/iana-codegen/Cargo.toml b/crates/iana-codegen/Cargo.toml index 8f09616b9..6e81d2901 100644 --- a/crates/iana-codegen/Cargo.toml +++ b/crates/iana-codegen/Cargo.toml @@ -18,9 +18,7 @@ camino.workspace = true convert_case = "0.6.0" csv = "1.3.0" futures-util = "0.3.31" -reqwest = { version = "0.12.8", default-features = false, features = [ - "rustls-tls", -] } +reqwest.workspace = true serde.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/crates/iana-codegen/src/main.rs b/crates/iana-codegen/src/main.rs index 253c2ca52..2a779c44d 100644 --- a/crates/iana-codegen/src/main.rs +++ b/crates/iana-codegen/src/main.rs @@ -190,6 +190,10 @@ async fn main() -> anyhow::Result<()> { .pretty() .init(); + #[expect( + clippy::disallowed_methods, + reason = "reqwest::Client::new should be disallowed by clippy, but for the codegen it's fine" + )] let client = Client::new(); let iana_crate_root = Utf8Path::new("crates/iana/"); diff --git a/crates/iana-codegen/src/traits.rs b/crates/iana-codegen/src/traits.rs index 2e6790c3a..39fb01a5c 100644 --- a/crates/iana-codegen/src/traits.rs +++ b/crates/iana-codegen/src/traits.rs @@ -66,6 +66,10 @@ pub trait EnumEntry: DeserializeOwned + Send + Sync { async fn fetch(client: &Client) -> anyhow::Result> { tracing::info!("Fetching CSV"); + #[expect( + clippy::disallowed_methods, + reason = "we don't use send_traced in the codegen" + )] let response = client .get(Self::URL) .header("User-Agent", "mas-iana-codegen/0.1") diff --git a/crates/oidc-client/Cargo.toml b/crates/oidc-client/Cargo.toml index 04d55706a..2dfb80a49 100644 --- a/crates/oidc-client/Cargo.toml +++ b/crates/oidc-client/Cargo.toml @@ -51,5 +51,3 @@ tokio.workspace = true wiremock = "0.6.2" http-body-util.workspace = true rustls.workspace = true - -mas-http = { workspace = true, features = ["client"] } diff --git a/crates/oidc-client/src/error.rs b/crates/oidc-client/src/error.rs index 79260ca8a..247b0aa2b 100644 --- a/crates/oidc-client/src/error.rs +++ b/crates/oidc-client/src/error.rs @@ -6,11 +6,9 @@ //! The error types used in this crate. -use std::{str::Utf8Error, sync::Arc}; - use headers::authorization::InvalidBearerToken; -use http::{header::ToStrError, StatusCode}; -use mas_http::{catch_http_codes, form_urlencoded_request, json_request, json_response}; +use http::StatusCode; +use mas_http::{catch_http_codes, form_urlencoded_request, json_response}; use mas_jose::{ claims::ClaimError, jwa::InvalidAlgorithm, @@ -33,9 +31,6 @@ pub enum Error { /// An error occurred fetching the provider JWKS. Jwks(#[from] JwksError), - /// An error occurred during client registration. - Registration(#[from] RegistrationError), - /// An error occurred building the authorization URL. Authorization(#[from] AuthorizationError), @@ -48,17 +43,11 @@ pub enum Error { /// An error occurred refreshing an access token. TokenRefresh(#[from] TokenRefreshError), - /// An error occurred revoking a token. - TokenRevoke(#[from] TokenRevokeError), - /// An error occurred requesting user info. UserInfo(#[from] UserInfoError), /// An error occurred introspecting a token. Introspection(#[from] IntrospectionError), - - /// An error occurred building the account management URL. - AccountManagement(#[from] AccountManagementError), } /// All possible errors when fetching provider metadata. @@ -81,135 +70,6 @@ pub enum DiscoveryError { Disabled, } -/// All possible errors when registering the client. -#[derive(Debug, Error)] -pub enum RegistrationError { - /// An error occurred building the request. - #[error(transparent)] - IntoHttp(#[from] http::Error), - - /// An error occurred serializing the request or deserializing the response. - #[error(transparent)] - Json(#[from] serde_json::Error), - - /// The server returned an HTTP error status code. - #[error(transparent)] - Http(#[from] HttpError), - - /// No client secret was received although one was expected because of the - /// authentication method. - #[error("missing client secret in response")] - MissingClientSecret, - - /// An error occurred sending the request. - #[error(transparent)] - Service(BoxError), -} - -impl From> for RegistrationError -where - S: Into, -{ - fn from(err: json_request::Error) -> Self { - match err { - json_request::Error::Serialize { inner } => inner.into(), - json_request::Error::Service { inner } => inner.into(), - } - } -} - -impl From> for RegistrationError -where - S: Into, -{ - fn from(err: json_response::Error) -> Self { - match err { - json_response::Error::Deserialize { inner } => inner.into(), - json_response::Error::Service { inner } => inner.into(), - } - } -} - -impl From>> for RegistrationError -where - S: Into, -{ - fn from(err: catch_http_codes::Error>) -> Self { - match err { - catch_http_codes::Error::HttpError { status_code, inner } => { - HttpError::new(status_code, inner).into() - } - catch_http_codes::Error::Service { inner } => Self::Service(inner.into()), - } - } -} - -/// All possible errors when making a pushed authorization request. -#[derive(Debug, Error)] -pub enum PushedAuthorizationError { - /// An error occurred serializing the request. - #[error(transparent)] - UrlEncoded(#[from] serde_urlencoded::ser::Error), - - /// An error occurred building the request. - #[error(transparent)] - IntoHttp(#[from] http::Error), - - /// An error occurred adding the client credentials to the request. - #[error(transparent)] - Credentials(#[from] CredentialsError), - - /// The server returned an HTTP error status code. - #[error(transparent)] - Http(#[from] HttpError), - - /// An error occurred deserializing the response. - #[error(transparent)] - Json(#[from] serde_json::Error), - - /// An error occurred sending the request. - #[error(transparent)] - Service(BoxError), -} - -impl From> for PushedAuthorizationError -where - S: Into, -{ - fn from(err: form_urlencoded_request::Error) -> Self { - match err { - form_urlencoded_request::Error::Serialize { inner } => inner.into(), - form_urlencoded_request::Error::Service { inner } => inner.into(), - } - } -} - -impl From> for PushedAuthorizationError -where - S: Into, -{ - fn from(err: json_response::Error) -> Self { - match err { - json_response::Error::Deserialize { inner } => inner.into(), - json_response::Error::Service { inner } => inner.into(), - } - } -} - -impl From>> for PushedAuthorizationError -where - S: Into, -{ - fn from(err: catch_http_codes::Error>) -> Self { - match err { - catch_http_codes::Error::HttpError { status_code, inner } => { - HttpError::new(status_code, inner).into() - } - catch_http_codes::Error::Service { inner } => Self::Service(inner.into()), - } - } -} - /// All possible errors when authorizing the client. #[derive(Debug, Error)] pub enum AuthorizationError { @@ -220,76 +80,18 @@ pub enum AuthorizationError { /// An error occurred serializing the request. #[error(transparent)] UrlEncoded(#[from] serde_urlencoded::ser::Error), - - /// An error occurred making the PAR request. - #[error(transparent)] - PushedAuthorization(#[from] PushedAuthorizationError), } /// All possible errors when requesting an access token. #[derive(Debug, Error)] pub enum TokenRequestError { - /// An error occurred building the request. + /// The HTTP client returned an error. #[error(transparent)] - IntoHttp(#[from] http::Error), + Http(#[from] reqwest::Error), - /// An error occurred adding the client credentials to the request. + /// Error while injecting the client credentials into the request. #[error(transparent)] Credentials(#[from] CredentialsError), - - /// An error occurred serializing the request. - #[error(transparent)] - UrlEncoded(#[from] serde_urlencoded::ser::Error), - - /// The server returned an HTTP error status code. - #[error(transparent)] - Http(#[from] HttpError), - - /// An error occurred deserializing the response. - #[error(transparent)] - Json(#[from] serde_json::Error), - - /// An error occurred sending the request. - #[error(transparent)] - Service(BoxError), -} - -impl From> for TokenRequestError -where - S: Into, -{ - fn from(err: form_urlencoded_request::Error) -> Self { - match err { - form_urlencoded_request::Error::Serialize { inner } => inner.into(), - form_urlencoded_request::Error::Service { inner } => inner.into(), - } - } -} - -impl From> for TokenRequestError -where - S: Into, -{ - fn from(err: json_response::Error) -> Self { - match err { - json_response::Error::Deserialize { inner } => inner.into(), - json_response::Error::Service { inner } => inner.into(), - } - } -} - -impl From>> for TokenRequestError -where - S: Into, -{ - fn from(err: catch_http_codes::Error>) -> Self { - match err { - catch_http_codes::Error::HttpError { status_code, inner } => { - HttpError::new(status_code, inner).into() - } - catch_http_codes::Error::Service { inner } => Self::Service(inner.into()), - } - } } /// All possible errors when exchanging a code for an access token. @@ -316,95 +118,13 @@ pub enum TokenRefreshError { IdToken(#[from] IdTokenError), } -/// All possible errors when revoking a token. -#[derive(Debug, Error)] -pub enum TokenRevokeError { - /// An error occurred building the request. - #[error(transparent)] - IntoHttp(#[from] http::Error), - - /// An error occurred adding the client credentials to the request. - #[error(transparent)] - Credentials(#[from] CredentialsError), - - /// An error occurred serializing the request. - #[error(transparent)] - UrlEncoded(#[from] serde_urlencoded::ser::Error), - - /// An error occurred deserializing the error response. - #[error(transparent)] - Json(#[from] serde_json::Error), - - /// The server returned an HTTP error status code. - #[error(transparent)] - Http(#[from] HttpError), - - /// An error occurred sending the request. - #[error(transparent)] - Service(BoxError), -} - -impl From> for TokenRevokeError -where - S: Into, -{ - fn from(err: form_urlencoded_request::Error) -> Self { - match err { - form_urlencoded_request::Error::Serialize { inner } => inner.into(), - form_urlencoded_request::Error::Service { inner } => inner.into(), - } - } -} - -impl From>> for TokenRevokeError -where - S: Into, -{ - fn from(err: catch_http_codes::Error>) -> Self { - match err { - catch_http_codes::Error::HttpError { status_code, inner } => { - HttpError::new(status_code, inner).into() - } - catch_http_codes::Error::Service { inner } => Self::Service(inner.into()), - } - } -} - /// All possible errors when requesting user info. #[derive(Debug, Error)] pub enum UserInfoError { - /// An error occurred getting the provider metadata. - #[error(transparent)] - Discovery(#[from] Arc), - - /// The provider doesn't support requesting user info. - #[error("missing UserInfo support")] - MissingUserInfoSupport, - - /// No token is available to get info from. - #[error("missing token")] - MissingToken, - - /// No client metadata is available. - #[error("missing client metadata")] - MissingClientMetadata, - - /// The access token is invalid. - #[error(transparent)] - Token(#[from] InvalidBearerToken), - - /// An error occurred building the request. - #[error(transparent)] - IntoHttp(#[from] http::Error), - /// The content-type header is missing from the response. #[error("missing response content-type")] MissingResponseContentType, - /// The content-type header could not be decoded. - #[error("could not decoded response content-type: {0}")] - DecodeResponseContentType(#[from] ToStrError), - /// The content-type is not valid. #[error("invalid response content-type")] InvalidResponseContentTypeValue, @@ -418,39 +138,13 @@ pub enum UserInfoError { got: String, }, - /// An error occurred reading the response. - #[error(transparent)] - FromUtf8(#[from] Utf8Error), - - /// An error occurred deserializing the JSON or error response. - #[error(transparent)] - Json(#[from] serde_json::Error), - /// An error occurred verifying the Id Token. #[error(transparent)] IdToken(#[from] IdTokenError), - /// The server returned an HTTP error status code. - #[error(transparent)] - Http(#[from] HttpError), - /// An error occurred sending the request. #[error(transparent)] - Service(BoxError), -} - -impl From>> for UserInfoError -where - S: Into, -{ - fn from(err: catch_http_codes::Error>) -> Self { - match err { - catch_http_codes::Error::HttpError { status_code, inner } => { - HttpError::new(status_code, inner).into() - } - catch_http_codes::Error::Service { inner } => Self::Service(inner.into()), - } - } + Http(#[from] reqwest::Error), } /// All possible errors when introspecting a token. @@ -644,11 +338,3 @@ pub enum CredentialsError { #[error(transparent)] Custom(BoxError), } - -/// All errors that can occur when building the account management URL. -#[derive(Debug, Error)] -pub enum AccountManagementError { - /// An error occurred serializing the parameters. - #[error(transparent)] - UrlEncoded(#[from] serde_urlencoded::ser::Error), -} diff --git a/crates/oidc-client/src/lib.rs b/crates/oidc-client/src/lib.rs index f038ff012..8ed31de08 100644 --- a/crates/oidc-client/src/lib.rs +++ b/crates/oidc-client/src/lib.rs @@ -52,7 +52,6 @@ pub mod error; pub mod http_service; pub mod requests; pub mod types; -mod utils; use std::fmt; diff --git a/crates/oidc-client/src/requests/account_management.rs b/crates/oidc-client/src/requests/account_management.rs deleted file mode 100644 index 2a0158c43..000000000 --- a/crates/oidc-client/src/requests/account_management.rs +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Methods related to the account management URL. -//! -//! This is a Matrix extension introduced in [MSC2965](https://github.com/matrix-org/matrix-spec-proposals/pull/2965). - -use serde::Serialize; -use serde_with::skip_serializing_none; -use url::Url; - -use crate::error::AccountManagementError; - -/// An account management action that a user can take, including a device ID for -/// the actions that support it. -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -#[serde(tag = "action")] -#[non_exhaustive] -pub enum AccountManagementActionFull { - /// `org.matrix.profile` - /// - /// The user wishes to view their profile (name, avatar, contact details). - #[serde(rename = "org.matrix.profile")] - Profile, - - /// `org.matrix.sessions_list` - /// - /// The user wishes to view a list of their sessions. - #[serde(rename = "org.matrix.sessions_list")] - SessionsList, - - /// `org.matrix.session_view` - /// - /// The user wishes to view the details of a specific session. - #[serde(rename = "org.matrix.session_view")] - SessionView { - /// The ID of the session to view the details of. - device_id: String, - }, - - /// `org.matrix.session_end` - /// - /// The user wishes to end/log out of a specific session. - #[serde(rename = "org.matrix.session_end")] - SessionEnd { - /// The ID of the session to end. - device_id: String, - }, - - /// `org.matrix.account_deactivate` - /// - /// The user wishes to deactivate their account. - #[serde(rename = "org.matrix.account_deactivate")] - AccountDeactivate, - - /// `org.matrix.cross_signing_reset` - /// - /// The user wishes to reset their cross-signing keys. - #[serde(rename = "org.matrix.cross_signing_reset")] - CrossSigningReset, -} - -#[skip_serializing_none] -#[derive(Debug, Clone, Serialize)] -struct AccountManagementData { - #[serde(flatten)] - action: Option, - id_token_hint: Option, -} - -/// Build the URL for accessing the account management capabilities. -/// -/// # Arguments -/// -/// * `account_management_uri` - The URL to access the issuer's account -/// management capabilities. -/// -/// * `action` - The action that the user wishes to take. -/// -/// * `id_token_hint` - An ID Token that was previously issued to the client, -/// used as a hint for which user is requesting to manage their account. -/// -/// # Returns -/// -/// A URL to be opened in a web browser where the end-user will be able to -/// access the account management capabilities of the issuer. -/// -/// # Errors -/// -/// Returns an error if serializing the URL fails. -pub fn build_account_management_url( - mut account_management_uri: Url, - action: Option, - id_token_hint: Option, -) -> Result { - let data = AccountManagementData { - action, - id_token_hint, - }; - let extra_query = serde_urlencoded::to_string(data)?; - - if !extra_query.is_empty() { - // Add our parameters to the query, because the URL might already have one. - let mut full_query = account_management_uri - .query() - .map(ToOwned::to_owned) - .unwrap_or_default(); - - if !full_query.is_empty() { - full_query.push('&'); - } - full_query.push_str(&extra_query); - - account_management_uri.set_query(Some(&full_query)); - } - - Ok(account_management_uri) -} diff --git a/crates/oidc-client/src/requests/authorization_code.rs b/crates/oidc-client/src/requests/authorization_code.rs index 39a2dff4d..9295efdcd 100644 --- a/crates/oidc-client/src/requests/authorization_code.rs +++ b/crates/oidc-client/src/requests/authorization_code.rs @@ -12,9 +12,7 @@ use std::{collections::HashSet, num::NonZeroU32}; use base64ct::{Base64UrlUnpadded, Encoding}; use chrono::{DateTime, Utc}; -use http::header::CONTENT_TYPE; use language_tags::LanguageTag; -use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer}; use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod}; use mas_jose::claims::{self, TokenHash}; use oauth2_types::{ @@ -22,7 +20,7 @@ use oauth2_types::{ prelude::CodeChallengeMethodExt, requests::{ AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest, - Display, Prompt, PushedAuthorizationResponse, + Display, Prompt, }, scope::Scope, }; @@ -32,22 +30,17 @@ use rand::{ }; use serde::Serialize; use serde_with::skip_serializing_none; -use tower::{Layer, Service, ServiceExt}; use url::Url; use super::jose::JwtVerificationData; use crate::{ - error::{ - AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError, - }, - http_service::HttpService, + error::{AuthorizationError, IdTokenError, TokenAuthorizationCodeError}, requests::{jose::verify_id_token, token::request_access_token}, types::{ client_credentials::ClientCredentials, scope::{ScopeExt, ScopeToken}, IdToken, }, - utils::{http_all_error_status_codes, http_error_mapper}, }; /// The data necessary to build an authorization request. @@ -320,7 +313,6 @@ fn build_authorization_request( /// /// [`VerifiedClientMetadata`]: oauth2_types::registration::VerifiedClientMetadata /// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode -#[allow(clippy::too_many_lines)] pub fn build_authorization_url( authorization_endpoint: Url, authorization_data: AuthorizationRequestData, @@ -353,115 +345,6 @@ pub fn build_authorization_url( Ok((authorization_url, validation_data)) } -/// Make a [Pushed Authorization Request] and build the URL for authenticating -/// at the Authorization endpoint. -/// -/// # Arguments -/// -/// * `http_service` - The service to use for making HTTP requests. -/// -/// * `client_credentials` - The credentials obtained when registering the -/// client. -/// -/// * `par_endpoint` - The URL of the issuer's Pushed Authorization Request -/// endpoint. -/// -/// * `authorization_endpoint` - The URL of the issuer's Authorization endpoint. -/// -/// * `authorization_data` - The data necessary to build the authorization -/// request. -/// -/// * `now` - The current time. -/// -/// * `rng` - A random number generator. -/// -/// # Returns -/// -/// A URL to be opened in a web browser where the end-user will be able to -/// authorize the given scope, and the [`AuthorizationValidationData`] to -/// validate this request. -/// -/// The redirect URI will receive parameters in its query: -/// -/// * A successful response will receive a `code` and a `state`. -/// -/// * If the authorization fails, it should receive an `error` parameter with a -/// [`ClientErrorCode`] and optionally an `error_description`. -/// -/// # Errors -/// -/// Returns an error if the request fails, the response is invalid or building -/// the URL fails. -/// -/// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/ -/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode -#[tracing::instrument(skip_all, fields(par_endpoint))] -pub async fn build_par_authorization_url( - http_service: &HttpService, - client_credentials: ClientCredentials, - par_endpoint: &Url, - authorization_endpoint: Url, - authorization_data: AuthorizationRequestData, - now: DateTime, - rng: &mut impl Rng, -) -> Result<(Url, AuthorizationValidationData), AuthorizationError> { - tracing::debug!( - scope = ?authorization_data.scope, - "Authorizing with a PAR..." - ); - - let client_id = client_credentials.client_id().to_owned(); - - let (authorization_request, validation_data) = - build_authorization_request(authorization_data, rng)?; - - let par_request = http::Request::post(par_endpoint.as_str()) - .header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref()) - .body(authorization_request) - .map_err(PushedAuthorizationError::from)?; - - let par_request = client_credentials - .apply_to_request(par_request, now, rng) - .map_err(PushedAuthorizationError::from)?; - - let service = ( - FormUrlencodedRequestLayer::default(), - JsonResponseLayer::::default(), - CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper), - ) - .layer(http_service.clone()); - - let par_response = service - .ready_oneshot() - .await - .map_err(PushedAuthorizationError::from)? - .call(par_request) - .await - .map_err(PushedAuthorizationError::from)? - .into_body(); - - let authorization_query = serde_urlencoded::to_string([ - ("request_uri", par_response.request_uri.as_str()), - ("client_id", &client_id), - ])?; - - let mut authorization_url = authorization_endpoint; - - // Add our parameters to the query, because the URL might already have one. - let mut full_query = authorization_url - .query() - .map(ToOwned::to_owned) - .unwrap_or_default(); - if !full_query.is_empty() { - full_query.push('&'); - } - full_query.push_str(&authorization_query); - - authorization_url.set_query(Some(&full_query)); - - Ok((authorization_url, validation_data)) -} - /// Exchange an authorization code for an access token. /// /// This should be used as the first step for logging in, and to request a diff --git a/crates/oidc-client/src/requests/client_credentials.rs b/crates/oidc-client/src/requests/client_credentials.rs index f87779682..ace203c5f 100644 --- a/crates/oidc-client/src/requests/client_credentials.rs +++ b/crates/oidc-client/src/requests/client_credentials.rs @@ -17,7 +17,7 @@ use rand::Rng; use url::Url; use crate::{ - error::TokenRequestError, http_service::HttpService, requests::token::request_access_token, + error::TokenRequestError, requests::token::request_access_token, types::client_credentials::ClientCredentials, }; @@ -46,7 +46,7 @@ use crate::{ /// Returns an error if the request fails or the response is invalid. #[tracing::instrument(skip_all, fields(token_endpoint))] pub async fn access_token_with_client_credentials( - http_service: &HttpService, + http_client: &reqwest::Client, client_credentials: ClientCredentials, token_endpoint: &Url, scope: Option, @@ -56,7 +56,7 @@ pub async fn access_token_with_client_credentials( tracing::debug!("Requesting access token with client credentials..."); request_access_token( - http_service, + http_client, client_credentials, token_endpoint, AccessTokenRequest::ClientCredentials(ClientCredentialsGrant { scope }), diff --git a/crates/oidc-client/src/requests/discovery.rs b/crates/oidc-client/src/requests/discovery.rs index c6cfe5767..2e58cf37e 100644 --- a/crates/oidc-client/src/requests/discovery.rs +++ b/crates/oidc-client/src/requests/discovery.rs @@ -8,6 +8,7 @@ //! //! [Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html +use mas_http::RequestBuilderExt; use oauth2_types::oidc::{ProviderMetadata, VerifiedProviderMetadata}; use url::Url; @@ -34,7 +35,7 @@ async fn discover_inner( let response = client .get(config_url.as_str()) - .send() + .send_traced() .await? .error_for_status()? .json() diff --git a/crates/oidc-client/src/requests/introspection.rs b/crates/oidc-client/src/requests/introspection.rs deleted file mode 100644 index e82f2bcb1..000000000 --- a/crates/oidc-client/src/requests/introspection.rs +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Requests for [Token Introspection]. -//! -//! [Token Introspection]: https://www.rfc-editor.org/rfc/rfc7662 - -use chrono::{DateTime, Utc}; -use headers::{Authorization, HeaderMapExt}; -use http::Request; -use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer}; -use mas_iana::oauth::OAuthTokenTypeHint; -use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse}; -use rand::Rng; -use serde::Serialize; -use tower::{Layer, Service, ServiceExt}; -use url::Url; - -use crate::{ - error::IntrospectionError, - http_service::HttpService, - types::client_credentials::{ClientCredentials, RequestWithClientCredentials}, - utils::{http_all_error_status_codes, http_error_mapper}, -}; - -/// The method used to authenticate at the introspection endpoint. -pub enum IntrospectionAuthentication<'a> { - /// Using client authentication. - Credentials(ClientCredentials), - - /// Using a bearer token. - BearerToken(&'a str), -} - -impl<'a> IntrospectionAuthentication<'a> { - /// Constructs an `IntrospectionAuthentication` from the given client - /// credentials. - #[must_use] - pub fn with_client_credentials(credentials: ClientCredentials) -> Self { - Self::Credentials(credentials) - } - - /// Constructs an `IntrospectionAuthentication` from the given bearer token. - #[must_use] - pub fn with_bearer_token(token: &'a str) -> Self { - Self::BearerToken(token) - } - - fn apply_to_request( - self, - request: Request, - now: DateTime, - rng: &mut impl Rng, - ) -> Result>, IntrospectionError> { - let res = match self { - IntrospectionAuthentication::Credentials(client_credentials) => { - client_credentials.apply_to_request(request, now, rng)? - } - IntrospectionAuthentication::BearerToken(access_token) => { - let (mut parts, body) = request.into_parts(); - - parts - .headers - .typed_insert(Authorization::bearer(access_token)?); - - let body = RequestWithClientCredentials { - body, - credentials: None, - }; - - http::Request::from_parts(parts, body) - } - }; - - Ok(res) - } -} - -impl<'a> From for IntrospectionAuthentication<'a> { - fn from(credentials: ClientCredentials) -> Self { - Self::with_client_credentials(credentials) - } -} - -/// Obtain information about a token. -/// -/// # Arguments -/// -/// * `http_service` - The service to use for making HTTP requests. -/// -/// * `authentication` - The method used to authenticate the request. -/// -/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint. -/// -/// * `token` - The token to introspect. -/// -/// * `token_type_hint` - Hint about the type of the token. -/// -/// * `now` - The current time. -/// -/// * `rng` - A random number generator. -/// -/// # Errors -/// -/// Returns an error if the request fails or the response is invalid. -#[tracing::instrument(skip_all, fields(introspection_endpoint))] -pub async fn introspect_token( - http_service: &HttpService, - authentication: IntrospectionAuthentication<'_>, - introspection_endpoint: &Url, - token: String, - token_type_hint: Option, - now: DateTime, - rng: &mut impl Rng, -) -> Result { - tracing::debug!("Introspecting token…"); - - let introspection_request = IntrospectionRequest { - token, - token_type_hint, - }; - let introspection_request = - http::Request::post(introspection_endpoint.as_str()).body(introspection_request)?; - - let introspection_request = authentication.apply_to_request(introspection_request, now, rng)?; - - let service = ( - FormUrlencodedRequestLayer::default(), - JsonResponseLayer::::default(), - CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper), - ) - .layer(http_service.clone()); - - let introspection_response = service - .ready_oneshot() - .await? - .call(introspection_request) - .await? - .into_body(); - - Ok(introspection_response) -} diff --git a/crates/oidc-client/src/requests/jose.rs b/crates/oidc-client/src/requests/jose.rs index 0a901db3b..5f630b607 100644 --- a/crates/oidc-client/src/requests/jose.rs +++ b/crates/oidc-client/src/requests/jose.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; use chrono::{DateTime, Utc}; +use mas_http::RequestBuilderExt; use mas_iana::jose::JsonWebSignatureAlg; use mas_jose::{ claims::{self, TimeOptions}, @@ -43,7 +44,7 @@ pub async fn fetch_jwks( let response: PublicJsonWebKeySet = client .get(jwks_uri.as_str()) - .send() + .send_traced() .await? .error_for_status()? .json() diff --git a/crates/oidc-client/src/requests/mod.rs b/crates/oidc-client/src/requests/mod.rs index 5bc1c71ed..bb1d50020 100644 --- a/crates/oidc-client/src/requests/mod.rs +++ b/crates/oidc-client/src/requests/mod.rs @@ -6,15 +6,11 @@ //! Methods to interact with OpenID Connect and OAuth2.0 endpoints. -pub mod account_management; pub mod authorization_code; pub mod client_credentials; pub mod discovery; -pub mod introspection; pub mod jose; pub mod refresh_token; -pub mod registration; -pub mod revocation; pub mod rp_initiated_logout; pub mod token; pub mod userinfo; diff --git a/crates/oidc-client/src/requests/refresh_token.rs b/crates/oidc-client/src/requests/refresh_token.rs index 036f94a6b..68368cd1a 100644 --- a/crates/oidc-client/src/requests/refresh_token.rs +++ b/crates/oidc-client/src/requests/refresh_token.rs @@ -20,7 +20,6 @@ use url::Url; use super::jose::JwtVerificationData; use crate::{ error::{IdTokenError, TokenRefreshError}, - http_service::HttpService, requests::{jose::verify_id_token, token::request_access_token}, types::{client_credentials::ClientCredentials, IdToken}, }; @@ -68,7 +67,7 @@ use crate::{ #[allow(clippy::too_many_arguments)] #[tracing::instrument(skip_all, fields(token_endpoint))] pub async fn refresh_access_token( - http_service: &HttpService, + http_client: &reqwest::Client, client_credentials: ClientCredentials, token_endpoint: &Url, refresh_token: String, @@ -81,7 +80,7 @@ pub async fn refresh_access_token( tracing::debug!("Refreshing access token…"); let token_response = request_access_token( - http_service, + http_client, client_credentials, token_endpoint, AccessTokenRequest::RefreshToken(RefreshTokenGrant { diff --git a/crates/oidc-client/src/requests/registration.rs b/crates/oidc-client/src/requests/registration.rs deleted file mode 100644 index 3f22f65f6..000000000 --- a/crates/oidc-client/src/requests/registration.rs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Requests for [Dynamic Registration]. -//! -//! [Dynamic Registration]: https://openid.net/specs/openid-connect-registration-1_0.html - -use mas_http::{CatchHttpCodesLayer, JsonRequestLayer, JsonResponseLayer}; -use mas_iana::oauth::OAuthClientAuthenticationMethod; -use oauth2_types::registration::{ClientRegistrationResponse, VerifiedClientMetadata}; -use serde::Serialize; -use serde_with::skip_serializing_none; -use tower::{Layer, Service, ServiceExt}; -use url::Url; - -use crate::{ - error::RegistrationError, - http_service::HttpService, - utils::{http_all_error_status_codes, http_error_mapper}, -}; - -#[skip_serializing_none] -#[derive(Serialize)] -struct RegistrationRequest { - #[serde(flatten)] - client_metadata: VerifiedClientMetadata, - software_statement: Option, -} - -/// Register a client with an OpenID Provider. -/// -/// # Arguments -/// -/// * `http_service` - The service to use for making HTTP requests. -/// -/// * `registration_endpoint` - The URL of the issuer's Registration endpoint. -/// -/// * `client_metadata` - The metadata to register with the issuer. -/// -/// * `software_statement` - A JWT that asserts metadata values about the client -/// software that should be signed. -/// -/// # Errors -/// -/// Returns an error if the request fails or the response is invalid. -#[tracing::instrument(skip_all, fields(registration_endpoint))] -pub async fn register_client( - http_service: &HttpService, - registration_endpoint: &Url, - client_metadata: VerifiedClientMetadata, - software_statement: Option, -) -> Result { - tracing::debug!("Registering client..."); - - let should_receive_secret = matches!( - client_metadata.token_endpoint_auth_method(), - OAuthClientAuthenticationMethod::ClientSecretPost - | OAuthClientAuthenticationMethod::ClientSecretBasic - | OAuthClientAuthenticationMethod::ClientSecretJwt - ); - - let body = RegistrationRequest { - client_metadata, - software_statement, - }; - - let registration_req = http::Request::post(registration_endpoint.as_str()).body(body)?; - - let service = ( - JsonRequestLayer::default(), - JsonResponseLayer::::default(), - CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper), - ) - .layer(http_service.clone()); - - let response = service - .ready_oneshot() - .await? - .call(registration_req) - .await? - .into_body(); - - if should_receive_secret && response.client_secret.is_none() { - return Err(RegistrationError::MissingClientSecret); - } - - Ok(response) -} diff --git a/crates/oidc-client/src/requests/revocation.rs b/crates/oidc-client/src/requests/revocation.rs deleted file mode 100644 index 290e7c711..000000000 --- a/crates/oidc-client/src/requests/revocation.rs +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Requests for [Token Revocation]. -//! -//! [Token Revocation]: https://www.rfc-editor.org/rfc/rfc7009.html - -use chrono::{DateTime, Utc}; -use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer}; -use mas_iana::oauth::OAuthTokenTypeHint; -use oauth2_types::requests::IntrospectionRequest; -use rand::Rng; -use tower::{Layer, Service, ServiceExt}; -use url::Url; - -use crate::{ - error::TokenRevokeError, - http_service::HttpService, - types::client_credentials::ClientCredentials, - utils::{http_all_error_status_codes, http_error_mapper}, -}; - -/// Revoke a token. -/// -/// # Arguments -/// -/// * `http_service` - The service to use for making HTTP requests. -/// -/// * `client_credentials` - The credentials obtained when registering the -/// client. -/// -/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint. -/// -/// * `token` - The token to revoke. -/// -/// * `token_type_hint` - Hint about the type of the token. -/// -/// * `now` - The current time. -/// -/// * `rng` - A random number generator. -/// -/// # Errors -/// -/// Returns an error if the request fails or the response is invalid. -#[tracing::instrument(skip_all, fields(revocation_endpoint))] -pub async fn revoke_token( - http_service: &HttpService, - client_credentials: ClientCredentials, - revocation_endpoint: &Url, - token: String, - token_type_hint: Option, - now: DateTime, - rng: &mut impl Rng, -) -> Result<(), TokenRevokeError> { - tracing::debug!("Revoking token…"); - - let request = IntrospectionRequest { - token, - token_type_hint, - }; - - let revocation_request = http::Request::post(revocation_endpoint.as_str()).body(request)?; - - let revocation_request = client_credentials.apply_to_request(revocation_request, now, rng)?; - - let service = ( - FormUrlencodedRequestLayer::default(), - CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper), - ) - .layer(http_service.clone()); - - service - .ready_oneshot() - .await? - .call(revocation_request) - .await?; - - Ok(()) -} diff --git a/crates/oidc-client/src/requests/token.rs b/crates/oidc-client/src/requests/token.rs index 39c592eea..40832bee0 100644 --- a/crates/oidc-client/src/requests/token.rs +++ b/crates/oidc-client/src/requests/token.rs @@ -7,18 +7,12 @@ //! Requests for the Token endpoint. use chrono::{DateTime, Utc}; -use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer}; +use mas_http::RequestBuilderExt; use oauth2_types::requests::{AccessTokenRequest, AccessTokenResponse}; use rand::Rng; -use tower::{Layer, Service, ServiceExt}; use url::Url; -use crate::{ - error::TokenRequestError, - http_service::HttpService, - types::client_credentials::ClientCredentials, - utils::{http_all_error_status_codes, http_error_mapper}, -}; +use crate::{error::TokenRequestError, types::client_credentials::ClientCredentials}; /// Request an access token. /// @@ -51,13 +45,15 @@ pub async fn request_access_token( ) -> Result { tracing::debug!(?request, "Requesting access token..."); - let token_request = http_client.post(token_endpoint.as_str()).form(&request); + let token_request = http_client.post(token_endpoint.as_str()); - let token_request = client_credentials.apply_to_request(token_request, now, rng)?; - - let res = service.ready_oneshot().await?.call(token_request).await?; - - let token_response = res.into_body(); + let token_response = client_credentials + .authenticated_form(token_request, &request, now, rng)? + .send_traced() + .await? + .error_for_status()? + .json() + .await?; Ok(token_response) } diff --git a/crates/oidc-client/src/requests/userinfo.rs b/crates/oidc-client/src/requests/userinfo.rs index 053c6c99b..9718e537b 100644 --- a/crates/oidc-client/src/requests/userinfo.rs +++ b/crates/oidc-client/src/requests/userinfo.rs @@ -10,23 +10,19 @@ use std::collections::HashMap; -use bytes::Bytes; -use headers::{Authorization, ContentType, HeaderMapExt, HeaderValue}; +use headers::{ContentType, HeaderMapExt, HeaderValue}; use http::header::ACCEPT; -use mas_http::CatchHttpCodesLayer; +use mas_http::RequestBuilderExt; use mas_jose::claims; use mime::Mime; use serde_json::Value; -use tower::{Layer, Service, ServiceExt}; use url::Url; use super::jose::JwtVerificationData; use crate::{ error::{IdTokenError, UserInfoError}, - http_service::HttpService, requests::jose::verify_signed_jwt, types::IdToken, - utils::{http_all_error_status_codes, http_error_mapper}, }; /// Obtain information about an authenticated end-user. @@ -59,7 +55,7 @@ use crate::{ /// [`Claim`]: mas_jose::claims::Claim #[tracing::instrument(skip_all, fields(userinfo_endpoint))] pub async fn fetch_userinfo( - http_service: &HttpService, + http_client: &reqwest::Client, userinfo_endpoint: &Url, access_token: &str, jwt_verification_data: Option>, @@ -67,29 +63,18 @@ pub async fn fetch_userinfo( ) -> Result, UserInfoError> { tracing::debug!("Obtaining user info…"); - let mut userinfo_request = http::Request::get(userinfo_endpoint.as_str()); - let expected_content_type = if jwt_verification_data.is_some() { "application/jwt" } else { mime::APPLICATION_JSON.as_ref() }; - if let Some(headers) = userinfo_request.headers_mut() { - headers.typed_insert(Authorization::bearer(access_token)?); - headers.insert(ACCEPT, HeaderValue::from_static(expected_content_type)); - } - - let userinfo_request = userinfo_request.body(Bytes::new())?; + let userinfo_request = http_client + .get(userinfo_endpoint.as_str()) + .bearer_auth(access_token) + .header(ACCEPT, HeaderValue::from_static(expected_content_type)); - let service = CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper) - .layer(http_service.clone()); - - let userinfo_response = service - .ready_oneshot() - .await? - .call(userinfo_request) - .await?; + let userinfo_response = userinfo_request.send_traced().await?.error_for_status()?; let content_type: Mime = userinfo_response .headers() @@ -105,15 +90,14 @@ pub async fn fetch_userinfo( }); } - let response_body = std::str::from_utf8(userinfo_response.body())?; - let mut claims = if let Some(verification_data) = jwt_verification_data { - verify_signed_jwt(response_body, verification_data) + let response_body = userinfo_response.text().await?; + verify_signed_jwt(&response_body, verification_data) .map_err(IdTokenError::from)? .into_parts() .1 } else { - serde_json::from_str(response_body)? + userinfo_response.json().await? }; let mut auth_claims = auth_id_token.payload().clone(); diff --git a/crates/oidc-client/src/types/client_credentials.rs b/crates/oidc-client/src/types/client_credentials.rs index 09018ff1f..6fd6cb452 100644 --- a/crates/oidc-client/src/types/client_credentials.rs +++ b/crates/oidc-client/src/types/client_credentials.rs @@ -10,8 +10,6 @@ use std::{collections::HashMap, fmt, sync::Arc}; use base64ct::{Base64UrlUnpadded, Encoding}; use chrono::{DateTime, Duration, Utc}; -use headers::{Authorization, HeaderMapExt}; -use http::Request; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; #[cfg(feature = "keystore")] use mas_jose::constraints::Constrainable; @@ -25,7 +23,6 @@ use mas_keystore::Keystore; use rand::Rng; use serde::Serialize; use serde_json::Value; -use serde_with::skip_serializing_none; use tower::BoxError; use url::Url; @@ -175,131 +172,48 @@ impl ClientCredentials { } } - /// Apply these `ClientCredentials` to the given request. - pub(crate) fn apply_to_request( - self, + /// Apply these [`ClientCredentials`] to the given request with the given + /// form. + pub(crate) fn authenticated_form( + &self, request: reqwest::RequestBuilder, + form: &T, now: DateTime, rng: &mut impl Rng, ) -> Result { - // TODO: get the form in params, augment it and serialize - let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?; - - let (parts, body) = request.into_parts(); - let mut body = RequestWithClientCredentials { - body, - credentials: None, - }; - - let request = match credentials { - RequestClientCredentials::Body(credentials) => { - body.credentials = Some(credentials); - Request::from_parts(parts, body) - } - RequestClientCredentials::Header(credentials) => { - let HeaderClientCredentials { - client_id, - client_secret, - } = credentials; - - let mut request = Request::from_parts(parts, body); - - // Encode the values with `application/x-www-form-urlencoded`. - let client_id = - form_urlencoded::byte_serialize(client_id.as_bytes()).collect::(); - let client_secret = - form_urlencoded::byte_serialize(client_secret.as_bytes()).collect::(); - - let auth = Authorization::basic(&client_id, &client_secret); - request.headers_mut().typed_insert(auth); - - request - } - }; - - Ok(request) - } -} - -impl fmt::Debug for ClientCredentials { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::None { client_id } => f - .debug_struct("None") - .field("client_id", client_id) - .finish(), - Self::ClientSecretBasic { client_id, .. } => f - .debug_struct("ClientSecretBasic") - .field("client_id", client_id) - .finish_non_exhaustive(), - Self::ClientSecretPost { client_id, .. } => f - .debug_struct("ClientSecretPost") - .field("client_id", client_id) - .finish_non_exhaustive(), - Self::ClientSecretJwt { - client_id, - signing_algorithm, - token_endpoint, - .. - } => f - .debug_struct("ClientSecretJwt") - .field("client_id", client_id) - .field("signing_algorithm", signing_algorithm) - .field("token_endpoint", token_endpoint) - .finish_non_exhaustive(), - Self::PrivateKeyJwt { - client_id, - signing_algorithm, - token_endpoint, - .. - } => f - .debug_struct("PrivateKeyJwt") - .field("client_id", client_id) - .field("signing_algorithm", signing_algorithm) - .field("token_endpoint", token_endpoint) - .finish_non_exhaustive(), - } - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] -#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")] -pub(crate) struct JwtBearerClientAssertionType; - -enum RequestClientCredentials { - Body(BodyClientCredentials), - Header(HeaderClientCredentials), -} - -impl RequestClientCredentials { - fn try_from_credentials( - credentials: ClientCredentials, - now: DateTime, - rng: &mut impl Rng, - ) -> Result { - let res = match credentials { - ClientCredentials::None { client_id } => Self::Body(BodyClientCredentials { + let request = match self { + ClientCredentials::None { client_id } => request.form(&RequestWithClientCredentials { + body: form, client_id, client_secret: None, client_assertion: None, client_assertion_type: None, }), + ClientCredentials::ClientSecretBasic { client_id, client_secret, - } => Self::Header(HeaderClientCredentials { - client_id, - client_secret, - }), + } => request.basic_auth(client_id, Some(client_secret)).form( + &RequestWithClientCredentials { + body: form, + client_id, + client_secret: None, + client_assertion: None, + client_assertion_type: None, + }, + ), + ClientCredentials::ClientSecretPost { client_id, client_secret, - } => Self::Body(BodyClientCredentials { + } => request.form(&RequestWithClientCredentials { + body: form, client_id, client_secret: Some(client_secret), client_assertion: None, client_assertion_type: None, }), + ClientCredentials::ClientSecretJwt { client_id, client_secret, @@ -308,18 +222,23 @@ impl RequestClientCredentials { } => { let claims = prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?; - let key = SymmetricKey::new_for_alg(client_secret.into(), &signing_algorithm)?; - let header = JsonWebSignatureHeader::new(signing_algorithm); + let key = SymmetricKey::new_for_alg( + client_secret.as_bytes().to_vec(), + signing_algorithm, + )?; + let header = JsonWebSignatureHeader::new(signing_algorithm.clone()); let jwt = Jwt::sign(header, claims, &key)?; - Self::Body(BodyClientCredentials { + request.form(&RequestWithClientCredentials { + body: form, client_id, client_secret: None, - client_assertion: Some(jwt.to_string()), + client_assertion: Some(jwt.as_str()), client_assertion_type: Some(JwtBearerClientAssertionType), }) } + ClientCredentials::PrivateKeyJwt { client_id, jwt_signing_method, @@ -333,13 +252,13 @@ impl RequestClientCredentials { #[cfg(feature = "keystore")] JwtSigningMethod::Keystore(keystore) => { let key = keystore - .signing_key_for_algorithm(&signing_algorithm) + .signing_key_for_algorithm(signing_algorithm) .ok_or(CredentialsError::NoPrivateKeyFound)?; let signer = key .params() - .signing_key_for_alg(&signing_algorithm) + .signing_key_for_alg(signing_algorithm) .map_err(|_| CredentialsError::JwtWrongAlgorithm)?; - let mut header = JsonWebSignatureHeader::new(signing_algorithm); + let mut header = JsonWebSignatureHeader::new(signing_algorithm.clone()); if let Some(kid) = key.kid() { header = header.with_kid(kid); @@ -348,39 +267,69 @@ impl RequestClientCredentials { Jwt::sign(header, claims, &signer)?.to_string() } JwtSigningMethod::Custom(jwt_signing_fn) => { - jwt_signing_fn(claims, signing_algorithm) + jwt_signing_fn(claims, signing_algorithm.clone()) .map_err(CredentialsError::Custom)? } }; - Self::Body(BodyClientCredentials { + request.form(&RequestWithClientCredentials { + body: form, client_id, client_secret: None, - client_assertion: Some(client_assertion), + client_assertion: Some(&client_assertion), client_assertion_type: Some(JwtBearerClientAssertionType), }) } }; - Ok(res) + Ok(request) } } -#[allow(clippy::struct_field_names)] // All the fields start with `client_` -#[skip_serializing_none] -#[derive(Debug, Clone, PartialEq, Eq, Serialize)] -pub(crate) struct BodyClientCredentials { - client_id: String, - client_secret: Option, - client_assertion: Option, - client_assertion_type: Option, +impl fmt::Debug for ClientCredentials { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::None { client_id } => f + .debug_struct("None") + .field("client_id", client_id) + .finish(), + Self::ClientSecretBasic { client_id, .. } => f + .debug_struct("ClientSecretBasic") + .field("client_id", client_id) + .finish_non_exhaustive(), + Self::ClientSecretPost { client_id, .. } => f + .debug_struct("ClientSecretPost") + .field("client_id", client_id) + .finish_non_exhaustive(), + Self::ClientSecretJwt { + client_id, + signing_algorithm, + token_endpoint, + .. + } => f + .debug_struct("ClientSecretJwt") + .field("client_id", client_id) + .field("signing_algorithm", signing_algorithm) + .field("token_endpoint", token_endpoint) + .finish_non_exhaustive(), + Self::PrivateKeyJwt { + client_id, + signing_algorithm, + token_endpoint, + .. + } => f + .debug_struct("PrivateKeyJwt") + .field("client_id", client_id) + .field("signing_algorithm", signing_algorithm) + .field("token_endpoint", token_endpoint) + .finish_non_exhaustive(), + } + } } -#[derive(Debug, Clone)] -struct HeaderClientCredentials { - client_id: String, - client_secret: String, -} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")] +struct JwtBearerClientAssertionType; fn prepare_claims( iss: String, @@ -409,14 +358,20 @@ fn prepare_claims( /// A request with client credentials added to it. #[derive(Clone, Serialize)] -#[skip_serializing_none] -pub struct RequestWithClientCredentials { +struct RequestWithClientCredentials<'a, T> { #[serde(flatten)] - pub(crate) body: T, - #[serde(flatten)] - pub(crate) credentials: Option, + body: T, + + client_id: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + client_secret: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + client_assertion: Option<&'a str>, + #[serde(skip_serializing_if = "Option::is_none")] + client_assertion_type: Option, } +/* #[cfg(test)] mod test { use assert_matches::assert_matches; @@ -442,91 +397,6 @@ mod test { Utc::now() } - #[test] - fn serialize_credentials() { - assert_eq!( - serde_urlencoded::to_string(BodyClientCredentials { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_assertion: None, - client_assertion_type: None, - }) - .unwrap(), - "client_id=abcd%24%2B%2B" - ); - assert_eq!( - serde_urlencoded::to_string(BodyClientCredentials { - client_id: CLIENT_ID.to_owned(), - client_secret: Some(CLIENT_SECRET.to_owned()), - client_assertion: None, - client_assertion_type: None, - }) - .unwrap(), - "client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F" - ); - assert_eq!( - serde_urlencoded::to_string(BodyClientCredentials { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_assertion: Some(CLIENT_SECRET.to_owned()), - client_assertion_type: Some(JwtBearerClientAssertionType) - }) - .unwrap(), - "client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" - ); - } - - #[test] - fn serialize_request_with_credentials() { - let req = RequestWithClientCredentials { - body: Body { body: REQUEST_BODY }, - credentials: None, - }; - assert_eq!(serde_urlencoded::to_string(req).unwrap(), "body=some_body"); - - let req = RequestWithClientCredentials { - body: Body { body: REQUEST_BODY }, - credentials: Some(BodyClientCredentials { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_assertion: None, - client_assertion_type: None, - }), - }; - assert_eq!( - serde_urlencoded::to_string(req).unwrap(), - "body=some_body&client_id=abcd%24%2B%2B" - ); - - let req = RequestWithClientCredentials { - body: Body { body: REQUEST_BODY }, - credentials: Some(BodyClientCredentials { - client_id: CLIENT_ID.to_owned(), - client_secret: Some(CLIENT_SECRET.to_owned()), - client_assertion: None, - client_assertion_type: None, - }), - }; - assert_eq!( - serde_urlencoded::to_string(req).unwrap(), - "body=some_body&client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F" - ); - - let req = RequestWithClientCredentials { - body: Body { body: REQUEST_BODY }, - credentials: Some(BodyClientCredentials { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_assertion: Some(CLIENT_SECRET.to_owned()), - client_assertion_type: Some(JwtBearerClientAssertionType), - }), - }; - assert_eq!( - serde_urlencoded::to_string(req).unwrap(), - "body=some_body&client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer" - ); - } - #[tokio::test] async fn build_request_none() { let credentials = ClientCredentials::None { @@ -677,3 +547,5 @@ mod test { credentials.client_assertion_type.unwrap(); } } + +*/ diff --git a/crates/oidc-client/src/utils/mod.rs b/crates/oidc-client/src/utils/mod.rs deleted file mode 100644 index e2b701e1f..000000000 --- a/crates/oidc-client/src/utils/mod.rs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::ops::RangeBounds; - -use bytes::Buf; -use http::{Response, StatusCode}; - -use crate::error::ErrorBody; - -pub fn http_error_mapper(response: Response) -> Option -where - T: Buf, -{ - let body = response.into_body(); - serde_json::from_reader(body.reader()).ok() -} - -pub fn http_all_error_status_codes() -> impl RangeBounds { - let Ok(client_errors_start_code) = StatusCode::from_u16(400) else { - unreachable!() - }; - let Ok(server_errors_end_code) = StatusCode::from_u16(599) else { - unreachable!() - }; - - client_errors_start_code..=server_errors_end_code -} diff --git a/crates/oidc-client/tests/it/main.rs b/crates/oidc-client/tests/it/main.rs index 53efb261c..dbc137d70 100644 --- a/crates/oidc-client/tests/it/main.rs +++ b/crates/oidc-client/tests/it/main.rs @@ -7,8 +7,6 @@ use std::collections::HashMap; use chrono::{DateTime, Duration, Utc}; -use http_body_util::Full; -use mas_http::{BodyToBytesResponseLayer, BoxCloneSyncService}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; use mas_jose::{ claims::{self, hash_token}, @@ -17,21 +15,14 @@ use mas_jose::{ jwt::{JsonWebSignatureHeader, Jwt}, }; use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey}; -use mas_oidc_client::{ - http_service::HttpService, - types::{ - client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod}, - IdToken, - }, +use mas_oidc_client::types::{ + client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod}, + IdToken, }; use rand::{ distributions::{Alphanumeric, DistString}, SeedableRng, }; -use tower::{ - util::{MapErrLayer, MapRequestLayer}, - BoxError, Layer, -}; use url::Url; use wiremock::MockServer; @@ -41,7 +32,6 @@ mod types; const REDIRECT_URI: &str = "http://localhost/"; const CLIENT_ID: &str = "client!+ID"; const CLIENT_SECRET: &str = "SECRET?%Gclient"; -const REQUEST_URI: &str = "REQUESTur1"; const AUTHORIZATION_CODE: &str = "authC0D3"; const CODE_VERIFIER: &str = "cODEv3R1f1ER"; const NONCE: &str = "No0o0o0once"; diff --git a/crates/oidc-client/tests/it/requests/account_management.rs b/crates/oidc-client/tests/it/requests/account_management.rs deleted file mode 100644 index 83eb78bb8..000000000 --- a/crates/oidc-client/tests/it/requests/account_management.rs +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::collections::HashMap; - -use mas_oidc_client::requests::account_management::{ - build_account_management_url, AccountManagementActionFull, -}; -use url::Url; - -#[test] -fn build_url() { - let account_management_uri = Url::parse("http://localhost/account_management/").unwrap(); - - // No params - let url = build_account_management_url(account_management_uri.clone(), None, None).unwrap(); - - assert_eq!(url.query(), None); - - // Action without device ID. - let url = build_account_management_url( - account_management_uri.clone(), - Some(AccountManagementActionFull::Profile), - None, - ) - .unwrap(); - - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.len(), 1); - assert_eq!(query_pairs.get("action").unwrap(), "org.matrix.profile"); - - // Action with device ID. - let url = build_account_management_url( - account_management_uri.clone(), - Some(AccountManagementActionFull::SessionEnd { - device_id: "mydevice".to_owned(), - }), - None, - ) - .unwrap(); - - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.len(), 2); - assert_eq!(query_pairs.get("action").unwrap(), "org.matrix.session_end"); - assert_eq!(query_pairs.get("device_id").unwrap(), "mydevice"); - - // ID Token hint. - let url = build_account_management_url( - account_management_uri.clone(), - None, - Some("anidtokenthat.might.looksomethinglikethis".to_owned()), - ) - .unwrap(); - - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.len(), 1); - assert_eq!( - query_pairs.get("id_token_hint").unwrap(), - "anidtokenthat.might.looksomethinglikethis" - ); - - // Action without device ID and ID Token hint. - let url = build_account_management_url( - account_management_uri.clone(), - Some(AccountManagementActionFull::AccountDeactivate), - Some("anotheridtokenthat.might.looksomethinglikethis".to_owned()), - ) - .unwrap(); - - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.len(), 2); - assert_eq!( - query_pairs.get("action").unwrap(), - "org.matrix.account_deactivate" - ); - assert_eq!( - query_pairs.get("id_token_hint").unwrap(), - "anotheridtokenthat.might.looksomethinglikethis" - ); - - // Action with device ID and ID Token hint. - let url = build_account_management_url( - account_management_uri, - Some(AccountManagementActionFull::SessionView { - device_id: "myseconddevice".to_owned(), - }), - Some("athirdidtokenthat.might.looksomethinglikethis".to_owned()), - ) - .unwrap(); - - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.len(), 3); - assert_eq!( - query_pairs.get("action").unwrap(), - "org.matrix.session_view" - ); - assert_eq!(query_pairs.get("device_id").unwrap(), "myseconddevice"); - assert_eq!( - query_pairs.get("id_token_hint").unwrap(), - "athirdidtokenthat.might.looksomethinglikethis" - ); - - // Account management URI with a query already. - let account_management_uri_with_query = - Url::parse("http://localhost/account_management?param=value").unwrap(); - - let url = build_account_management_url( - account_management_uri_with_query, - Some(AccountManagementActionFull::SessionsList), - Some("afinalidtokenthat.might.looksomethinglikethis".to_owned()), - ) - .unwrap(); - - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.len(), 3); - assert_eq!( - query_pairs.get("action").unwrap(), - "org.matrix.sessions_list" - ); - assert_eq!( - query_pairs.get("id_token_hint").unwrap(), - "afinalidtokenthat.might.looksomethinglikethis" - ); -} diff --git a/crates/oidc-client/tests/it/requests/authorization_code.rs b/crates/oidc-client/tests/it/requests/authorization_code.rs index a22d9f2bc..e212b66be 100644 --- a/crates/oidc-client/tests/it/requests/authorization_code.rs +++ b/crates/oidc-client/tests/it/requests/authorization_code.rs @@ -4,34 +4,26 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -use std::{ - collections::HashMap, - num::NonZeroU32, - sync::{Arc, Mutex}, -}; +use std::{collections::HashMap, num::NonZeroU32}; use assert_matches::assert_matches; -use chrono::Duration; use mas_iana::oauth::{ OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod, }; use mas_jose::{claims::ClaimError, jwk::PublicJsonWebKeySet}; use mas_oidc_client::{ - error::{ - AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError, - }, + error::{IdTokenError, TokenAuthorizationCodeError}, requests::{ authorization_code::{ access_token_with_authorization_code, build_authorization_url, - build_par_authorization_url, AuthorizationRequestData, AuthorizationValidationData, + AuthorizationRequestData, AuthorizationValidationData, }, jose::JwtVerificationData, }, types::scope::{ScopeExt, ScopeToken}, }; -use oauth2_types::requests::{AccessTokenResponse, Display, Prompt, PushedAuthorizationResponse}; +use oauth2_types::requests::{AccessTokenResponse, Display, Prompt}; use rand::SeedableRng; -use tokio::sync::oneshot; use url::Url; use wiremock::{ matchers::{method, path}, @@ -40,7 +32,7 @@ use wiremock::{ use crate::{ client_credentials, id_token, init_test, now, ACCESS_TOKEN, AUTHORIZATION_CODE, CLIENT_ID, - CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI, REQUEST_URI, + CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI, }; #[test] @@ -137,115 +129,6 @@ fn pass_full_authorization_url() { assert_eq!(query_pairs.get("code_challenge_method"), None); } -#[tokio::test] -async fn pass_pushed_authorization_request() { - let (http_service, mock_server, issuer) = init_test().await; - let client_credentials = - client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None); - let authorization_endpoint = issuer.join("authorize").unwrap(); - let par_endpoint = issuer.join("par").unwrap(); - let redirect_uri = Url::parse(REDIRECT_URI).unwrap(); - let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); - - let (sender, receiver) = oneshot::channel(); - let sender_mutex = Arc::new(Mutex::new(Some(sender))); - - Mock::given(method("POST")) - .and(path("/par")) - .and(move |req: &Request| { - let body = form_urlencoded::parse(&req.body) - .into_owned() - .collect::>(); - if let Some(sender) = sender_mutex.lock().unwrap().take() { - sender.send(body).unwrap(); - true - } else { - false - } - }) - .respond_with( - ResponseTemplate::new(200).set_body_json(PushedAuthorizationResponse { - request_uri: REQUEST_URI.to_owned(), - expires_in: Duration::microseconds(30 * 1000 * 1000), - }), - ) - .mount(&mock_server) - .await; - - let (url, validation_data) = build_par_authorization_url( - &http_service, - client_credentials, - &par_endpoint, - authorization_endpoint, - AuthorizationRequestData::new( - CLIENT_ID.to_owned(), - [ScopeToken::Openid].into_iter().collect(), - redirect_uri, - ) - .with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]), - now(), - &mut rng, - ) - .await - .unwrap(); - - assert_eq!(validation_data.state, "OrJ8xbWovSpJUTKz"); - assert_eq!( - validation_data.code_challenge_verifier.unwrap(), - "TSgZ_hr3TJPjhq4aDp34K_8ksjLwaa1xDcPiRGBcjhM" - ); - - let request_pairs = receiver.await.unwrap(); - - assert_eq!(url.path(), "/authorize"); - let query_pairs = url.query_pairs().collect::>(); - assert_eq!(query_pairs.get("request_uri").unwrap(), REQUEST_URI,); - assert_eq!(query_pairs.get("client_id").unwrap(), CLIENT_ID); - - assert_eq!(request_pairs.get("scope").unwrap(), "openid"); - assert_eq!(request_pairs.get("response_type").unwrap(), "code"); - assert_eq!(request_pairs.get("client_id").unwrap(), CLIENT_ID); - assert_eq!(request_pairs.get("redirect_uri").unwrap(), REDIRECT_URI); - assert_eq!(*request_pairs.get("state").unwrap(), validation_data.state); - assert_eq!(request_pairs.get("nonce").unwrap(), "ox0PigY5l9xl5uTL"); - let code_challenge = request_pairs.get("code_challenge").unwrap(); - assert!(code_challenge.len() >= 43); - assert_eq!(request_pairs.get("code_challenge_method").unwrap(), "S256"); -} - -#[tokio::test] -async fn fail_pushed_authorization_request_404() { - let (http_service, _, issuer) = init_test().await; - let client_credentials = - client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None); - let authorization_endpoint = issuer.join("authorize").unwrap(); - let par_endpoint = issuer.join("par").unwrap(); - let redirect_uri = Url::parse(REDIRECT_URI).unwrap(); - let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); - - let error = build_par_authorization_url( - &http_service, - client_credentials, - &par_endpoint, - authorization_endpoint, - AuthorizationRequestData::new( - CLIENT_ID.to_owned(), - [ScopeToken::Openid].into_iter().collect(), - redirect_uri, - ) - .with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]), - now(), - &mut rng, - ) - .await - .unwrap_err(); - - assert_matches!( - error, - AuthorizationError::PushedAuthorization(PushedAuthorizationError::Http(_)) - ); -} - /// Check if the given request to the token endpoint is valid. fn is_valid_token_endpoint_request(req: &Request) -> bool { let body = form_urlencoded::parse(&req.body).collect::>(); diff --git a/crates/oidc-client/tests/it/requests/discovery.rs b/crates/oidc-client/tests/it/requests/discovery.rs index d7cd6cf13..9eb012f82 100644 --- a/crates/oidc-client/tests/it/requests/discovery.rs +++ b/crates/oidc-client/tests/it/requests/discovery.rs @@ -36,7 +36,7 @@ fn provider_metadata(issuer: &Url) -> ProviderMetadata { #[tokio::test] async fn pass_discover() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; Mock::given(method("GET")) .and(path("/.well-known/openid-configuration")) @@ -44,7 +44,9 @@ async fn pass_discover() { .mount(&mock_server) .await; - let provider_metadata = insecure_discover(&client, issuer.as_str()).await.unwrap(); + let provider_metadata = insecure_discover(&http_client, issuer.as_str()) + .await + .unwrap(); assert_eq!(provider_metadata.issuer(), issuer.as_str()); } @@ -70,7 +72,7 @@ async fn fail_discover_not_json() { let error = discover(&http_service, issuer.as_str()).await.unwrap_err(); - assert_matches!(error, DiscoveryError::FromJson(_)); + assert_matches!(error, DiscoveryError::Http(_)); } #[tokio::test] diff --git a/crates/oidc-client/tests/it/requests/introspection.rs b/crates/oidc-client/tests/it/requests/introspection.rs deleted file mode 100644 index 8b5ea70fd..000000000 --- a/crates/oidc-client/tests/it/requests/introspection.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::collections::HashMap; - -use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; -use mas_oidc_client::{ - requests::introspection::introspect_token, - types::scope::{ScopeExt, ScopeToken}, -}; -use oauth2_types::requests::IntrospectionResponse; -use rand::SeedableRng; -use wiremock::{ - matchers::{method, path}, - Mock, Request, ResponseTemplate, -}; - -use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, SUBJECT_IDENTIFIER}; - -#[tokio::test] -async fn pass_introspect_token() { - let (http_service, mock_server, issuer) = init_test().await; - let client_credentials = - client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None); - let introspection_endpoint = issuer.join("introspect").unwrap(); - let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); - - Mock::given(method("POST")) - .and(path("/introspect")) - .and(|req: &Request| { - let query_pairs = form_urlencoded::parse(&req.body).collect::>(); - - if query_pairs - .get("token") - .filter(|s| *s == ACCESS_TOKEN) - .is_none() - { - println!("Wrong or missing token"); - return false; - } - if query_pairs - .get("token_type_hint") - .filter(|s| *s == "access_token") - .is_none() - { - println!("Wrong or missing token type hint"); - return false; - } - if query_pairs - .get("client_id") - .filter(|s| *s == CLIENT_ID) - .is_none() - { - println!("Wrong or missing client ID"); - return false; - } - - true - }) - .respond_with( - ResponseTemplate::new(200).set_body_json(IntrospectionResponse { - active: true, - scope: Some([ScopeToken::Profile].into_iter().collect()), - client_id: Some(CLIENT_ID.to_owned()), - username: None, - token_type: Some(OAuthTokenTypeHint::AccessToken), - exp: None, - iat: None, - nbf: None, - sub: Some(SUBJECT_IDENTIFIER.to_owned()), - aud: Some(CLIENT_ID.to_owned()), - iss: Some(issuer.to_string()), - jti: None, - }), - ) - .mount(&mock_server) - .await; - - let response = introspect_token( - &http_service, - client_credentials.into(), - &introspection_endpoint, - ACCESS_TOKEN.to_owned(), - Some(OAuthTokenTypeHint::AccessToken), - now(), - &mut rng, - ) - .await - .unwrap(); - - assert!(response.active); - assert_eq!(response.aud.unwrap(), CLIENT_ID); - assert!(response.scope.unwrap().contains_token(&ScopeToken::Profile)); - assert_eq!(response.client_id.unwrap(), CLIENT_ID); - assert_eq!(response.iss.unwrap(), issuer.as_str()); - assert_eq!(response.sub.unwrap(), SUBJECT_IDENTIFIER); -} diff --git a/crates/oidc-client/tests/it/requests/mod.rs b/crates/oidc-client/tests/it/requests/mod.rs index ffd50286a..e2ff1ee10 100644 --- a/crates/oidc-client/tests/it/requests/mod.rs +++ b/crates/oidc-client/tests/it/requests/mod.rs @@ -4,14 +4,10 @@ // SPDX-License-Identifier: AGPL-3.0-only // Please see LICENSE in the repository root for full details. -mod account_management; mod authorization_code; mod client_credentials; mod discovery; -mod introspection; mod jose; mod refresh_token; -mod registration; -mod revocation; mod rp_initiated_logout; mod userinfo; diff --git a/crates/oidc-client/tests/it/requests/registration.rs b/crates/oidc-client/tests/it/requests/registration.rs deleted file mode 100644 index d15604830..000000000 --- a/crates/oidc-client/tests/it/requests/registration.rs +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use assert_matches::assert_matches; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use mas_jose::jwk::PublicJsonWebKeySet; -use mas_oidc_client::{error::RegistrationError, requests::registration::register_client}; -use oauth2_types::{ - oidc::ApplicationType, - registration::{ClientMetadata, ClientRegistrationResponse, VerifiedClientMetadata}, -}; -use serde_json::json; -use url::Url; -use wiremock::{ - matchers::{body_partial_json, method, path}, - Mock, Request, ResponseTemplate, -}; - -use crate::{init_test, CLIENT_ID, CLIENT_SECRET, REDIRECT_URI}; - -/// Generate valid client metadata for the given authentication method. -fn client_metadata(auth_method: OAuthClientAuthenticationMethod) -> VerifiedClientMetadata { - let (signing_alg, jwks) = match &auth_method { - OAuthClientAuthenticationMethod::ClientSecretJwt => { - (Some(JsonWebSignatureAlg::Hs256), None) - } - OAuthClientAuthenticationMethod::PrivateKeyJwt => ( - Some(JsonWebSignatureAlg::Es256), - Some(PublicJsonWebKeySet::default()), - ), - _ => (None, None), - }; - - ClientMetadata { - redirect_uris: Some(vec![Url::parse(REDIRECT_URI).expect("Couldn't parse URL")]), - application_type: Some(ApplicationType::Native), - token_endpoint_auth_method: Some(auth_method), - token_endpoint_auth_signing_alg: signing_alg, - jwks, - ..Default::default() - } - .validate() - .unwrap() -} - -#[tokio::test] -async fn pass_register_client_none() { - let (http_service, mock_server, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None); - let registration_endpoint = issuer.join("register").unwrap(); - - Mock::given(method("POST")) - .and(path("/register")) - .and(body_partial_json(json!({ - "redirect_uris": [REDIRECT_URI], - "token_endpoint_auth_method": "none", - }))) - .respond_with( - ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_id_issued_at: None, - client_secret_expires_at: None, - }), - ) - .mount(&mock_server) - .await; - - let response = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap(); - - assert_eq!(response.client_id, CLIENT_ID); - assert_eq!(response.client_secret, None); -} - -#[tokio::test] -async fn pass_register_client_client_secret_basic() { - let (http_service, mock_server, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic); - let registration_endpoint = issuer.join("register").unwrap(); - - Mock::given(method("POST")) - .and(path("/register")) - .and(body_partial_json(json!({ - "redirect_uris": [REDIRECT_URI], - "token_endpoint_auth_method": "client_secret_basic", - }))) - .respond_with( - ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: Some(CLIENT_SECRET.to_owned()), - client_id_issued_at: None, - client_secret_expires_at: None, - }), - ) - .mount(&mock_server) - .await; - - let response = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap(); - - assert_eq!(response.client_id, CLIENT_ID); - assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET); -} - -#[tokio::test] -async fn pass_register_client_client_secret_post() { - let (http_service, mock_server, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretPost); - let registration_endpoint = issuer.join("register").unwrap(); - - Mock::given(method("POST")) - .and(path("/register")) - .and(body_partial_json(json!({ - "redirect_uris": [REDIRECT_URI], - "token_endpoint_auth_method": "client_secret_post", - }))) - .respond_with( - ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: Some(CLIENT_SECRET.to_owned()), - client_id_issued_at: None, - client_secret_expires_at: None, - }), - ) - .mount(&mock_server) - .await; - - let response = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap(); - - assert_eq!(response.client_id, CLIENT_ID); - assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET); -} - -#[tokio::test] -async fn pass_register_client_client_secret_jwt() { - let (http_service, mock_server, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretJwt); - let registration_endpoint = issuer.join("register").unwrap(); - - Mock::given(method("POST")) - .and(path("/register")) - .and(body_partial_json(json!({ - "redirect_uris": [REDIRECT_URI], - "token_endpoint_auth_method": "client_secret_jwt", - "token_endpoint_auth_signing_alg": "HS256", - }))) - .respond_with( - ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: Some(CLIENT_SECRET.to_owned()), - client_id_issued_at: None, - client_secret_expires_at: None, - }), - ) - .mount(&mock_server) - .await; - - let response = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap(); - - assert_eq!(response.client_id, CLIENT_ID); - assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET); -} - -#[tokio::test] -async fn pass_register_client_private_key_jwt() { - let (http_service, mock_server, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::PrivateKeyJwt); - let registration_endpoint = issuer.join("register").unwrap(); - - Mock::given(method("POST")) - .and(path("/register")) - .and(|req: &Request| { - let Ok(metadata) = req.body_json::() else { - return false; - }; - - *metadata.token_endpoint_auth_method() == OAuthClientAuthenticationMethod::PrivateKeyJwt - && metadata.token_endpoint_auth_signing_alg == Some(JsonWebSignatureAlg::Es256) - && metadata.jwks.is_some() - }) - .respond_with( - ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_id_issued_at: None, - client_secret_expires_at: None, - }), - ) - .mount(&mock_server) - .await; - - let response = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap(); - - assert_eq!(response.client_id, CLIENT_ID); - assert_eq!(response.client_secret, None); -} - -#[tokio::test] -async fn fail_register_client_404() { - let (http_service, _, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None); - let registration_endpoint = issuer.join("register").unwrap(); - - let error = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap_err(); - - assert_matches!(error, RegistrationError::Http(_)); -} - -#[tokio::test] -async fn fail_register_client_missing_secret() { - let (http_service, mock_server, issuer) = init_test().await; - let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic); - let registration_endpoint = issuer.join("register").unwrap(); - - Mock::given(method("POST")) - .and(path("/register")) - .and(body_partial_json(json!({ - "token_endpoint_auth_method": "client_secret_basic", - }))) - .respond_with( - ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse { - client_id: CLIENT_ID.to_owned(), - client_secret: None, - client_id_issued_at: None, - client_secret_expires_at: None, - }), - ) - .mount(&mock_server) - .await; - - let error = register_client(&http_service, ®istration_endpoint, client_metadata, None) - .await - .unwrap_err(); - - assert_matches!(error, RegistrationError::MissingClientSecret); -} diff --git a/crates/oidc-client/tests/it/requests/revocation.rs b/crates/oidc-client/tests/it/requests/revocation.rs deleted file mode 100644 index bcea2642f..000000000 --- a/crates/oidc-client/tests/it/requests/revocation.rs +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 Kévin Commaille. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::collections::HashMap; - -use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; -use mas_oidc_client::requests::revocation::revoke_token; -use rand::SeedableRng; -use wiremock::{ - matchers::{method, path}, - Mock, Request, ResponseTemplate, -}; - -use crate::{client_credentials, init_test, ACCESS_TOKEN, CLIENT_ID}; - -#[tokio::test] -async fn pass_revoke_token() { - let (http_service, mock_server, issuer) = init_test().await; - let client_credentials = - client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None); - let revocation_endpoint = issuer.join("revoke").unwrap(); - let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); - - Mock::given(method("POST")) - .and(path("/revoke")) - .and(|req: &Request| { - let query_pairs = form_urlencoded::parse(&req.body).collect::>(); - - if query_pairs - .get("token") - .filter(|s| *s == ACCESS_TOKEN) - .is_none() - { - println!("Wrong or missing refresh token"); - return false; - } - if query_pairs - .get("token_type_hint") - .filter(|s| *s == "access_token") - .is_none() - { - println!("Wrong or missing token type hint"); - return false; - } - if query_pairs - .get("client_id") - .filter(|s| *s == CLIENT_ID) - .is_none() - { - println!("Wrong or missing client ID"); - return false; - } - - true - }) - .respond_with(ResponseTemplate::new(200)) - .mount(&mock_server) - .await; - - revoke_token( - &http_service, - client_credentials, - &revocation_endpoint, - ACCESS_TOKEN.to_owned(), - Some(OAuthTokenTypeHint::AccessToken), - crate::now(), - &mut rng, - ) - .await - .unwrap(); -} diff --git a/crates/oidc-client/tests/it/types/client_credentials.rs b/crates/oidc-client/tests/it/types/client_credentials.rs index d6d8413a1..9e52fd66b 100644 --- a/crates/oidc-client/tests/it/types/client_credentials.rs +++ b/crates/oidc-client/tests/it/types/client_credentials.rs @@ -8,6 +8,7 @@ use std::collections::HashMap; use assert_matches::assert_matches; use base64ct::Encoding; +use http::header::AUTHORIZATION; use mas_iana::oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod}; use mas_jose::{ claims::{self, TimeOptions}, @@ -31,7 +32,7 @@ use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, CLIENT_ #[tokio::test] async fn pass_none() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; let client_credentials = client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None); let token_endpoint = issuer.join("token").unwrap(); @@ -67,7 +68,7 @@ async fn pass_none() { .await; access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, @@ -80,7 +81,7 @@ async fn pass_none() { #[tokio::test] async fn pass_client_secret_basic() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; let client_credentials = client_credentials( &OAuthClientAuthenticationMethod::ClientSecretBasic, &issuer, @@ -94,10 +95,15 @@ async fn pass_client_secret_basic() { let enc_user_pass = base64ct::Base64::encode_string(format!("{username}:{password}").as_bytes()); let authorization_header = format!("Basic {enc_user_pass}"); + eprintln!("{authorization_header}"); Mock::given(method("POST")) .and(path("/token")) - .and(header("authorization", authorization_header.as_str())) + .and(|req: &Request| { + println!("{req:?}"); + true + }) + .and(header(AUTHORIZATION, authorization_header.as_str())) .respond_with( ResponseTemplate::new(200).set_body_json(AccessTokenResponse { access_token: ACCESS_TOKEN.to_owned(), @@ -112,7 +118,7 @@ async fn pass_client_secret_basic() { .await; access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, @@ -125,7 +131,7 @@ async fn pass_client_secret_basic() { #[tokio::test] async fn pass_client_secret_post() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; let client_credentials = client_credentials( &OAuthClientAuthenticationMethod::ClientSecretPost, &issuer, @@ -172,7 +178,7 @@ async fn pass_client_secret_post() { .await; access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, @@ -185,7 +191,7 @@ async fn pass_client_secret_post() { #[tokio::test] async fn pass_client_secret_jwt() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; let client_credentials = client_credentials( &OAuthClientAuthenticationMethod::ClientSecretJwt, &issuer, @@ -253,7 +259,7 @@ async fn pass_client_secret_jwt() { .await; access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, @@ -266,7 +272,7 @@ async fn pass_client_secret_jwt() { #[tokio::test] async fn pass_private_key_jwt_with_keystore() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; let client_credentials = client_credentials( &OAuthClientAuthenticationMethod::PrivateKeyJwt, &issuer, @@ -341,7 +347,7 @@ async fn pass_private_key_jwt_with_keystore() { .await; access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, @@ -354,7 +360,7 @@ async fn pass_private_key_jwt_with_keystore() { #[tokio::test] async fn pass_private_key_jwt_with_custom_signing() { - let (http_service, mock_server, issuer) = init_test().await; + let (http_client, mock_server, issuer) = init_test().await; let client_credentials = client_credentials( &OAuthClientAuthenticationMethod::PrivateKeyJwt, &issuer, @@ -410,7 +416,7 @@ async fn pass_private_key_jwt_with_custom_signing() { .await; access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, @@ -423,7 +429,7 @@ async fn pass_private_key_jwt_with_custom_signing() { #[tokio::test] async fn fail_private_key_jwt_with_custom_signing() { - let (http_service, _, issuer) = init_test().await; + let (http_client, _, issuer) = init_test().await; let client_credentials = client_credentials( &OAuthClientAuthenticationMethod::PrivateKeyJwt, &issuer, @@ -433,7 +439,7 @@ async fn fail_private_key_jwt_with_custom_signing() { let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42); let error = access_token_with_client_credentials( - &http_service, + &http_client, client_credentials, &token_endpoint, None, From 420b8da6b057c62cc64c3eb11f3077ee2c3430b1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 24 Oct 2024 18:56:49 +0200 Subject: [PATCH 03/12] Replace all the manual HTTP clients with reqwest --- Cargo.lock | 42 +- Cargo.toml | 8 +- crates/axum-utils/Cargo.toml | 1 + crates/axum-utils/src/client_authorization.rs | 30 +- crates/axum-utils/src/http_client_factory.rs | 65 --- crates/axum-utils/src/lib.rs | 1 - crates/cli/src/app_state.rs | 9 +- crates/cli/src/commands/debug.rs | 84 ---- crates/cli/src/commands/doctor.rs | 139 +++--- crates/cli/src/commands/manage.rs | 5 +- crates/cli/src/commands/server.rs | 7 +- crates/cli/src/commands/worker.rs | 5 +- crates/cli/src/main.rs | 29 +- crates/cli/src/sentry_transport/mod.rs | 121 ------ crates/cli/src/sentry_transport/ratelimit.rs | 180 -------- .../cli/src/sentry_transport/tokio_thread.rs | 119 ----- crates/cli/src/telemetry.rs | 9 +- crates/handlers/src/captcha.rs | 74 ++-- crates/handlers/src/lib.rs | 7 +- .../handlers/src/oauth2/device/authorize.rs | 10 +- crates/handlers/src/oauth2/introspection.rs | 5 +- crates/handlers/src/oauth2/revoke.rs | 5 +- crates/handlers/src/oauth2/token.rs | 5 +- crates/handlers/src/test_utils.rs | 11 - crates/handlers/src/views/register.rs | 5 +- crates/http/Cargo.toml | 24 +- crates/http/src/client.rs | 88 ---- crates/http/src/ext.rs | 73 +--- .../http/src/layers/body_to_bytes_response.rs | 97 ----- .../http/src/layers/bytes_to_body_request.rs | 58 --- crates/http/src/layers/catch_http_codes.rs | 139 ------ crates/http/src/layers/client.rs | 249 ----------- .../src/layers/form_urlencoded_request.rs | 111 ----- crates/http/src/layers/json_request.rs | 111 ----- crates/http/src/layers/json_response.rs | 112 ----- crates/http/src/layers/mod.rs | 13 - crates/http/src/lib.rs | 21 +- crates/http/src/service.rs | 96 ----- crates/http/tests/client_layers.rs | 148 ------- crates/matrix-synapse/Cargo.toml | 2 + crates/matrix-synapse/src/error.rs | 71 +-- crates/matrix-synapse/src/lib.rs | 406 +++++++----------- crates/oidc-client/src/error.rs | 110 +---- crates/oidc-client/src/http_service.rs | 10 - crates/oidc-client/src/lib.rs | 8 +- 45 files changed, 381 insertions(+), 2542 deletions(-) delete mode 100644 crates/axum-utils/src/http_client_factory.rs delete mode 100644 crates/cli/src/sentry_transport/mod.rs delete mode 100644 crates/cli/src/sentry_transport/ratelimit.rs delete mode 100644 crates/cli/src/sentry_transport/tokio_thread.rs delete mode 100644 crates/http/src/client.rs delete mode 100644 crates/http/src/layers/body_to_bytes_response.rs delete mode 100644 crates/http/src/layers/bytes_to_body_request.rs delete mode 100644 crates/http/src/layers/catch_http_codes.rs delete mode 100644 crates/http/src/layers/client.rs delete mode 100644 crates/http/src/layers/form_urlencoded_request.rs delete mode 100644 crates/http/src/layers/json_request.rs delete mode 100644 crates/http/src/layers/json_response.rs delete mode 100644 crates/http/src/layers/mod.rs delete mode 100644 crates/http/src/service.rs delete mode 100644 crates/http/tests/client_layers.rs delete mode 100644 crates/oidc-client/src/http_service.rs diff --git a/Cargo.lock b/Cargo.lock index 7373d128a..b46b75b79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2884,16 +2884,6 @@ dependencies = [ "serde", ] -[[package]] -name = "iri-string" -version = "0.7.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc0f0a572e8ffe56e2ff4f769f32ffe919282c3916799f8b68688b6030063bea" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -3188,6 +3178,7 @@ dependencies = [ "mime", "oauth2-types", "rand", + "reqwest", "sentry", "serde", "serde_json", @@ -3413,31 +3404,14 @@ dependencies = [ name = "mas-http" version = "0.12.0" dependencies = [ - "anyhow", - "async-trait", - "bytes", - "futures-util", "headers", "http", - "http-body", - "http-body-util", - "hyper", - "hyper-rustls", "hyper-util", - "mas-tower", "opentelemetry", "opentelemetry-http", "opentelemetry-semantic-conventions", - "pin-project-lite", "reqwest", - "rustls", "rustls-platform-verifier", - "serde", - "serde_json", - "serde_urlencoded", - "thiserror", - "tokio", - "tower 0.5.1", "tower-http", "tracing", "tracing-opentelemetry", @@ -3612,8 +3586,10 @@ dependencies = [ "mas-axum-utils", "mas-http", "mas-matrix", + "reqwest", "serde", "serde_json", + "thiserror", "tower 0.5.1", "tracing", "url", @@ -4158,11 +4134,8 @@ dependencies = [ "async-trait", "bytes", "http", - "http-body-util", - "hyper", - "hyper-util", "opentelemetry", - "tokio", + "reqwest", ] [[package]] @@ -5032,6 +5005,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2", @@ -5422,12 +5396,15 @@ version = "0.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5484316556650182f03b43d4c746ce0e3e48074a21e2f51244b648b6542e1066" dependencies = [ + "httpdate", + "reqwest", "sentry-backtrace", "sentry-contexts", "sentry-core", "sentry-panic", "sentry-tower", "sentry-tracing", + "tokio", ] [[package]] @@ -6398,7 +6375,6 @@ dependencies = [ "pin-project-lite", "sync_wrapper 0.1.2", "tokio", - "tokio-util", "tower-layer", "tower-service", "tracing", @@ -6418,14 +6394,12 @@ dependencies = [ "http-body-util", "http-range-header", "httpdate", - "iri-string", "mime", "mime_guess", "percent-encoding", "pin-project-lite", "tokio", "tokio-util", - "tower 0.5.1", "tower-layer", "tower-service", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 68a859253..f1177edfc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -221,7 +221,7 @@ features = [ [workspace.dependencies.sentry] version = "0.34.0" default-features = false -features = ["backtrace", "contexts", "panic", "tower"] +features = ["backtrace", "contexts", "panic", "tower", "reqwest"] # Sentry tower layer [workspace.dependencies.sentry-tower] @@ -273,12 +273,12 @@ features = ["rt"] # Tower services [workspace.dependencies.tower] version = "0.5.1" -features = ["util", "limit"] +features = ["util"] # Tower HTTP layers [workspace.dependencies.tower-http] version = "0.6.1" -features = ["cors", "fs", "add-extension", "set-header", "follow-redirect"] +features = ["cors", "fs", "add-extension", "set-header"] # Logging and tracing [workspace.dependencies.tracing] @@ -292,7 +292,7 @@ version = "0.24.0" features = ["trace", "metrics"] [workspace.dependencies.opentelemetry-http] version = "0.13.0" -features = ["hyper"] +features = ["reqwest"] [workspace.dependencies.opentelemetry-semantic-conventions] version = "0.16.0" [workspace.dependencies.tracing-opentelemetry] diff --git a/crates/axum-utils/Cargo.toml b/crates/axum-utils/Cargo.toml index 8d16c3f22..e23e065b6 100644 --- a/crates/axum-utils/Cargo.toml +++ b/crates/axum-utils/Cargo.toml @@ -27,6 +27,7 @@ hyper-util.workspace = true icu_locid = "1.4.0" mime = "0.3.17" rand.workspace = true +reqwest.workspace = true sentry.workspace = true serde.workspace = true serde_with = "3.11.0" diff --git a/crates/axum-utils/src/client_authorization.rs b/crates/axum-utils/src/client_authorization.rs index e1f953887..c5a2dbb48 100644 --- a/crates/axum-utils/src/client_authorization.rs +++ b/crates/axum-utils/src/client_authorization.rs @@ -19,7 +19,7 @@ use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason}; use headers::{authorization::Basic, Authorization}; use http::{Request, StatusCode}; use mas_data_model::{Client, JwksOrJwksUri}; -use mas_http::HttpServiceExt; +use mas_http::RequestBuilderExt; use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt}; use mas_keystore::Encrypter; @@ -28,9 +28,6 @@ use oauth2_types::errors::{ClientError, ClientErrorCode}; use serde::{de::DeserializeOwned, Deserialize}; use serde_json::Value; use thiserror::Error; -use tower::{Service, ServiceExt}; - -use crate::http_client_factory::HttpClientFactory; static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; @@ -104,7 +101,7 @@ impl Credentials { #[tracing::instrument(skip_all, err)] pub async fn verify( &self, - http_client_factory: &HttpClientFactory, + http_client: &reqwest::Client, encrypter: &Encrypter, method: &OAuthClientAuthenticationMethod, client: &Client, @@ -146,7 +143,7 @@ impl Credentials { .as_ref() .ok_or(CredentialsVerificationError::InvalidClientConfig)?; - let jwks = fetch_jwks(http_client_factory, jwks) + let jwks = fetch_jwks(http_client, jwks) .await .map_err(|_| CredentialsVerificationError::JwksFetchFailed)?; @@ -181,7 +178,7 @@ impl Credentials { } async fn fetch_jwks( - http_client_factory: &HttpClientFactory, + http_client: &reqwest::Client, jwks: &JwksOrJwksUri, ) -> Result { let uri = match jwks { @@ -189,19 +186,14 @@ async fn fetch_jwks( JwksOrJwksUri::JwksUri(u) => u, }; - let request = http::Request::builder() - .uri(uri.as_str()) - .body(mas_http::EmptyBody::new()) - .unwrap(); - - let mut client = http_client_factory - .client("client.fetch_jwks") - .response_body_to_bytes() - .json_response::(); - - let response = client.ready().await?.call(request).await?; + let response = http_client + .get(uri.as_str()) + .send_traced() + .await? + .json() + .await?; - Ok(response.into_body()) + Ok(response) } #[derive(Debug, Error)] diff --git a/crates/axum-utils/src/http_client_factory.rs b/crates/axum-utils/src/http_client_factory.rs deleted file mode 100644 index 584dba499..000000000 --- a/crates/axum-utils/src/http_client_factory.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2022-2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use http_body_util::Full; -use hyper_util::rt::TokioExecutor; -use mas_http::{ - make_traced_connector, BodyToBytesResponseLayer, Client, ClientLayer, ClientService, - HttpService, TracedClient, TracedConnector, -}; -use tower::{ - util::{MapErrLayer, MapRequestLayer}, - BoxError, Layer, -}; - -#[derive(Debug, Clone)] -pub struct HttpClientFactory { - traced_connector: TracedConnector, - client_layer: ClientLayer, -} - -impl Default for HttpClientFactory { - fn default() -> Self { - Self::new() - } -} - -impl HttpClientFactory { - /// Constructs a new HTTP client factory - #[must_use] - pub fn new() -> Self { - Self { - traced_connector: make_traced_connector(), - client_layer: ClientLayer::new(), - } - } - - /// Constructs a new HTTP client - pub fn client(&self, category: &'static str) -> ClientService> - where - B: axum::body::HttpBody + Send, - B::Data: Send, - { - let client = Client::builder(TokioExecutor::new()).build(self.traced_connector.clone()); - self.client_layer - .clone() - .with_category(category) - .layer(client) - } - - /// Constructs a new [`HttpService`], suitable for `mas-oidc-client` - pub fn http_service(&self, category: &'static str) -> HttpService { - let client = self.client(category); - let client = ( - MapErrLayer::new(BoxError::from), - MapRequestLayer::new(|req: http::Request<_>| req.map(Full::new)), - BodyToBytesResponseLayer, - ) - .layer(client); - - HttpService::new(client) - } -} diff --git a/crates/axum-utils/src/lib.rs b/crates/axum-utils/src/lib.rs index db26454d2..aa6fad9e6 100644 --- a/crates/axum-utils/src/lib.rs +++ b/crates/axum-utils/src/lib.rs @@ -12,7 +12,6 @@ pub mod cookies; pub mod csrf; pub mod error_wrapper; pub mod fancy_error; -pub mod http_client_factory; pub mod jwt; pub mod language_detection; pub mod sentry; diff --git a/crates/cli/src/app_state.rs b/crates/cli/src/app_state.rs index ac511ab01..525ec5591 100644 --- a/crates/cli/src/app_state.rs +++ b/crates/cli/src/app_state.rs @@ -14,7 +14,7 @@ use ipnetwork::IpNetwork; use mas_data_model::SiteConfig; use mas_handlers::{ passwords::PasswordManager, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, - GraphQLSchema, HttpClientFactory, Limiter, MetadataCache, RequesterFingerprint, + GraphQLSchema, Limiter, MetadataCache, RequesterFingerprint, }; use mas_i18n::Translator; use mas_keystore::{Encrypter, Keystore}; @@ -43,7 +43,6 @@ pub struct AppState { pub homeserver_connection: SynapseConnection, pub policy_factory: Arc, pub graphql_schema: GraphQLSchema, - pub http_client_factory: HttpClientFactory, pub http_client: reqwest::Client, pub password_manager: PasswordManager, pub metadata_cache: MetadataCache, @@ -170,12 +169,6 @@ impl FromRef for UrlBuilder { } } -impl FromRef for HttpClientFactory { - fn from_ref(input: &AppState) -> Self { - input.http_client_factory.clone() - } -} - impl FromRef for reqwest::Client { fn from_ref(input: &AppState) -> Self { input.http_client.clone() diff --git a/crates/cli/src/commands/debug.rs b/crates/cli/src/commands/debug.rs index 74ad0d3c4..8ffb50306 100644 --- a/crates/cli/src/commands/debug.rs +++ b/crates/cli/src/commands/debug.rs @@ -8,13 +8,7 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; -use http_body_util::BodyExt; -use hyper::{Response, Uri}; use mas_config::{ConfigurationSectionExt, PolicyConfig}; -use mas_handlers::HttpClientFactory; -use mas_http::HttpServiceExt; -use tokio::io::AsyncWriteExt; -use tower::{Service, ServiceExt}; use tracing::{info, info_span}; use crate::util::policy_factory_from_config; @@ -27,93 +21,15 @@ pub(super) struct Options { #[derive(Parser, Debug)] enum Subcommand { - /// Perform an HTTP request with the default HTTP client - Http { - /// Show response headers - #[arg(long, short = 'I')] - show_headers: bool, - - /// Parse the response as JSON - #[arg(long, short = 'j')] - json: bool, - - /// URI where to perform a GET request - url: Uri, - }, - /// Check that the policies compile Policy, } -fn print_headers(parts: &hyper::http::response::Parts) { - println!( - "{:?} {} {}", - parts.version, - parts.status.as_str(), - parts.status.canonical_reason().unwrap_or_default() - ); - - for (header, value) in &parts.headers { - println!("{header}: {value:?}"); - } - println!(); -} - impl Options { #[tracing::instrument(skip_all)] pub async fn run(self, figment: &Figment) -> anyhow::Result { use Subcommand as SC; - let http_client_factory = HttpClientFactory::new(); match self.subcommand { - SC::Http { - show_headers, - json: false, - url, - } => { - let _span = info_span!("cli.debug.http").entered(); - let mut client = http_client_factory.client("debug"); - let request = hyper::Request::builder() - .uri(url) - .body(axum::body::Body::empty())?; - - let response = client.ready().await?.call(request).await?; - let (parts, body) = response.into_parts(); - - if show_headers { - print_headers(&parts); - } - - let mut body = body.collect().await?.to_bytes(); - let mut stdout = tokio::io::stdout(); - stdout.write_all_buf(&mut body).await?; - } - - SC::Http { - show_headers, - json: true, - url, - } => { - let _span = info_span!("cli.debug.http").entered(); - let mut client = http_client_factory - .client("debug") - .response_body_to_bytes() - .json_response(); - let request = hyper::Request::builder() - .uri(url) - .body(axum::body::Body::empty())?; - - let response: Response = - client.ready().await?.call(request).await?; - let (parts, body) = response.into_parts(); - - if show_headers { - print_headers(&parts); - } - - let body = serde_json::to_string_pretty(&body)?; - println!("{body}"); - } - SC::Policy => { let _span = info_span!("cli.debug.policy").entered(); let config = PolicyConfig::extract_or_default(figment)?; diff --git a/crates/cli/src/commands/doctor.rs b/crates/cli/src/commands/doctor.rs index 0155a8033..d507b8308 100644 --- a/crates/cli/src/commands/doctor.rs +++ b/crates/cli/src/commands/doctor.rs @@ -15,9 +15,7 @@ use anyhow::Context; use clap::Parser; use figment::Figment; use mas_config::{ConfigurationSection, RootConfig}; -use mas_handlers::HttpClientFactory; -use mas_http::HttpServiceExt; -use tower::{Service, ServiceExt}; +use mas_http::RequestBuilderExt; use tracing::{error, info, info_span, warn}; use url::{Host, Url}; @@ -36,7 +34,7 @@ impl Options { let config = RootConfig::extract(figment)?; // We'll need an HTTP client - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let base_url = config.http.public_base.as_str(); let issuer = config.http.issuer.as_ref().map(url::Url::as_str); let issuer = issuer.unwrap_or(base_url); @@ -55,15 +53,7 @@ This means some clients will refuse to use it."# } let well_known_uri = format!("https://{matrix_domain}/.well-known/matrix/client"); - let mut client = http_client_factory - .client("doctor") - .response_body_to_bytes() - .json_response::(); - - let request = hyper::Request::builder() - .uri(&well_known_uri) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client.get(&well_known_uri).send_traced().await; let expected_well_known = serde_json::json!({ "m.homeserver": { @@ -86,15 +76,21 @@ Make sure the homeserver is reachable and the well-known document is available a ); } - let body = response.into_body(); - - if let Some(auth) = body.get("org.matrix.msc2965.authentication") { - if let Some(wk_issuer) = auth.get("issuer").and_then(|issuer| issuer.as_str()) { - if issuer == wk_issuer { - info!(r#"✅ Matrix client well-known at "{well_known_uri}" is valid"#); - } else { - warn!( - r#"⚠️ Matrix client well-known has an "org.matrix.msc2965.authentication" section, but the issuer is not the same as the homeserver. + let result = response.json::().await; + + match result { + Ok(body) => { + if let Some(auth) = body.get("org.matrix.msc2965.authentication") { + if let Some(wk_issuer) = + auth.get("issuer").and_then(|issuer| issuer.as_str()) + { + if issuer == wk_issuer { + info!( + r#"✅ Matrix client well-known at "{well_known_uri}" is valid"# + ); + } else { + warn!( + r#"⚠️ Matrix client well-known has an "org.matrix.msc2965.authentication" section, but the issuer is not the same as the homeserver. Check the well-known document at "{well_known_uri}" This can happen because MAS parses the URL its config differently from the homeserver. This means some OIDC-native clients might not work. @@ -116,18 +112,18 @@ And in the Synapse config: See {DOCS_BASE}/setup/homeserver.html "# - ); - } - } else { - error!( - r#"❌ Matrix client well-known "org.matrix.msc2965.authentication" does not have a valid "issuer" field. + ); + } + } else { + error!( + r#"❌ Matrix client well-known "org.matrix.msc2965.authentication" does not have a valid "issuer" field. Check the well-known document at "{well_known_uri}" "# - ); - } - } else { - warn!( - r#"Matrix client well-known is missing the "org.matrix.msc2965.authentication" section. + ); + } + } else { + warn!( + r#"Matrix client well-known is missing the "org.matrix.msc2965.authentication" section. Check the well-known document at "{well_known_uri}" Make sure Synapse has delegated auth enabled: @@ -143,14 +139,29 @@ If it is not Synapse handling the well-known document, update it to include the See {DOCS_BASE}/setup/homeserver.html "# - ); - } + ); + } + // Return the discovered homeserver base URL + body.get("m.homeserver") + .and_then(|hs| hs.get("base_url")) + .and_then(|base_url| base_url.as_str()) + .and_then(|base_url| Url::parse(base_url).ok()) + } + Err(e) => { + warn!( + r#"⚠️ Invalid JSON for the well-known document at "{well_known_uri}". +Make sure going to {well_known_uri:?} in a web browser returns a valid JSON document, similar to: + +{expected_well_known:#} + +See {DOCS_BASE}/setup/homeserver.html - // Return the discovered homeserver base URL - body.get("m.homeserver") - .and_then(|hs| hs.get("base_url")) - .and_then(|base_url| base_url.as_str()) - .and_then(|base_url| Url::parse(base_url).ok()) +Error details: {e} +"# + ); + None + } + } } Err(e) => { warn!( @@ -172,10 +183,10 @@ Error details: {e} // Now try to reach the homeserver let client_versions = hs_api.join("/_matrix/client/versions")?; - let request = hyper::Request::builder() - .uri(client_versions.as_str()) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client + .get(client_versions.as_str()) + .send_traced() + .await; let can_reach_cs = match result { Ok(response) => { let status = response.status(); @@ -222,18 +233,15 @@ Error details: {e} // Try the whoami API. If it replies with `M_UNKNOWN` this is because Synapse // couldn't reach MAS let whoami = hs_api.join("/_matrix/client/v3/account/whoami")?; - let request = hyper::Request::builder() - .header( - "Authorization", - "Bearer averyinvalidtokenireallyhopethisisnotvalid", - ) - .uri(whoami.as_str()) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client + .get(whoami.as_str()) + .bearer_auth("averyinvalidtokenireallyhopethisisnotvalid") + .send_traced() + .await; match result { Ok(response) => { - let (parts, body) = response.into_parts(); - let status = parts.status; + let status = response.status(); + let body = response.text().await.unwrap_or("???".into()); match status.as_u16() { 401 => info!( @@ -276,10 +284,7 @@ Error details: {e} // Try to reach the admin API on an unauthorized endpoint let server_version = hs_api.join("/_synapse/admin/v1/server_version")?; - let request = hyper::Request::builder() - .uri(server_version.as_str()) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client.get(server_version.as_str()).send_traced().await; match result { Ok(response) => { let status = response.status(); @@ -304,11 +309,11 @@ Error details: {e} // Try to reach an authenticated admin API endpoint let background_updates = hs_api.join("/_synapse/admin/v1/background_updates/status")?; - let request = hyper::Request::builder() - .uri(background_updates.as_str()) - .header("Authorization", format!("Bearer {admin_token}")) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client + .get(background_updates.as_str()) + .bearer_auth(&admin_token) + .send_traced() + .await; match result { Ok(response) => { let status = response.status(); @@ -353,17 +358,17 @@ Error details: {e} // Try to reach the legacy login API let compat_login = external_cs_api_endpoint.join("/_matrix/client/v3/login")?; let compat_login = compat_login.as_str(); - let request = hyper::Request::builder() - .uri(compat_login) - .body(axum::body::Body::empty())?; - let result = client.ready().await?.call(request).await; + let result = http_client.get(compat_login).send_traced().await; match result { Ok(response) => { let status = response.status(); if status.is_success() { // Now we need to inspect the body to figure out whether it's Synapse or MAS // which handled the request - let body = response.into_body(); + let body = response + .json::() + .await + .unwrap_or_default(); let flows = body .get("flows") .and_then(|flows| flows.as_array()) diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 07d4fa601..d8b55777b 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -16,7 +16,6 @@ use mas_config::{ }; use mas_data_model::{Device, TokenType, Ulid, UpstreamOAuthProvider, User}; use mas_email::Address; -use mas_handlers::HttpClientFactory; use mas_matrix::HomeserverConnection; use mas_matrix_synapse::SynapseConnection; use mas_storage::{ @@ -512,7 +511,7 @@ impl Options { yes, ignore_password_complexity, } => { - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let password_config = PasswordsConfig::extract_or_default(figment)?; let database_config = DatabaseConfig::extract_or_default(figment)?; let matrix_config = MatrixConfig::extract(figment)?; @@ -522,7 +521,7 @@ impl Options { matrix_config.homeserver, matrix_config.endpoint, matrix_config.secret, - http_client_factory, + http_client, ); let mut conn = database_connection_from_config(&database_config).await?; let txn = conn.begin().await?; diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index d32b99099..0558453d8 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -13,7 +13,7 @@ use itertools::Itertools; use mas_config::{ AppConfig, ClientsConfig, ConfigurationSection, ConfigurationSectionExt, UpstreamOAuth2Config, }; -use mas_handlers::{ActivityTracker, CookieManager, HttpClientFactory, Limiter, MetadataCache}; +use mas_handlers::{ActivityTracker, CookieManager, Limiter, MetadataCache}; use mas_listener::server::Server; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; @@ -148,15 +148,13 @@ impl Options { let templates = templates_from_config(&config.templates, &site_config, &url_builder).await?; - let http_client_factory = HttpClientFactory::new(); - let http_client = mas_http::reqwest_client(); let homeserver_connection = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), config.matrix.secret.clone(), - http_client_factory.clone(), + http_client.clone(), ); if !self.no_worker { @@ -243,7 +241,6 @@ impl Options { homeserver_connection, policy_factory, graphql_schema, - http_client_factory, http_client, password_manager, metadata_cache, diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index 7a797fc16..3bbef12dd 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -9,7 +9,6 @@ use std::process::ExitCode; use clap::Parser; use figment::Figment; use mas_config::{AppConfig, ConfigurationSection}; -use mas_handlers::HttpClientFactory; use mas_matrix_synapse::SynapseConnection; use mas_router::UrlBuilder; use rand::{ @@ -57,12 +56,12 @@ impl Options { let mailer = mailer_from_config(&config.email, &templates)?; mailer.test_connection().await?; - let http_client_factory = HttpClientFactory::new(); + let http_client = mas_http::reqwest_client(); let conn = SynapseConnection::new( config.matrix.homeserver.clone(), config.matrix.endpoint.clone(), config.matrix.secret.clone(), - http_client_factory, + http_client, ); drop(config); diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index eee5a73a4..220480b93 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -16,17 +16,36 @@ use tracing_subscriber::{ filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer, Registry, }; -use crate::sentry_transport::HyperTransportFactory; - mod app_state; mod commands; -mod sentry_transport; mod server; mod shutdown; mod sync; mod telemetry; mod util; +#[derive(Debug)] +struct SentryTransportFactory { + client: reqwest::Client, +} + +impl SentryTransportFactory { + fn new() -> Self { + Self { + client: mas_http::reqwest_client(), + } + } +} + +impl sentry::TransportFactory for SentryTransportFactory { + fn create_transport(&self, options: &sentry::ClientOptions) -> Arc { + let transport = + sentry::transports::ReqwestHttpTransport::with_client(options, self.client.clone()); + + Arc::new(transport) + } +} + #[tokio::main] async fn main() -> anyhow::Result { // We're splitting the "fallible" part of main in another function to have a @@ -79,9 +98,7 @@ async fn try_main() -> anyhow::Result { let sentry = sentry::init(( telemetry_config.sentry.dsn.as_deref(), sentry::ClientOptions { - transport: Some(Arc::new(HyperTransportFactory::new( - mas_http::make_untraced_client(), - ))), + transport: Some(Arc::new(SentryTransportFactory::new())), traces_sample_rate: 1.0, auto_session_tracking: true, session_mode: sentry::SessionMode::Request, diff --git a/crates/cli/src/sentry_transport/mod.rs b/crates/cli/src/sentry_transport/mod.rs deleted file mode 100644 index 8eb8dad35..000000000 --- a/crates/cli/src/sentry_transport/mod.rs +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -//! Implements a transport for Sentry based on Hyper. -//! -//! This avoids the dependency on `reqwest`, which helps avoiding having two -//! HTTP and TLS stacks in the binary. -//! -//! The [`ratelimit`] and [`tokio_thread`] modules are directly copied from the -//! Sentry codebase. - -use std::{sync::Arc, time::Duration}; - -use bytes::Bytes; -use http_body_util::{BodyExt, Full}; -use hyper::{header::RETRY_AFTER, StatusCode}; -use mas_http::UntracedClient; -use sentry::{sentry_debug, ClientOptions, Transport, TransportFactory}; - -use self::tokio_thread::TransportThread; - -mod ratelimit; -mod tokio_thread; - -pub struct HyperTransport { - thread: TransportThread, -} - -pub struct HyperTransportFactory { - client: UntracedClient>, -} - -impl HyperTransportFactory { - pub fn new(client: UntracedClient>) -> Self { - Self { client } - } -} - -impl TransportFactory for HyperTransportFactory { - fn create_transport(&self, options: &ClientOptions) -> Arc { - Arc::new(HyperTransport::new(options, self.client.clone())) - } -} - -impl HyperTransport { - pub fn new(options: &ClientOptions, client: UntracedClient>) -> Self { - let dsn = options.dsn.as_ref().unwrap(); - let user_agent = options.user_agent.clone(); - let auth = dsn.to_auth(Some(&user_agent)).to_string(); - let url = dsn.envelope_api_url().to_string(); - - let thread = TransportThread::new(move |envelope, mut rl| { - let mut body = Vec::new(); - envelope.to_writer(&mut body).unwrap(); - - let request = hyper::Request::post(&url) - .header("X-Sentry-Auth", &auth) - .body(Full::new(Bytes::from(body))) - .unwrap(); - - let fut = client.request(request); - - async move { - match fut.await { - Ok(response) => { - if let Some(sentry_header) = response - .headers() - .get("x-sentry-rate-limits") - .and_then(|x| x.to_str().ok()) - { - rl.update_from_sentry_header(sentry_header); - } else if let Some(retry_after) = response - .headers() - .get(RETRY_AFTER) - .and_then(|x| x.to_str().ok()) - { - rl.update_from_retry_after(retry_after); - } else if response.status() == StatusCode::TOO_MANY_REQUESTS { - rl.update_from_429(); - } - - match response.into_body().collect().await { - Err(err) => { - sentry_debug!("Failed to read sentry response: {}", err); - } - Ok(body) => { - let body = body.to_bytes(); - let text = String::from_utf8_lossy(&body); - sentry_debug!("Get response: `{}`", text); - } - } - } - Err(err) => { - sentry_debug!("Failed to send envelope: {}", err); - } - } - - rl - } - }); - - Self { thread } - } -} - -impl Transport for HyperTransport { - fn send_envelope(&self, envelope: sentry::Envelope) { - self.thread.send(envelope); - } - - fn flush(&self, timeout: Duration) -> bool { - self.thread.flush(timeout) - } - - fn shutdown(&self, timeout: Duration) -> bool { - self.flush(timeout) - } -} diff --git a/crates/cli/src/sentry_transport/ratelimit.rs b/crates/cli/src/sentry_transport/ratelimit.rs deleted file mode 100644 index e3d0b66e1..000000000 --- a/crates/cli/src/sentry_transport/ratelimit.rs +++ /dev/null @@ -1,180 +0,0 @@ -// Taken from sentry/transports/ratelimit.rs -#![allow(clippy::all, clippy::pedantic)] - -use std::time::{Duration, SystemTime}; - -use httpdate::parse_http_date; -use sentry::{protocol::EnvelopeItem, Envelope}; - -/// A Utility that helps with rate limiting sentry requests. -#[derive(Debug, Default)] -pub struct RateLimiter { - global: Option, - error: Option, - session: Option, - transaction: Option, - attachment: Option, -} - -impl RateLimiter { - /// Create a new RateLimiter. - pub fn new() -> Self { - Self::default() - } - - /// Updates the RateLimiter with information from a `Retry-After` header. - pub fn update_from_retry_after(&mut self, header: &str) { - let new_time = if let Ok(value) = header.parse::() { - SystemTime::now() + Duration::from_secs(value.ceil() as u64) - } else if let Ok(value) = parse_http_date(header) { - value - } else { - SystemTime::now() + Duration::from_secs(60) - }; - - self.global = Some(new_time); - } - - /// Updates the RateLimiter with information from a `X-Sentry-Rate-Limits` - /// header. - pub fn update_from_sentry_header(&mut self, header: &str) { - // = (,)+ - // =