Skip to content

Commit 366fb1d

Browse files
committed
WIP: switch to reqwest
1 parent 26d945d commit 366fb1d

File tree

20 files changed

+222
-145
lines changed

20 files changed

+222
-145
lines changed

Cargo.lock

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ version = "0.3.0"
182182
[workspace.dependencies.rand]
183183
version = "0.8.5"
184184

185+
# High-level HTTP client
186+
[workspace.dependencies.reqwest]
187+
version = "0.12.8"
188+
default-features = false
189+
features = ["http2", "rustls-tls-manual-roots", "charset", "json"]
190+
185191
# TLS stack
186192
[workspace.dependencies.rustls]
187193
version = "0.23.15"

crates/handlers/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ aide.workspace = true
4141
async-graphql.workspace = true
4242
schemars.workspace = true
4343

44+
# HTTP client
45+
reqwest.workspace = true
46+
4447
# Emails
4548
lettre.workspace = true
4649

crates/handlers/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ where
312312
MetadataCache: FromRef<S>,
313313
SiteConfig: FromRef<S>,
314314
Limiter: FromRef<S>,
315+
reqwest::Client: FromRef<S>,
315316
BoxHomeserverConnection: FromRef<S>,
316317
BoxClock: FromRequestParts<S>,
317318
BoxRng: FromRequestParts<S>,

crates/handlers/src/test_utils.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ pub(crate) struct TestState {
109109
pub limiter: Limiter,
110110
pub clock: Arc<MockClock>,
111111
pub rng: Arc<Mutex<ChaChaRng>>,
112+
pub http_client: reqwest::Client,
112113

113114
#[allow(dead_code)] // It is used, as it will cancel the CancellationToken when dropped
114115
cancellation_drop_guard: Arc<DropGuard>,
@@ -169,6 +170,8 @@ impl TestState {
169170
)
170171
.await?;
171172

173+
let http_client = mas_http::reqwest_client();
174+
172175
// TODO: add more test keys to the store
173176
let rsa =
174177
PrivateKey::load_pem(include_str!("../../keystore/tests/keys/rsa.pkcs1.pem")).unwrap();
@@ -241,6 +244,7 @@ impl TestState {
241244
limiter,
242245
clock,
243246
rng,
247+
http_client,
244248
cancellation_drop_guard: Arc::new(shutdown_token.drop_guard()),
245249
})
246250
}
@@ -494,6 +498,12 @@ impl FromRef<TestState> for Limiter {
494498
}
495499
}
496500

