@@ -2,11 +2,16 @@ use axum::{
22 http:: StatusCode ,
33 response:: { IntoResponse , Response } ,
44} ;
5- use std:: { num:: NonZeroU32 , time:: Duration } ;
6- use tower:: {
7- layer:: util:: { Stack , LayerFn } ,
8- Limit , RateLimitLayer ,
5+ pub use governor:: {
6+ clock:: QuantaClock ,
7+ middleware:: NoOpMiddleware ,
8+ state:: keyed:: DashMapStateStore as DashMapStore ,
9+ Quota , RateLimiter ,
910} ;
11+ use std:: { num:: NonZeroU32 , sync:: Arc , time:: Duration } ;
12+ use tower:: limit:: RateLimitLayer ;
13+ use std:: error:: Error as StdError ;
14+ use std:: fmt;
1015
1116/// Rate limiting configuration for API endpoints
1217#[ derive( Debug , Clone ) ]
@@ -18,40 +23,63 @@ pub struct RateLimitConfig {
1823impl RateLimitConfig {
1924 /// Creates a new rate limiter layer based on configuration
2025 pub fn layer ( & self ) -> RateLimitLayer {
21- let window = Duration :: from_secs ( self . per_seconds ) ;
22- RateLimitLayer :: new ( self . requests . get ( ) , window)
26+ let rate = self . requests . get ( ) as u64 ;
27+ let per = Duration :: from_secs ( self . per_seconds ) ;
28+ RateLimitLayer :: new ( rate, per)
2329 }
2430}
2531
2632/// Global rate limiting configuration
33+ #[ derive( Clone ) ]
2734pub struct GlobalRateLimit {
2835 /// General API rate limits
29- pub api : RateLimitConfig ,
36+ pub api : Arc < RateLimiter < String , DashMapStore < String > , QuantaClock , NoOpMiddleware > > ,
3037 /// Stricter limits for GPU operations
31- pub gpu_operations : RateLimitConfig ,
38+ pub gpu_operations : Arc < RateLimiter < String , DashMapStore < String > , QuantaClock , NoOpMiddleware > > ,
3239 /// Authentication-specific limits
33- pub auth : RateLimitConfig ,
40+ pub auth : Arc < RateLimiter < String , DashMapStore < String > , QuantaClock , NoOpMiddleware > > ,
3441}
3542
3643impl Default for GlobalRateLimit {
3744 fn default ( ) -> Self {
45+ let clock = QuantaClock :: default ( ) ;
3846 Self {
39- api : RateLimitConfig {
40- requests : NonZeroU32 :: new ( 100 ) . unwrap ( ) ,
41- per_seconds : 60 ,
42- } ,
43- gpu_operations : RateLimitConfig {
44- requests : NonZeroU32 :: new ( 30 ) . unwrap ( ) ,
45- per_seconds : 60 ,
46- } ,
47- auth : RateLimitConfig {
48- requests : NonZeroU32 :: new ( 10 ) . unwrap ( ) ,
49- per_seconds : 60 ,
50- } ,
47+ api : Arc :: new (
48+ RateLimiter :: dashmap_with_clock (
49+ Quota :: per_second ( NonZeroU32 :: new ( 5 ) . unwrap ( ) ) . allow_burst ( NonZeroU32 :: new ( 10 ) . unwrap ( ) ) ,
50+ clock. clone ( ) ,
51+ )
52+ ) ,
53+ gpu_operations : Arc :: new (
54+ RateLimiter :: dashmap_with_clock (
55+ Quota :: per_minute ( NonZeroU32 :: new ( 3 ) . unwrap ( ) ) . allow_burst ( NonZeroU32 :: new ( 5 ) . unwrap ( ) ) ,
56+ clock. clone ( ) ,
57+ )
58+ ) ,
59+ auth : Arc :: new (
60+ RateLimiter :: dashmap_with_clock (
61+ Quota :: per_minute ( NonZeroU32 :: new ( 10 ) . unwrap ( ) ) . allow_burst ( NonZeroU32 :: new ( 15 ) . unwrap ( ) ) ,
62+ clock,
63+ )
64+ ) ,
5165 }
5266 }
5367}
5468
69+ impl GlobalRateLimit {
70+ pub fn api_quota ( & self ) -> Quota {
71+ Quota :: per_second ( NonZeroU32 :: new ( 5 ) . unwrap ( ) ) . allow_burst ( NonZeroU32 :: new ( 10 ) . unwrap ( ) )
72+ }
73+
74+ pub fn gpu_quota ( & self ) -> Quota {
75+ Quota :: per_minute ( NonZeroU32 :: new ( 3 ) . unwrap ( ) ) . allow_burst ( NonZeroU32 :: new ( 5 ) . unwrap ( ) )
76+ }
77+
78+ pub fn auth_quota ( & self ) -> Quota {
79+ Quota :: per_minute ( NonZeroU32 :: new ( 10 ) . unwrap ( ) ) . allow_burst ( NonZeroU32 :: new ( 15 ) . unwrap ( ) )
80+ }
81+ }
82+
5583/// Custom rate limit exceeded response
5684#[ derive( Debug ) ]
5785pub struct RateLimitExceeded ;
@@ -68,11 +96,72 @@ impl IntoResponse for RateLimitExceeded {
6896
6997/// Layer factory for rate limiting with custom response
7098pub fn rate_limit_layer (
71- config : RateLimitConfig ,
72- ) -> Stack < LayerFn < fn ( Limit ) -> Limit > , RateLimitLayer > {
73- let layer = config. layer ( ) ;
74- tower:: ServiceBuilder :: new ( )
75- . layer ( layer)
76- . map_err ( |_| RateLimitExceeded )
77- . into_inner ( )
78- }
99+ _limiter : Arc < RateLimiter < String , DashMapStore < String > , QuantaClock , NoOpMiddleware > > ,
100+ ) -> RateLimitLayer {
101+ // Sabit rate limit değerleri
102+ let rate = 100 ;
103+ let per = Duration :: from_secs ( 1 ) ;
104+ RateLimitLayer :: new ( rate, per)
105+ }
106+
107+ // Enhanced error handling for rate limits
108+ impl StdError for RateLimitExceeded { }
109+
110+ impl fmt:: Display for RateLimitExceeded {
111+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
112+ write ! ( f, "Rate limit exceeded" )
113+ }
114+ }
115+
116+
117+ #[ cfg( test) ]
118+ mod tests {
119+ use super :: * ;
120+ use axum:: body:: Body ;
121+ use axum:: http:: Request ;
122+ use tower:: { Service , ServiceExt } ;
123+
124+ #[ tokio:: test]
125+ async fn test_rate_limiting ( ) {
126+ let config = RateLimitConfig {
127+ requests : NonZeroU32 :: new ( 2 ) . unwrap ( ) ,
128+ per_seconds : 1 ,
129+ } ;
130+
131+ let mut service = tower:: ServiceBuilder :: new ( )
132+ . layer ( config. layer ( ) )
133+ . service ( tower:: service_fn ( |_| async {
134+ Ok :: < _ , std:: convert:: Infallible > ( Response :: new ( Body :: empty ( ) ) )
135+ } ) ) ;
136+
137+
138+ let response = service
139+ . ready ( )
140+ . await
141+ . unwrap ( )
142+ . call ( Request :: new ( Body :: empty ( ) ) )
143+ . await
144+ . unwrap ( ) ;
145+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
146+
147+
148+ let response = service
149+ . ready ( )
150+ . await
151+ . unwrap ( )
152+ . call ( Request :: new ( Body :: empty ( ) ) )
153+ . await
154+ . unwrap ( ) ;
155+ assert_eq ! ( response. status( ) , StatusCode :: OK ) ;
156+
157+
158+ let response = service
159+ . ready ( )
160+ . await
161+ . unwrap ( )
162+ . call ( Request :: new ( Body :: empty ( ) ) )
163+ . await
164+ . unwrap ( ) ;
165+ assert_eq ! ( response. status( ) , StatusCode :: TOO_MANY_REQUESTS ) ;
166+ }
167+ }
0 commit comments