Skip to content

Commit 86dba78

Browse files
authored
feat(cherry pick): basic load shedding for aggregators (#2628) (#2641)
feat: basic load shedding for aggregators (#2628) * basic load shedding for aggregators * Address comments
1 parent cc13b4b commit 86dba78

File tree

3 files changed

+256
-24
lines changed

3 files changed

+256
-24
lines changed

crates/walrus-service/src/client/cli/args.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,19 @@ pub struct AggregatorArgs {
887887
#[arg(long)]
888888
#[serde(default)]
889889
pub max_blob_size: Option<u64>,
890+
/// The maximum number of requests that can be buffered before the server starts rejecting new
891+
/// ones.
892+
#[arg(long = "max-buffer-size", default_value_t = default::max_aggregator_buffer_size())]
893+
#[serde(default = "default::max_aggregator_buffer_size")]
894+
pub max_request_buffer_size: usize,
895+
/// The maximum number of requests the aggregator can process concurrently.
896+
///
897+
/// If more requests than this maximum are received, the excess requests are buffered up to
898+
/// `--max-buffer-size`. Any outstanding request will result in a response with a
899+
/// 429 HTTP status code.
900+
#[arg(long, default_value_t = default::max_aggregator_concurrent_requests())]
901+
#[serde(default = "default::max_aggregator_concurrent_requests")]
902+
pub max_concurrent_requests: usize,
890903
}
891904

892905
/// The arguments for the publisher service.
@@ -1875,6 +1888,14 @@ pub(crate) mod default {
18751888
pub(crate) fn concurrent_requests_for_health() -> usize {
18761889
60
18771890
}
1891+
1892+
pub(crate) fn max_aggregator_concurrent_requests() -> usize {
1893+
256
1894+
}
1895+
1896+
pub(crate) fn max_aggregator_buffer_size() -> usize {
1897+
384
1898+
}
18781899
}
18791900

18801901
#[cfg(test)]
@@ -1967,6 +1988,8 @@ mod tests {
19671988
allowed_headers: default::allowed_headers(),
19681989
allow_quilt_patch_tags_in_response: false,
19691990
max_blob_size: None,
1991+
max_request_buffer_size: default::max_aggregator_buffer_size(),
1992+
max_concurrent_requests: default::max_aggregator_concurrent_requests(),
19701993
},
19711994
})
19721995
}

crates/walrus-service/src/client/cli/runner.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,8 +1234,7 @@ impl ClientCommandRunner {
12341234
client,
12351235
daemon_args.bind_address,
12361236
registry,
1237-
aggregator_args.allowed_headers,
1238-
aggregator_args.allow_quilt_patch_tags_in_response,
1237+
&aggregator_args,
12391238
))
12401239
}
12411240

crates/walrus-service/src/client/daemon.rs

Lines changed: 232 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,15 @@ impl<T: WalrusReadClient + Send + Sync + 'static> ClientDaemon<T> {
341341
client: T,
342342
network_address: SocketAddr,
343343
registry: &Registry,
344-
allowed_headers: Vec<String>,
345-
allow_quilt_patch_tags_in_response: bool,
344+
args: &AggregatorArgs,
346345
) -> Self {
347346
Self::new::<AggregatorApiDoc>(client, network_address, registry).with_aggregator(
348347
AggregatorResponseHeaderConfig {
349-
allowed_headers: allowed_headers.into_iter().collect(),
350-
allow_quilt_patch_tags_in_response,
348+
allowed_headers: args.allowed_headers.clone().into_iter().collect(),
349+
allow_quilt_patch_tags_in_response: args.allow_quilt_patch_tags_in_response,
351350
},
351+
args.max_request_buffer_size,
352+
args.max_concurrent_requests,
352353
)
353354
}
354355

@@ -370,33 +371,61 @@ impl<T: WalrusReadClient + Send + Sync + 'static> ClientDaemon<T> {
370371
}
371372