501+
impl FromRef<TestState> for reqwest::Client {
502+
fn from_ref(input: &TestState) -> Self {
503+
input.http_client.clone()
504+
}
505+
}
506+
497507
#[async_trait]
498508
impl FromRequestParts<TestState> for ActivityTracker {
499509
type Rejection = Infallible;

crates/handlers/src/upstream_oauth2/authorize.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ impl IntoResponse for RouteError {
6262
pub(crate) async fn get(
6363
mut rng: BoxRng,
6464
clock: BoxClock,
65-
State(http_client_factory): State<HttpClientFactory>,
6665
State(metadata_cache): State<MetadataCache>,
6766
mut repo: BoxRepository,
6867
State(url_builder): State<UrlBuilder>,
68+
State(http_client): State<reqwest::Client>,
6969
cookie_jar: CookieJar,
7070
Path(provider_id): Path<Ulid>,
7171
Query(query): Query<OptionalPostAuthAction>,
@@ -77,12 +77,10 @@ pub(crate) async fn get(
7777
.filter(UpstreamOAuthProvider::enabled)
7878
.ok_or(RouteError::ProviderNotFound)?;
7979

80-
let http_service = http_client_factory.http_service("upstream_oauth2.authorize");
81-
8280
// First, discover the provider
8381
// This is done lazyly according to provider.discovery_mode and the various
8482
// endpoint overrides
85-
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
83+
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_client);
8684
lazy_metadata.maybe_discover().await?;
8785

8886
let redirect_uri = url_builder.upstream_oauth_callback(provider.id);

crates/handlers/src/upstream_oauth2/cache.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use std::{collections::HashMap, sync::Arc};
99
use mas_data_model::{
1010
UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
1111
};
12-
use mas_http::HttpService;
1312
use mas_iana::oauth::PkceCodeChallengeMethod;
1413
use mas_oidc_client::error::DiscoveryError;
1514
use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, RepositoryAccess};
@@ -22,20 +21,20 @@ use url::Url;
2221
pub struct LazyProviderInfos<'a> {
2322
cache: &'a MetadataCache,
2423
provider: &'a UpstreamOAuthProvider,
25-
http_service: &'a HttpService,
24+
client: &'a reqwest::Client,
2625
loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
2726
}
2827

2928
impl<'a> LazyProviderInfos<'a> {
3029
pub fn new(
3130
cache: &'a MetadataCache,
3231
provider: &'a UpstreamOAuthProvider,
33-
http_service: &'a HttpService,
32+
client: &'a reqwest::Client,
3433
) -> Self {
3534
Self {
3635
cache,
3736
provider,
38-
http_service,
37+
client,
3938
loaded_metadata: None,
4039
}
4140
}
@@ -64,7 +63,7 @@ impl<'a> LazyProviderInfos<'a> {
6463

6564
let metadata = self
6665
.cache
67-
.get(self.http_service, &self.provider.issuer, verify)
66+
.get(self.client, &self.provider.issuer, verify)
6867
.await?;
6968

7069
self.loaded_metadata = Some(metadata);
@@ -155,7 +154,7 @@ impl MetadataCache {
155154
#[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all, err)]
156155
pub async fn warm_up_and_run<R: RepositoryAccess>(
157156
&self,
158-
http_service: HttpService,
157+
client: &reqwest::Client,
159158
interval: std::time::Duration,
160159
repository: &mut R,
161160
) -> Result<tokio::task::JoinHandle<()>, R::Error> {
@@ -168,32 +167,32 @@ impl MetadataCache {
168167
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
169168
};
170169

171-
if let Err(e) = self.fetch(&http_service, &provider.issuer, verify).await {
170+
if let Err(e) = self.fetch(client, &provider.issuer, verify).await {
172171
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
173172
}
174173
}
175174

176175
// Spawn a background task to refresh the cache regularly
177176
let cache = self.clone();
177+
let client = client.clone();
178178
Ok(tokio::spawn(async move {
179179
loop {
180180
// Re-fetch the known metadata at the given interval
181181
tokio::time::sleep(interval).await;
182-
cache.refresh_all(&http_service).await;
182+
cache.refresh_all(&client).await;
183183
}
184184
}))
185185
}
186186

187187
#[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all, err)]
188188
async fn fetch(
189189
&self,
190-
http_service: &HttpService,
190+
client: &reqwest::Client,
191191
issuer: &str,
192192
verify: bool,
193193
) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
194194
if verify {
195-
let metadata =
196-
mas_oidc_client::requests::discovery::discover(http_service, issuer).await?;
195+
let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
197196
let metadata = Arc::new(metadata);
198197

199198
self.cache
@@ -204,8 +203,7 @@ impl MetadataCache {
204203
Ok(metadata)
205204
} else {
206205
let metadata =
207-
mas_oidc_client::requests::discovery::insecure_discover(http_service, issuer)
208-
.await?;
206+
mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
209207
let metadata = Arc::new(metadata);
210208

211209
self.insecure_cache
@@ -221,7 +219,7 @@ impl MetadataCache {
221219
#[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all, err)]
222220
pub async fn get(
223221
&self,
224-
http_service: &HttpService,
222+
client: &reqwest::Client,
225223
issuer: &str,
226224
verify: bool,
227225
) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
@@ -237,20 +235,20 @@ impl MetadataCache {
237235
// Drop the cache guard so that we don't deadlock when we try to fetch
238236
drop(cache);
239237

240-
let metadata = self.fetch(http_service, issuer, verify).await?;
238+
let metadata = self.fetch(client, issuer, verify).await?;
241239
Ok(metadata)
242240
}
243241

244242
#[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
245-
async fn refresh_all(&self, http_service: &HttpService) {
243+
async fn refresh_all(&self, client: &reqwest::Client) {
246244
// Grab all the keys first to avoid locking the cache for too long
247245
let keys: Vec<String> = {
248246
let cache = self.cache.read().await;
249247
cache.keys().cloned().collect()
250248
};
251249

252250
for issuer in keys {
253-
if let Err(e) = self.fetch(http_service, &issuer, true).await {
251+
if let Err(e) = self.fetch(client, &issuer, true).await {
254252
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
255253
}
256254
}
@@ -262,13 +260,14 @@ impl MetadataCache {
262260
};
263261

264262
for issuer in keys {
265-
if let Err(e) = self.fetch(http_service, &issuer, false).await {
263+
if let Err(e) = self.fetch(client, &issuer, false).await {
266264
tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
267265
}
268266
}
269267
}
270268
}
271269

270+
/* TODO: redo those tests
272271
#[cfg(test)]
273272
mod tests {
274273
#![allow(clippy::too_many_lines)]
@@ -619,3 +618,4 @@ mod tests {
619618
}
620619
}
621620
}
621+
*/

crates/handlers/src/upstream_oauth2/callback.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ pub(crate) async fn get(
134134
State(url_builder): State<UrlBuilder>,
135135
State(encrypter): State<Encrypter>,
136136
State(keystore): State<Keystore>,
137+
State(client): State<reqwest::Client>,
137138
cookie_jar: CookieJar,
138139
Path(provider_id): Path<Ulid>,
139140
Query(params): Query<QueryParams>,
@@ -186,12 +187,11 @@ pub(crate) async fn get(
186187
CodeOrError::Code { code } => code,
187188
};
188189

189-
let http_service = http_client_factory.http_service("upstream_oauth2.callback");
190-
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &http_service);
190+
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
191191

192192
// Fetch the JWKS
193193
let jwks =
194-
mas_oidc_client::requests::jose::fetch_jwks(&http_service, lazy_metadata.jwks_uri().await?)
194+
mas_oidc_client::requests::jose::fetch_jwks(&client, lazy_metadata.jwks_uri().await?)
195195
.await?;
196196

197197
// Figure out the client credentials
@@ -222,7 +222,7 @@ pub(crate) async fn get(
222222

223223
let (response, id_token) =
224224
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
225-
&http_service,
225+
&client,
226226
client_credentials,
227227
lazy_metadata.token_endpoint().await?,
228228
code,

crates/http/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ hyper.workspace = true
2323
hyper-util.workspace = true
2424
hyper-rustls = { workspace = true, optional = true }
2525
opentelemetry.workspace = true
26+
opentelemetry-http.workspace = true
2627
opentelemetry-semantic-conventions.workspace = true
2728
rustls = { workspace = true, optional = true }
2829
rustls-platform-verifier = { workspace = true, optional = true }
2930
pin-project-lite = "0.2.14"
31+
reqwest.workspace = true
3032
serde.workspace = true
3133
serde_json.workspace = true
3234
serde_urlencoded = "0.7.1"

crates/http/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
mod client;
1414
mod ext;
1515
mod layers;
16+
mod reqwest;
1617
mod service;
1718

1819
#[cfg(feature = "client")]
@@ -33,6 +34,7 @@ pub use self::{
3334
json_request::{self, JsonRequest, JsonRequestLayer},
3435
json_response::{self, JsonResponse, JsonResponseLayer},
3536
},
37+
reqwest::{client as reqwest_client, RequestBuilderExt},
3638
service::{BoxCloneSyncService, HttpService},
3739
};
3840

0 commit comments

Comments
 (0)