@@ -341,14 +341,15 @@ impl<T: WalrusReadClient + Send + Sync + 'static> ClientDaemon<T> {
341341 client : T ,
342342 network_address : SocketAddr ,
343343 registry : & Registry ,
344- allowed_headers : Vec < String > ,
345- allow_quilt_patch_tags_in_response : bool ,
344+ args : & AggregatorArgs ,
346345 ) -> Self {
347346 Self :: new :: < AggregatorApiDoc > ( client, network_address, registry) . with_aggregator (
348347 AggregatorResponseHeaderConfig {
349- allowed_headers : allowed_headers. into_iter ( ) . collect ( ) ,
350- allow_quilt_patch_tags_in_response,
348+ allowed_headers : args . allowed_headers . clone ( ) . into_iter ( ) . collect ( ) ,
349+ allow_quilt_patch_tags_in_response : args . allow_quilt_patch_tags_in_response ,
351350 } ,
351+ args. max_request_buffer_size ,
352+ args. max_concurrent_requests ,
352353 )
353354 }
354355
@@ -370,33 +371,61 @@ impl<T: WalrusReadClient + Send + Sync + 'static> ClientDaemon<T> {
370371 }
371372
372373 /// Specifies that the daemon should expose the aggregator interface (read blobs).
373- fn with_aggregator ( mut self , response_header_config : AggregatorResponseHeaderConfig ) -> Self {
374+ fn with_aggregator (
375+ mut self ,
376+ response_header_config : AggregatorResponseHeaderConfig ,
377+ max_request_buffer_size : usize ,
378+ max_concurrent_requests : usize ,
379+ ) -> Self {
374380 self . response_header_config = Arc :: new ( response_header_config) ;
375381 tracing:: info!(
376382 "Aggregator response header config: {:?}" ,
377383 self . response_header_config
378384 ) ;
385+ tracing:: debug!(
386+ %max_request_buffer_size,
387+ %max_concurrent_requests,
388+ "configuring the aggregator endpoint" ,
389+ ) ;
390+
391+ let aggregator_layers = ServiceBuilder :: new ( )
392+ . layer ( HandleErrorLayer :: new ( handle_aggregator_error) )
393+ // If inner service isn't ready, fail fast (no pile-ups)
394+ . layer ( LoadShedLayer :: new ( ) )
395+ // Small bounded queue to smooth tiny bursts
396+ . layer ( BufferLayer :: new ( max_request_buffer_size) )
397+ // Cap total in-flight requests across the aggregator
398+ . layer ( ConcurrencyLimitLayer :: new ( max_concurrent_requests) ) ;
399+
379400 self . router = self
380401 . router
381- . route ( BLOB_GET_ENDPOINT , get ( routes:: get_blob) )
402+ . route (
403+ BLOB_GET_ENDPOINT ,
404+ get ( routes:: get_blob) . route_layer ( aggregator_layers. clone ( ) ) ,
405+ )
382406 . route (
383407 BLOB_OBJECT_GET_ENDPOINT ,
384408 get ( routes:: get_blob_by_object_id)
385- . with_state ( ( self . client . clone ( ) , self . response_header_config . clone ( ) ) ) ,
409+ . with_state ( ( self . client . clone ( ) , self . response_header_config . clone ( ) ) )
410+ . route_layer ( aggregator_layers. clone ( ) ) ,
386411 )
387412 . route (
388413 QUILT_PATCH_BY_ID_GET_ENDPOINT ,
389414 get ( routes:: get_patch_by_quilt_patch_id)
390- . with_state ( ( self . client . clone ( ) , self . response_header_config . clone ( ) ) ) ,
415+ . with_state ( ( self . client . clone ( ) , self . response_header_config . clone ( ) ) )
416+ . route_layer ( aggregator_layers. clone ( ) ) ,
391417 )
392418 . route (
393419 QUILT_PATCH_BY_IDENTIFIER_GET_ENDPOINT ,
394420 get ( routes:: get_patch_by_quilt_id_and_identifier)
395- . with_state ( ( self . client . clone ( ) , self . response_header_config . clone ( ) ) ) ,
421+ . with_state ( ( self . client . clone ( ) , self . response_header_config . clone ( ) ) )
422+ . route_layer ( aggregator_layers. clone ( ) ) ,
396423 )
397424 . route (
398425 LIST_PATCHES_IN_QUILT_ENDPOINT ,
399- get ( routes:: list_patches_in_quilt) . with_state ( self . client . clone ( ) ) ,
426+ get ( routes:: list_patches_in_quilt)
427+ . with_state ( self . client . clone ( ) )
428+ . route_layer ( aggregator_layers) ,
400429 ) ;
401430 self
402431 }
@@ -456,15 +485,19 @@ impl<T: WalrusWriteClient + Send + Sync + 'static> ClientDaemon<T> {
456485 aggregator_args : & AggregatorArgs ,
457486 ) -> Self {
458487 Self :: new :: < DaemonApiDoc > ( client, publisher_args. daemon_args . bind_address , registry)
459- . with_aggregator ( AggregatorResponseHeaderConfig {
460- allowed_headers : aggregator_args
461- . allowed_headers
462- . clone ( )
463- . into_iter ( )
464- . collect ( ) ,
465- allow_quilt_patch_tags_in_response : aggregator_args
466- . allow_quilt_patch_tags_in_response ,
467- } )
488+ . with_aggregator (
489+ AggregatorResponseHeaderConfig {
490+ allowed_headers : aggregator_args
491+ . allowed_headers
492+ . clone ( )
493+ . into_iter ( )
494+ . collect ( ) ,
495+ allow_quilt_patch_tags_in_response : aggregator_args
496+ . allow_quilt_patch_tags_in_response ,
497+ } ,
498+ aggregator_args. max_request_buffer_size ,
499+ aggregator_args. max_concurrent_requests ,
500+ )
468501 . with_publisher (
469502 auth_config,
470503 publisher_args. max_body_size_kib ,
@@ -566,18 +599,195 @@ pub(crate) async fn auth_layer(
566599 }
567600}
568601
569- async fn handle_publisher_error ( error : BoxError ) -> Response {
602+ /// Handles errors from Tower middleware layers for service endpoints.
603+ ///
604+ /// Returns HTTP 429 for overload errors, and HTTP 500 with error details for other errors.
605+ async fn handle_service_error ( error : BoxError , service_name : & str ) -> Response {
570606 if error. is :: < Overloaded > ( ) {
571607 (
572608 StatusCode :: TOO_MANY_REQUESTS ,
573- "the publisher is receiving too many requests; please try again later" ,
609+ format ! ( "the {service_name} is receiving too many requests; please try again later" ) ,
574610 )
575611 . into_response ( )
576612 } else {
577613 (
578614 StatusCode :: INTERNAL_SERVER_ERROR ,
579- "something went wrong while storing the blob" ,
615+ format ! ( "{service_name} internal server error: {error}" ) ,
580616 )
581617 . into_response ( )
582618 }
583619}
620+
621+ async fn handle_aggregator_error ( error : BoxError ) -> Response {
622+ handle_service_error ( error, "aggregator" ) . await
623+ }
624+
625+ async fn handle_publisher_error ( error : BoxError ) -> Response {
626+ handle_service_error ( error, "publisher" ) . await
627+ }
628+
629+ #[ cfg( test) ]
630+ mod tests {
631+ use std:: {
632+ sync:: atomic:: { AtomicUsize , Ordering } ,
633+ time:: Duration ,
634+ } ;
635+
636+ use axum:: http:: StatusCode as HttpStatusCode ;
637+ use tower:: ServiceExt ;
638+ use walrus_core:: BlobId ;
639+
640+ use super :: * ;
641+
642+ /// Mock client that simulates slow blob reads to test concurrency limits.
643+ #[ derive( Clone ) ]
644+ struct MockSlowClient {
645+ /// Tracks the maximum number of concurrent requests observed.
646+ max_concurrent : Arc < AtomicUsize > ,
647+ /// Tracks the current number of active requests.
648+ active_requests : Arc < AtomicUsize > ,
649+ /// Artificial delay for read operations.
650+ delay : Duration ,
651+ }
652+
653+ impl MockSlowClient {
654+ fn new ( delay : Duration ) -> Self {
655+ Self {
656+ max_concurrent : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
657+ active_requests : Arc :: new ( AtomicUsize :: new ( 0 ) ) ,
658+ delay,
659+ }
660+ }
661+ }
662+
663+ impl WalrusReadClient for MockSlowClient {
664+ async fn read_blob ( & self , _blob_id : & BlobId ) -> ClientResult < Vec < u8 > > {
665+ // Increment active request counter and track max
666+ let current = self . active_requests . fetch_add ( 1 , Ordering :: SeqCst ) + 1 ;
667+ self . max_concurrent . fetch_max ( current, Ordering :: SeqCst ) ;
668+
669+ // Simulate slow read
670+ tokio:: time:: sleep ( self . delay ) . await ;
671+
672+ // Decrement active request counter
673+ self . active_requests . fetch_sub ( 1 , Ordering :: SeqCst ) ;
674+
675+ Ok ( b"mock data" . to_vec ( ) )
676+ }
677+
678+ async fn get_blob_by_object_id (
679+ & self ,
680+ _blob_object_id : & ObjectID ,
681+ ) -> ClientResult < walrus_sui:: types:: move_structs:: BlobWithAttribute > {
682+ unimplemented ! ( "not needed for rate limit tests" )
683+ }
684+ }
685+
686+ #[ tokio:: test]
687+ async fn test_aggregator_rate_limiting_returns_429 ( ) {
688+ // Create a registry for metrics
689+ let registry = Registry :: new ( prometheus:: Registry :: new ( ) ) ;
690+
691+ // Configure very low limits to easily trigger rate limiting
692+ let max_concurrent = 2 ;
693+ let max_buffer = 1 ;
694+ let num_requests = 5 ; // More than max_concurrent + max_buffer
695+
696+ // Create mock client with slow responses
697+ let mock_client = MockSlowClient :: new ( Duration :: from_millis ( 100 ) ) ;
698+ let active_counter = mock_client. active_requests . clone ( ) ;
699+ let max_concurrent_counter = mock_client. max_concurrent . clone ( ) ;
700+
701+ // Create aggregator with low limits
702+ let args = AggregatorArgs {
703+ allowed_headers : vec ! [ ] ,
704+ allow_quilt_patch_tags_in_response : false ,
705+ max_blob_size : None ,
706+ max_request_buffer_size : max_buffer,
707+ max_concurrent_requests : max_concurrent,
708+ } ;
709+
710+ let daemon = ClientDaemon :: new_aggregator (
711+ mock_client,
712+ "127.0.0.1:0" . parse ( ) . unwrap ( ) ,
713+ & registry,
714+ & args,
715+ ) ;
716+
717+ // Get the router (without global middleware for simpler testing)
718+ let app = daemon. router . with_state ( daemon. client ) ;
719+
720+ // Create a random blob ID for testing
721+ let blob_id = walrus_core:: test_utils:: random_blob_id ( ) ;
722+
723+ // Launch concurrent requests
724+ let mut handles = vec ! [ ] ;
725+ for _ in 0 ..num_requests {
726+ let app = app. clone ( ) ;
727+ let handle = tokio:: spawn ( async move {
728+ let request = axum:: http:: Request :: builder ( )
729+ . uri ( format ! ( "/v1/blobs/{}" , blob_id) )
730+ . body ( axum:: body:: Body :: empty ( ) )
731+ . unwrap ( ) ;
732+
733+ app. oneshot ( request) . await
734+ } ) ;
735+ handles. push ( handle) ;
736+ }
737+
738+ // Wait for all requests to complete
739+ let results = futures:: future:: join_all ( handles) . await ;
740+
741+ // Count successful and rate-limited responses
742+ let mut success_count = 0 ;
743+ let mut rate_limited_count = 0 ;
744+
745+ for result in results {
746+ let response = result. expect ( "request should complete" ) ;
747+ match response. unwrap ( ) . status ( ) {
748+ HttpStatusCode :: OK => success_count += 1 ,
749+ HttpStatusCode :: TOO_MANY_REQUESTS => rate_limited_count += 1 ,
750+ status => panic ! ( "unexpected status code: {}" , status) ,
751+ }
752+ }
753+
754+ // Verify that some requests were rate limited
755+ assert ! (
756+ rate_limited_count > 0 ,
757+ "Expected some requests to be rate limited, but got {} successes and {} rate limited" ,
758+ success_count,
759+ rate_limited_count
760+ ) ;
761+
762+ // Verify the total adds up
763+ assert_eq ! (
764+ success_count + rate_limited_count,
765+ num_requests,
766+ "Total responses should equal number of requests"
767+ ) ;
768+
769+ // The number of successful requests should not exceed max_concurrent + max_buffer
770+ assert ! (
771+ success_count <= max_concurrent + max_buffer,
772+ "Success count {} should not exceed max_concurrent + max_buffer = {}" ,
773+ success_count,
774+ max_concurrent + max_buffer
775+ ) ;
776+
777+ // Ensure no requests are still active
778+ assert_eq ! (
779+ active_counter. load( Ordering :: SeqCst ) ,
780+ 0 ,
781+ "All requests should have completed"
782+ ) ;
783+
784+ // Verify the concurrency limit was enforced
785+ let observed_max_concurrent = max_concurrent_counter. load ( Ordering :: SeqCst ) ;
786+ assert ! (
787+ observed_max_concurrent <= max_concurrent,
788+ "Observed max concurrent {} should not exceed limit {}" ,
789+ observed_max_concurrent,
790+ max_concurrent
791+ ) ;
792+ }
793+ }
0 commit comments