@@ -14,14 +14,14 @@ use tracing::{debug, warn};
1414
1515/// Time provider trait for mocking in tests
1616#[ cfg( test) ]
17- pub ( crate ) trait TimeProvider {
17+ pub trait TimeProvider {
1818 fn now ( & self ) -> Instant ;
1919 fn advance ( & mut self , duration : Duration ) ;
2020}
2121
2222/// Real time provider for production
2323#[ cfg( test) ]
24- struct RealTimeProvider ;
24+ pub struct RealTimeProvider ;
2525
2626#[ cfg( test) ]
2727impl TimeProvider for RealTimeProvider {
@@ -36,13 +36,13 @@ impl TimeProvider for RealTimeProvider {
3636
3737/// Mock time provider for tests
3838#[ cfg( test) ]
39- struct MockTimeProvider {
39+ pub struct MockTimeProvider {
4040 current_time : Instant ,
4141}
4242
4343#[ cfg( test) ]
4444impl MockTimeProvider {
45- fn new ( ) -> Self {
45+ pub fn new ( ) -> Self {
4646 Self {
4747 current_time : Instant :: now ( ) ,
4848 }
@@ -85,15 +85,23 @@ impl Default for RateLimitConfig {
8585}
8686
8787/// Token bucket for rate limiting
88- struct TokenBucket {
88+ #[ cfg( test) ]
89+ pub struct TokenBucket {
8990 tokens : AtomicU32 ,
9091 max_tokens : u32 ,
9192 refill_rate : u32 , // tokens per refill interval
9293 last_refill : RwLock < Instant > ,
93- #[ cfg( test) ]
9494 time_provider : Arc < RwLock < Box < dyn TimeProvider + Send + Sync > > > ,
9595}
9696
97+ #[ cfg( not( test) ) ]
98+ struct TokenBucket {
99+ tokens : AtomicU32 ,
100+ max_tokens : u32 ,
101+ refill_rate : u32 , // tokens per refill interval
102+ last_refill : RwLock < Instant > ,
103+ }
104+
97105impl std:: fmt:: Debug for TokenBucket {
98106 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
99107 f. debug_struct ( "TokenBucket" )
@@ -113,19 +121,29 @@ impl std::fmt::Debug for TokenBucket {
113121}
114122
115123impl TokenBucket {
124+ #[ cfg( not( test) ) ]
116125 fn new ( max_tokens : u32 , refill_rate : u32 ) -> Self {
117126 Self {
118127 tokens : AtomicU32 :: new ( max_tokens) ,
119128 max_tokens,
120129 refill_rate,
121130 last_refill : RwLock :: new ( Instant :: now ( ) ) ,
122- #[ cfg( test) ]
131+ }
132+ }
133+
134+ #[ cfg( test) ]
135+ pub fn new ( max_tokens : u32 , refill_rate : u32 ) -> Self {
136+ Self {
137+ tokens : AtomicU32 :: new ( max_tokens) ,
138+ max_tokens,
139+ refill_rate,
140+ last_refill : RwLock :: new ( Instant :: now ( ) ) ,
123141 time_provider : Arc :: new ( RwLock :: new ( Box :: new ( RealTimeProvider ) ) ) ,
124142 }
125143 }
126144
127145 #[ cfg( test) ]
128- fn new_with_time_provider (
146+ pub fn new_with_time_provider (
129147 max_tokens : u32 ,
130148 refill_rate : u32 ,
131149 time_provider : Box < dyn TimeProvider + Send + Sync > ,
@@ -140,6 +158,7 @@ impl TokenBucket {
140158 }
141159
142160 /// Try to consume a token. Returns true if successful, false if rate limited.
161+ #[ cfg( not( test) ) ]
143162 async fn try_consume ( & self ) -> bool {
144163 self . refill ( ) . await ;
145164
@@ -166,8 +185,64 @@ impl TokenBucket {
166185 }
167186 }
168187
188+ #[ cfg( test) ]
189+ pub async fn try_consume ( & self ) -> bool {
190+ self . refill ( ) . await ;
191+
192+ // Use a loop instead of recursion to avoid boxing
193+ loop {
194+ let current_tokens = self . tokens . load ( Ordering :: Acquire ) ;
195+ if current_tokens > 0 {
196+ // Try to decrement atomically
197+ match self . tokens . compare_exchange_weak (
198+ current_tokens,
199+ current_tokens - 1 ,
200+ Ordering :: Release ,
201+ Ordering :: Relaxed ,
202+ ) {
203+ Ok ( _) => return true ,
204+ Err ( _) => {
205+ // Someone else consumed the token, try again
206+ continue ;
207+ }
208+ }
209+ } else {
210+ return false ;
211+ }
212+ }
213+ }
214+
169215 /// Refill tokens based on elapsed time
216+ #[ cfg( not( test) ) ]
170217 async fn refill ( & self ) {
218+ let now = Instant :: now ( ) ;
219+
220+ let mut last_refill = self . last_refill . write ( ) . await ;
221+
222+ let elapsed = now. duration_since ( * last_refill) ;
223+ if elapsed >= Duration :: from_secs ( 1 ) {
224+ let seconds_passed = elapsed. as_secs ( ) as u32 ;
225+ let tokens_to_add = seconds_passed * self . refill_rate ;
226+
227+ if tokens_to_add > 0 {
228+ let current_tokens = self . tokens . load ( Ordering :: Acquire ) ;
229+ let new_tokens = ( current_tokens + tokens_to_add) . min ( self . max_tokens ) ;
230+ self . tokens . store ( new_tokens, Ordering :: Release ) ;
231+ * last_refill = now;
232+
233+ // Only log if we actually added tokens and it's significant
234+ if tokens_to_add > 0 && current_tokens < self . max_tokens / 2 {
235+ debug ! (
236+ "Refilled {} tokens, current: {}/{}" ,
237+ tokens_to_add, new_tokens, self . max_tokens
238+ ) ;
239+ }
240+ }
241+ }
242+ }
243+
244+ #[ cfg( test) ]
245+ pub async fn refill ( & self ) {
171246 #[ cfg( test) ]
172247 let now = {
173248 let time_provider = self . time_provider . read ( ) . await ;
@@ -207,7 +282,7 @@ impl TokenBucket {
207282
208283 #[ cfg( test) ]
209284 /// Advance time for testing
210- async fn advance_time ( & self , duration : Duration ) {
285+ pub async fn advance_time ( & self , duration : Duration ) {
211286 let mut time_provider = self . time_provider . write ( ) . await ;
212287 time_provider. advance ( duration) ;
213288 }
@@ -360,136 +435,3 @@ impl std::fmt::Display for RateLimitStatus {
360435 )
361436 }
362437}
363-
364- #[ cfg( test) ]
365- mod tests {
366- use super :: * ;
367- use std:: net:: Ipv4Addr ;
368-
369- #[ tokio:: test]
370- async fn test_token_bucket_basic ( ) {
371- let bucket = TokenBucket :: new ( 5 , 1 ) ;
372-
373- // Should be able to consume up to max tokens
374- for _ in 0 ..5 {
375- assert ! ( bucket. try_consume( ) . await ) ;
376- }
377-
378- // Should be rate limited after consuming all tokens
379- assert ! ( !bucket. try_consume( ) . await ) ;
380- }
381-
382- #[ tokio:: test]
383- async fn test_token_bucket_refill ( ) {
384- let time_provider = Box :: new ( MockTimeProvider :: new ( ) ) ;
385- let bucket = TokenBucket :: new_with_time_provider ( 2 , 1 , time_provider) ;
386-
387- // Consume all tokens
388- assert ! ( bucket. try_consume( ) . await ) ;
389- assert ! ( bucket. try_consume( ) . await ) ;
390- assert ! ( !bucket. try_consume( ) . await ) ;
391-
392- // Advance time by 2 seconds (instant)
393- bucket. advance_time ( Duration :: from_secs ( 2 ) ) . await ;
394-
395- // Should have tokens again
396- assert ! ( bucket. try_consume( ) . await ) ;
397- }
398-
399- #[ tokio:: test]
400- async fn test_rate_limiter_global_limit ( ) {
401- let config = RateLimitConfig {
402- global_requests_per_minute : 2 ,
403- per_ip_requests_per_minute : 10 ,
404- ip_memory_duration : 3600 ,
405- refill_interval : 1 ,
406- } ;
407-
408- let limiter = RateLimiter :: new ( config) ;
409- let ip = IpAddr :: V4 ( Ipv4Addr :: new ( 127 , 0 , 0 , 1 ) ) ;
410-
411- // Should allow up to global limit
412- assert_eq ! ( limiter. check_rate_limit( ip) . await , RateLimitResult :: Allowed ) ;
413- assert_eq ! ( limiter. check_rate_limit( ip) . await , RateLimitResult :: Allowed ) ;
414-
415- // Should exceed global limit
416- assert_eq ! (
417- limiter. check_rate_limit( ip) . await ,
418- RateLimitResult :: GlobalLimitExceeded
419- ) ;
420- }
421-
422- #[ tokio:: test]
423- async fn test_rate_limiter_ip_limit ( ) {
424- let config = RateLimitConfig {
425- global_requests_per_minute : 100 ,
426- per_ip_requests_per_minute : 2 ,
427- ip_memory_duration : 3600 ,
428- refill_interval : 1 ,
429- } ;
430-
431- let limiter = RateLimiter :: new ( config) ;
432- let ip = IpAddr :: V4 ( Ipv4Addr :: new ( 127 , 0 , 0 , 1 ) ) ;
433-
434- // Should allow up to per-IP limit
435- assert_eq ! ( limiter. check_rate_limit( ip) . await , RateLimitResult :: Allowed ) ;
436- assert_eq ! ( limiter. check_rate_limit( ip) . await , RateLimitResult :: Allowed ) ;
437-
438- // Should exceed per-IP limit
439- assert_eq ! (
440- limiter. check_rate_limit( ip) . await ,
441- RateLimitResult :: IpLimitExceeded
442- ) ;
443- }
444-
445- #[ tokio:: test]
446- async fn test_rate_limiter_different_ips ( ) {
447- let config = RateLimitConfig {
448- global_requests_per_minute : 100 ,
449- per_ip_requests_per_minute : 1 ,
450- ip_memory_duration : 3600 ,
451- refill_interval : 1 ,
452- } ;
453-
454- let limiter = RateLimiter :: new ( config) ;
455- let ip1 = IpAddr :: V4 ( Ipv4Addr :: new ( 127 , 0 , 0 , 1 ) ) ;
456- let ip2 = IpAddr :: V4 ( Ipv4Addr :: new ( 127 , 0 , 0 , 2 ) ) ;
457-
458- // Each IP should have its own limit
459- assert_eq ! (
460- limiter. check_rate_limit( ip1) . await ,
461- RateLimitResult :: Allowed
462- ) ;
463- assert_eq ! (
464- limiter. check_rate_limit( ip2) . await ,
465- RateLimitResult :: Allowed
466- ) ;
467-
468- // Both should be rate limited after consuming their tokens
469- assert_eq ! (
470- limiter. check_rate_limit( ip1) . await ,
471- RateLimitResult :: IpLimitExceeded
472- ) ;
473- assert_eq ! (
474- limiter. check_rate_limit( ip2) . await ,
475- RateLimitResult :: IpLimitExceeded
476- ) ;
477- }
478-
479- #[ tokio:: test]
480- async fn test_simple_time_advancement ( ) {
481- let time_provider = Box :: new ( MockTimeProvider :: new ( ) ) ;
482- let bucket = TokenBucket :: new_with_time_provider ( 2 , 1 , time_provider) ;
483-
484- // Consume all tokens
485- assert ! ( bucket. try_consume( ) . await ) ;
486- assert ! ( bucket. try_consume( ) . await ) ;
487- assert ! ( !bucket. try_consume( ) . await ) ;
488-
489- // Advance time by 2 seconds
490- bucket. advance_time ( Duration :: from_secs ( 2 ) ) . await ;
491-
492- // Should have tokens again
493- assert ! ( bucket. try_consume( ) . await ) ;
494- }
495- }
0 commit comments