Skip to content

Commit d49cb54

Browse files
authored
Allow to limit the number of concurrent requests made by the sdk (#3625)
Add a new `max_concurrent_requests` parameter in the `RequestConfig` limits the number of http(s) requests the internal sdk client issues concurrently (if > 0). The default behavior is the same as before: there is no limit on concurrent requests issued. This is especially useful for resource constrained platforms (e.g. mobile platforms), and if your pattern might lead to issuing many requests at the same time (like downloading and caching all avatars at startup). - [x] Public API changes documented in changelogs (optional) Signed-off-by: Benjamin Kampmann <[email protected]>
1 parent aaccfdf commit d49cb54

File tree

3 files changed

+163
-3
lines changed

3 files changed

+163
-3
lines changed

crates/matrix-sdk/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Breaking changes:
2222

2323
Additions:
2424

25+
- new `RequestConfig.max_concurrent_requests` which allows to limit the maximum number of concurrent requests the internal HTTP client issues (all others have to wait until the number drops below that threshold again)
2526
- Expose new method `Client::Oidc::login_with_qr_code()`.
2627
([#3466](https://github.com/matrix-org/matrix-rust-sdk/pull/3466))
2728
- Add the `ClientBuilder::add_root_certificates()` method which re-exposes the

crates/matrix-sdk/src/config/request.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
use std::{
1616
fmt::{self, Debug},
17+
num::NonZeroUsize,
1718
time::Duration,
1819
};
1920

@@ -44,18 +45,21 @@ pub struct RequestConfig {
4445
pub(crate) timeout: Duration,
4546
pub(crate) retry_limit: Option<u64>,
4647
pub(crate) retry_timeout: Option<Duration>,
48+
pub(crate) max_concurrent_requests: Option<NonZeroUsize>,
4749
pub(crate) force_auth: bool,
4850
}
4951

5052
#[cfg(not(tarpaulin_include))]
5153
impl Debug for RequestConfig {
5254
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
53-
let Self { timeout, retry_limit, retry_timeout, force_auth } = self;
55+
let Self { timeout, retry_limit, retry_timeout, force_auth, max_concurrent_requests } =
56+
self;
5457

5558
let mut res = fmt.debug_struct("RequestConfig");
5659
res.field("timeout", timeout)
5760
.maybe_field("retry_limit", retry_limit)
58-
.maybe_field("retry_timeout", retry_timeout);
61+
.maybe_field("retry_timeout", retry_timeout)
62+
.maybe_field("max_concurrent_requests", max_concurrent_requests);
5963

6064
if *force_auth {
6165
res.field("force_auth", &true);
@@ -71,6 +75,7 @@ impl Default for RequestConfig {
7175
timeout: DEFAULT_REQUEST_TIMEOUT,
7276
retry_limit: Default::default(),
7377
retry_timeout: Default::default(),
78+
max_concurrent_requests: Default::default(),
7479
force_auth: false,
7580
}
7681
}
@@ -106,6 +111,15 @@ impl RequestConfig {
106111
self
107112
}
108113

114+
/// The total limit of request that are pending or run concurrently.
115+
/// Any additional request beyond that number will be waiting until another
116+
/// concurrent requests finished. Requests are queued fairly.
117+
#[must_use]
118+
pub fn max_concurrent_requests(mut self, limit: Option<NonZeroUsize>) -> Self {
119+
self.max_concurrent_requests = limit;
120+
self
121+
}
122+
109123
/// Set the timeout duration for all HTTP requests.
110124
#[must_use]
111125
pub fn timeout(mut self, timeout: Duration) -> Self {

crates/matrix-sdk/src/http_client/mod.rs

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use std::{
1616
any::type_name,
1717
fmt::Debug,
18+
num::NonZeroUsize,
1819
sync::{
1920
atomic::{AtomicU64, Ordering},
2021
Arc,
@@ -30,6 +31,7 @@ use ruma::api::{
3031
error::{FromHttpResponseError, IntoHttpError},
3132
AuthScheme, MatrixVersion, OutgoingRequest, SendAccessToken,
3233
};
34+
use tokio::sync::{Semaphore, SemaphorePermit};
3335
use tracing::{debug, field::debug, instrument, trace};
3436

3537
use crate::{config::RequestConfig, error::HttpError};
@@ -48,16 +50,48 @@ pub(crate) use native::HttpSettings;
4850

4951
pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
5052

53+
#[derive(Clone, Debug)]
54+
struct MaybeSemaphore(Arc<Option<Semaphore>>);
55+
56+
#[allow(dead_code)] // false-positive lint: we never use it but only hold it for the drop
57+
struct MaybeSemaphorePermit<'a>(Option<SemaphorePermit<'a>>);
58+
59+
impl MaybeSemaphore {
60+
fn new(max: Option<NonZeroUsize>) -> Self {
61+
let inner = max.map(|i| Semaphore::new(i.into()));
62+
MaybeSemaphore(Arc::new(inner))
63+
}
64+
65+
async fn acquire(&self) -> MaybeSemaphorePermit<'_> {
66+
match self.0.as_ref() {
67+
Some(inner) => {
68+
// This can only ever error if the semaphore was closed,
69+
// which we never do, so we can safely ignore any error case
70+
MaybeSemaphorePermit(inner.acquire().await.ok())
71+
}
72+
None => MaybeSemaphorePermit(None),
73+
}
74+
}
75+
}
76+
5177
#[derive(Clone, Debug)]
5278
pub(crate) struct HttpClient {
5379
pub(crate) inner: reqwest::Client,
5480
pub(crate) request_config: RequestConfig,
81+
concurrent_request_semaphore: MaybeSemaphore,
5582
next_request_id: Arc<AtomicU64>,
5683
}
5784

5885
impl HttpClient {
5986
pub(crate) fn new(inner: reqwest::Client, request_config: RequestConfig) -> Self {
60-
HttpClient { inner, request_config, next_request_id: AtomicU64::new(0).into() }
87+
HttpClient {
88+
inner,
89+
request_config,
90+
concurrent_request_semaphore: MaybeSemaphore::new(
91+
request_config.max_concurrent_requests,
92+
),
93+
next_request_id: AtomicU64::new(0).into(),
94+
}
6195
}
6296

6397
fn get_request_id(&self) -> String {
@@ -184,6 +218,9 @@ impl HttpClient {
184218
request
185219
};
186220

221+
// will be automatically dropped at the end of this function
222+
let _handle = self.concurrent_request_semaphore.acquire().await;
223+
187224
debug!("Sending request");
188225

189226
// There's a bunch of state in send_request, factor out a pinned inner
@@ -259,3 +296,111 @@ impl tower::Service<http_old::Request<Bytes>> for HttpClient {
259296
Box::pin(fut)
260297
}
261298
}
299+
300+
#[cfg(all(test, not(target_arch = "wasm32")))]
301+
mod tests {
302+
use std::{
303+
num::NonZeroUsize,
304+
sync::{
305+
atomic::{AtomicU8, Ordering},
306+
Arc,
307+
},
308+
time::Duration,
309+
};
310+
311+
use matrix_sdk_test::{async_test, test_json};
312+
use wiremock::{
313+
matchers::{method, path},
314+
Mock, Request, ResponseTemplate,
315+
};
316+
317+
use crate::{
318+
http_client::RequestConfig,
319+
test_utils::{set_client_session, test_client_builder_with_server},
320+
};
321+
322+
#[async_test]
323+
async fn ensure_concurrent_request_limit_is_observed() {
324+
let (client_builder, server) = test_client_builder_with_server().await;
325+
let client = client_builder
326+
.request_config(RequestConfig::default().max_concurrent_requests(NonZeroUsize::new(5)))
327+
.build()
328+
.await
329+
.unwrap();
330+
331+
set_client_session(&client).await;
332+
333+
let counter = Arc::new(AtomicU8::new(0));
334+
let inner_counter = counter.clone();
335+
336+
Mock::given(method("GET"))
337+
.and(path("/_matrix/client/versions"))
338+
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
339+
.mount(&server)
340+
.await;
341+
342+
Mock::given(method("GET"))
343+
.and(path("_matrix/client/r0/account/whoami"))
344+
.respond_with(move |_req: &Request| {
345+
inner_counter.fetch_add(1, Ordering::SeqCst);
346+
// we stall the requests
347+
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
348+
})
349+
.mount(&server)
350+
.await;
351+
352+
let bg_task = tokio::spawn(async move {
353+
futures_util::future::join_all((0..10).map(|_| client.whoami())).await
354+
});
355+
356+
// give it some time to issue the requests
357+
tokio::time::sleep(Duration::from_millis(300)).await;
358+
359+
assert_eq!(
360+
counter.load(Ordering::SeqCst),
361+
5,
362+
"More requests passed than the limit we configured"
363+
);
364+
bg_task.abort();
365+
}
366+
367+
#[async_test]
368+
async fn ensure_no_max_concurrent_request_does_not_limit() {
369+
let (client_builder, server) = test_client_builder_with_server().await;
370+
let client = client_builder
371+
.request_config(RequestConfig::default().max_concurrent_requests(None))
372+
.build()
373+
.await
374+
.unwrap();
375+
376+
set_client_session(&client).await;
377+
378+
let counter = Arc::new(AtomicU8::new(0));
379+
let inner_counter = counter.clone();
380+
381+
Mock::given(method("GET"))
382+
.and(path("/_matrix/client/versions"))
383+
.respond_with(ResponseTemplate::new(200).set_body_json(&*test_json::VERSIONS))
384+
.mount(&server)
385+
.await;
386+
387+
Mock::given(method("GET"))
388+
.and(path("_matrix/client/r0/account/whoami"))
389+
.respond_with(move |_req: &Request| {
390+
inner_counter.fetch_add(1, Ordering::SeqCst);
391+
ResponseTemplate::new(200).set_delay(Duration::from_secs(60))
392+
})
393+
.mount(&server)
394+
.await;
395+
396+
let bg_task = tokio::spawn(async move {
397+
futures_util::future::join_all((0..254).map(|_| client.whoami())).await
398+
});
399+
400+
// give it some time to issue the requests
401+
tokio::time::sleep(Duration::from_secs(1)).await;
402+
403+
assert_eq!(counter.load(Ordering::SeqCst), 254, "Not all requests passed through");
404+
bg_task.abort();
405+
}
406+
}

0 commit comments

Comments
 (0)