372373
/// Specifies that the daemon should expose the aggregator interface (read blobs).
373-
fn with_aggregator(mut self, response_header_config: AggregatorResponseHeaderConfig) -> Self {
374+
fn with_aggregator(
375+
mut self,
376+
response_header_config: AggregatorResponseHeaderConfig,
377+
max_request_buffer_size: usize,
378+
max_concurrent_requests: usize,
379+
) -> Self {
374380
self.response_header_config = Arc::new(response_header_config);
375381
tracing::info!(
376382
"Aggregator response header config: {:?}",
377383
self.response_header_config
378384
);
385+
tracing::debug!(
386+
%max_request_buffer_size,
387+
%max_concurrent_requests,
388+
"configuring the aggregator endpoint",
389+
);
390+
391+
let aggregator_layers = ServiceBuilder::new()
392+
.layer(HandleErrorLayer::new(handle_aggregator_error))
393+
// If inner service isn't ready, fail fast (no pile-ups)
394+
.layer(LoadShedLayer::new())
395+
// Small bounded queue to smooth tiny bursts
396+
.layer(BufferLayer::new(max_request_buffer_size))
397+
// Cap total in-flight requests across the aggregator
398+
.layer(ConcurrencyLimitLayer::new(max_concurrent_requests));
399+
379400
self.router = self
380401
.router
381-
.route(BLOB_GET_ENDPOINT, get(routes::get_blob))
402+
.route(
403+
BLOB_GET_ENDPOINT,
404+
get(routes::get_blob).route_layer(aggregator_layers.clone()),
405+
)
382406
.route(
383407
BLOB_OBJECT_GET_ENDPOINT,
384408
get(routes::get_blob_by_object_id)
385-
.with_state((self.client.clone(), self.response_header_config.clone())),
409+
.with_state((self.client.clone(), self.response_header_config.clone()))
410+
.route_layer(aggregator_layers.clone()),
386411
)
387412
.route(
388413
QUILT_PATCH_BY_ID_GET_ENDPOINT,
389414
get(routes::get_patch_by_quilt_patch_id)
390-
.with_state((self.client.clone(), self.response_header_config.clone())),
415+
.with_state((self.client.clone(), self.response_header_config.clone()))
416+
.route_layer(aggregator_layers.clone()),
391417
)
392418
.route(
393419
QUILT_PATCH_BY_IDENTIFIER_GET_ENDPOINT,
394420
get(routes::get_patch_by_quilt_id_and_identifier)
395-
.with_state((self.client.clone(), self.response_header_config.clone())),
421+
.with_state((self.client.clone(), self.response_header_config.clone()))
422+
.route_layer(aggregator_layers.clone()),
396423
)
397424
.route(
398425
LIST_PATCHES_IN_QUILT_ENDPOINT,
399-
get(routes::list_patches_in_quilt).with_state(self.client.clone()),
426+
get(routes::list_patches_in_quilt)
427+
.with_state(self.client.clone())
428+
.route_layer(aggregator_layers),
400429
);
401430
self
402431
}
@@ -456,15 +485,19 @@ impl<T: WalrusWriteClient + Send + Sync + 'static> ClientDaemon<T> {
456485
aggregator_args: &AggregatorArgs,
457486
) -> Self {
458487
Self::new::<DaemonApiDoc>(client, publisher_args.daemon_args.bind_address, registry)
459-
.with_aggregator(AggregatorResponseHeaderConfig {
460-
allowed_headers: aggregator_args
461-
.allowed_headers
462-
.clone()
463-
.into_iter()
464-
.collect(),
465-
allow_quilt_patch_tags_in_response: aggregator_args
466-
.allow_quilt_patch_tags_in_response,
467-
})
488+
.with_aggregator(
489+
AggregatorResponseHeaderConfig {
490+
allowed_headers: aggregator_args
491+
.allowed_headers
492+
.clone()
493+
.into_iter()
494+
.collect(),
495+
allow_quilt_patch_tags_in_response: aggregator_args
496+
.allow_quilt_patch_tags_in_response,
497+
},
498+
aggregator_args.max_request_buffer_size,
499+
aggregator_args.max_concurrent_requests,
500+
)
468501
.with_publisher(
469502
auth_config,
470503
publisher_args.max_body_size_kib,
@@ -566,18 +599,195 @@ pub(crate) async fn auth_layer(
566599
}
567600
}
568601

569-
async fn handle_publisher_error(error: BoxError) -> Response {
602+
/// Handles errors from Tower middleware layers for service endpoints.
603+
///
604+
/// Returns HTTP 429 for overload errors, and HTTP 500 with error details for other errors.
605+
async fn handle_service_error(error: BoxError, service_name: &str) -> Response {
570606
if error.is::<Overloaded>() {
571607
(
572608
StatusCode::TOO_MANY_REQUESTS,
573-
"the publisher is receiving too many requests; please try again later",
609+
format!("the {service_name} is receiving too many requests; please try again later"),
574610
)
575611
.into_response()
576612
} else {
577613
(
578614
StatusCode::INTERNAL_SERVER_ERROR,
579-
"something went wrong while storing the blob",
615+
format!("{service_name} internal server error: {error}"),
580616
)
581617
.into_response()
582618
}
583619
}
620+
621+
async fn handle_aggregator_error(error: BoxError) -> Response {
622+
handle_service_error(error, "aggregator").await
623+
}
624+
625+
async fn handle_publisher_error(error: BoxError) -> Response {
626+
handle_service_error(error, "publisher").await
627+
}
628+
629+
#[cfg(test)]
630+
mod tests {
631+
use std::{
632+
sync::atomic::{AtomicUsize, Ordering},
633+
time::Duration,
634+
};
635+
636+
use axum::http::StatusCode as HttpStatusCode;
637+
use tower::ServiceExt;
638+
use walrus_core::BlobId;
639+
640+
use super::*;
641+
642+
/// Mock client that simulates slow blob reads to test concurrency limits.
643+
#[derive(Clone)]
644+
struct MockSlowClient {
645+
/// Tracks the maximum number of concurrent requests observed.
646+
max_concurrent: Arc<AtomicUsize>,
647+
/// Tracks the current number of active requests.
648+
active_requests: Arc<AtomicUsize>,
649+
/// Artificial delay for read operations.
650+
delay: Duration,
651+
}
652+
653+
impl MockSlowClient {
654+
fn new(delay: Duration) -> Self {
655+
Self {
656+
max_concurrent: Arc::new(AtomicUsize::new(0)),
657+
active_requests: Arc::new(AtomicUsize::new(0)),
658+
delay,
659+
}
660+
}
661+
}
662+
663+
impl WalrusReadClient for MockSlowClient {
664+
async fn read_blob(&self, _blob_id: &BlobId) -> ClientResult<Vec<u8>> {
665+
// Increment active request counter and track max
666+
let current = self.active_requests.fetch_add(1, Ordering::SeqCst) + 1;
667+
self.max_concurrent.fetch_max(current, Ordering::SeqCst);
668+
669+
// Simulate slow read
670+
tokio::time::sleep(self.delay).await;
671+
672+
// Decrement active request counter
673+
self.active_requests.fetch_sub(1, Ordering::SeqCst);
674+
675+
Ok(b"mock data".to_vec())
676+
}
677+
678+
async fn get_blob_by_object_id(
679+
&self,
680+
_blob_object_id: &ObjectID,
681+
) -> ClientResult<walrus_sui::types::move_structs::BlobWithAttribute> {
682+
unimplemented!("not needed for rate limit tests")
683+
}
684+
}
685+
686+
#[tokio::test]
687+
async fn test_aggregator_rate_limiting_returns_429() {
688+
// Create a registry for metrics
689+
let registry = Registry::new(prometheus::Registry::new());
690+
691+
// Configure very low limits to easily trigger rate limiting
692+
let max_concurrent = 2;
693+
let max_buffer = 1;
694+
let num_requests = 5; // More than max_concurrent + max_buffer
695+
696+
// Create mock client with slow responses
697+
let mock_client = MockSlowClient::new(Duration::from_millis(100));
698+
let active_counter = mock_client.active_requests.clone();
699+
let max_concurrent_counter = mock_client.max_concurrent.clone();
700+
701+
// Create aggregator with low limits
702+
let args = AggregatorArgs {
703+
allowed_headers: vec![],
704+
allow_quilt_patch_tags_in_response: false,
705+
max_blob_size: None,
706+
max_request_buffer_size: max_buffer,
707+
max_concurrent_requests: max_concurrent,
708+
};
709+
710+
let daemon = ClientDaemon::new_aggregator(
711+
mock_client,
712+
"127.0.0.1:0".parse().unwrap(),
713+
&registry,
714+
&args,
715+
);
716+
717+
// Get the router (without global middleware for simpler testing)
718+
let app = daemon.router.with_state(daemon.client);
719+
720+
// Create a random blob ID for testing
721+
let blob_id = walrus_core::test_utils::random_blob_id();
722+
723+
// Launch concurrent requests
724+
let mut handles = vec![];
725+
for _ in 0..num_requests {
726+
let app = app.clone();
727+
let handle = tokio::spawn(async move {
728+
let request = axum::http::Request::builder()
729+
.uri(format!("/v1/blobs/{}", blob_id))
730+
.body(axum::body::Body::empty())
731+
.unwrap();
732+
733+
app.oneshot(request).await
734+
});
735+
handles.push(handle);
736+
}
737+
738+
// Wait for all requests to complete
739+
let results = futures::future::join_all(handles).await;
740+
741+
// Count successful and rate-limited responses
742+
let mut success_count = 0;
743+
let mut rate_limited_count = 0;
744+
745+
for result in results {
746+
let response = result.expect("request should complete");
747+
match response.unwrap().status() {
748+
HttpStatusCode::OK => success_count += 1,
749+
HttpStatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1,
750+
status => panic!("unexpected status code: {}", status),
751+
}
752+
}
753+
754+
// Verify that some requests were rate limited
755+
assert!(
756+
rate_limited_count > 0,
757+
"Expected some requests to be rate limited, but got {} successes and {} rate limited",
758+
success_count,
759+
rate_limited_count
760+
);
761+
762+
// Verify the total adds up
763+
assert_eq!(
764+
success_count + rate_limited_count,
765+
num_requests,
766+
"Total responses should equal number of requests"
767+
);
768+
769+
// The number of successful requests should not exceed max_concurrent + max_buffer
770+
assert!(
771+
success_count <= max_concurrent + max_buffer,
772+
"Success count {} should not exceed max_concurrent + max_buffer = {}",
773+
success_count,
774+
max_concurrent + max_buffer
775+
);
776+
777+
// Ensure no requests are still active
778+
assert_eq!(
779+
active_counter.load(Ordering::SeqCst),
780+
0,
781+
"All requests should have completed"
782+
);
783+
784+
// Verify the concurrency limit was enforced
785+
let observed_max_concurrent = max_concurrent_counter.load(Ordering::SeqCst);
786+
assert!(
787+
observed_max_concurrent <= max_concurrent,
788+
"Observed max concurrent {} should not exceed limit {}",
789+
observed_max_concurrent,
790+
max_concurrent
791+
);
792+
}
793+
}

0 commit comments

Comments
 (0)