@@ -351,6 +351,24 @@ impl<'de> Deserialize<'de> for RateLimiter {
351351}
352352
353353impl RateLimiter {
354+ // description
355+ fn make_bucket (
356+ total_capacity : u64 ,
357+ one_time_burst : Option < u64 > ,
358+ complete_refill_time_ms : u64 ,
359+ ) -> Option < TokenBucket > {
360+ // If either token bucket capacity or refill time is 0, disable limiting.
361+ if total_capacity != 0 && complete_refill_time_ms != 0 {
362+ Some ( TokenBucket :: new (
363+ total_capacity,
364+ one_time_burst,
365+ complete_refill_time_ms,
366+ ) )
367+ } else {
368+ None
369+ }
370+ }
371+
354372 /// Creates a new Rate Limiter that can limit on both bytes/s and ops/s.
355373 ///
356374 /// # Arguments
@@ -380,28 +398,17 @@ impl RateLimiter {
380398 ops_one_time_burst : Option < u64 > ,
381399 ops_complete_refill_time_ms : u64 ,
382400 ) -> io:: Result < Self > {
383- // If either bytes token bucket capacity or refill time is 0, disable limiting on bytes/s.
384- let bytes_token_bucket = if bytes_total_capacity != 0 && bytes_complete_refill_time_ms != 0
385- {
386- Some ( TokenBucket :: new (
387- bytes_total_capacity,
388- bytes_one_time_burst,
389- bytes_complete_refill_time_ms,
390- ) )
391- } else {
392- None
393- } ;
401+ let bytes_token_bucket = Self :: make_bucket (
402+ bytes_total_capacity,
403+ bytes_one_time_burst,
404+ bytes_complete_refill_time_ms,
405+ ) ;
394406
395- // If either ops token bucket capacity or refill time is 0, disable limiting on ops/s.
396- let ops_token_bucket = if ops_total_capacity != 0 && ops_complete_refill_time_ms != 0 {
397- Some ( TokenBucket :: new (
398- ops_total_capacity,
399- ops_one_time_burst,
400- ops_complete_refill_time_ms,
401- ) )
402- } else {
403- None
404- } ;
407+ let ops_token_bucket = Self :: make_bucket (
408+ ops_total_capacity,
409+ ops_one_time_burst,
410+ ops_complete_refill_time_ms,
411+ ) ;
405412
406413 // If limiting is disabled on all token types, don't even create a timer fd.
407414 let timer_fd = if bytes_token_bucket. is_some ( ) || ops_token_bucket. is_some ( ) {
@@ -501,6 +508,22 @@ impl RateLimiter {
501508 }
502509 }
503510
511+ /// Updates the parameters of the token buckets associated with this RateLimiter.
512+ // TODO: Pls note that, right now, the buckets buckets become full after being updated.
513+ pub fn update_buckets ( & mut self , bytes : Option < TokenBucket > , ops : Option < TokenBucket > ) {
514+ // TODO: We have to call make_bucket instead of directly assigning the bytes and/or ops
515+ // because the input buckets are likely build via deserialization, which currently does not
516+ // properly set up their internal state.
517+
518+ if let Some ( b) = bytes {
519+ self . bandwidth = Self :: make_bucket ( b. size , b. one_time_burst , b. refill_time ) ;
520+ }
521+
522+ if let Some ( b) = ops {
523+ self . ops = Self :: make_bucket ( b. size , b. one_time_burst , b. refill_time ) ;
524+ }
525+ }
526+
504527 /// Returns an immutable view of the inner bandwidth token bucket.
505528 pub fn bandwidth ( & self ) -> Option < & TokenBucket > {
506529 self . bandwidth . as_ref ( )
@@ -834,4 +857,34 @@ mod tests {
834857 }"# ;
835858 assert ! ( serde_json:: from_str:: <RateLimiter >( jstr) . is_ok( ) ) ;
836859 }
860+
861+ #[ test]
862+ fn test_update_buckets ( ) {
863+ let jstr = r#"{
864+ "bandwidth": { "size": 1000, "one_time_burst": 2000, "refill_time": 1000 },
865+ "ops": { "size": 10, "one_time_burst": 20, "refill_time": 1000 }
866+ }"# ;
867+
868+ let mut x: RateLimiter = serde_json:: from_str ( jstr) . unwrap ( ) ;
869+
870+ let initial_bw = x. bandwidth . clone ( ) ;
871+ let initial_ops = x. ops . clone ( ) ;
872+
873+ x. update_buckets ( None , None ) ;
874+ assert_eq ! ( x. bandwidth, initial_bw) ;
875+ assert_eq ! ( x. ops, initial_ops) ;
876+
877+ let new_bw = TokenBucket :: new ( 123 , None , 57 ) ;
878+ let new_ops = TokenBucket :: new ( 321 , Some ( 12346 ) , 89 ) ;
879+ x. update_buckets ( Some ( new_bw. clone ( ) ) , Some ( new_ops. clone ( ) ) ) ;
880+
881+ // We have manually adjust the last_update field, because it changes when update_buckets()
882+ // constructs new buckets (and thus gets a different value for last_update). We do this so
883+ // it makes sense to test the following assertions.
884+ x. bandwidth . as_mut ( ) . unwrap ( ) . last_update = new_bw. last_update ;
885+ x. ops . as_mut ( ) . unwrap ( ) . last_update = new_ops. last_update ;
886+
887+ assert_eq ! ( x. bandwidth, Some ( new_bw) ) ;
888+ assert_eq ! ( x. ops, Some ( new_ops) ) ;
889+ }
837890}
0 commit comments