Skip to content

Commit 1619356

Browse files
committed
Router error handling errors :((
1 parent 182d3da commit 1619356

File tree

14 files changed

+368
-212
lines changed

14 files changed

+368
-212
lines changed

Cargo.toml

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,39 @@ edition = "2021"
55

66
[dependencies]
77
tokio = { version = "1.36", features = ["full"] }
8-
virt = { version = "0.4.1", features = ["snapshot"] }
8+
virt = "0.4.1"
99
serde = { version = "1.0", features = ["derive"] }
1010
serde_json = "1.0"
1111
tracing = "0.1"
1212
tracing-subscriber = "0.3"
1313
anyhow = "1.0"
1414
async-trait = "0.1"
1515
config = "0.15.6"
16-
axum = { version = "0.8", features = ["macros"] }
16+
axum = { version = "0.8.0", features = ["macros"] }
1717
hyper = { version = "1.0", features = ["full"] }
1818
tower = { version = "0.5.2", features = ["limit", "util"] }
19-
tower-http = { version = "0.6.2", features = ["trace", "limit"] }
19+
tower-http = { version = "0.6.2", features = ["trace", "limit", "add-extension"] }
2020
clap = { version = "4.4", features = ["derive"] }
2121
colored = "3.0"
2222
thiserror = "2.0.11"
2323
chrono = "0.4"
2424
uuid = { version = "1.8.0", features = ["v4"] }
25-
libvirt = "0.1.0"
26-
governor = { version = "0.8.0", features = ["std", "nohashmap"] }
25+
governor = { version = "0.8", features = ["dashmap"] }
26+
jsonwebtoken = "8.3.0"
2727

28-
# Platform-specific dependencies
2928
[target.'cfg(target_os = "linux")'.dependencies]
30-
libvirt = "0.1.0"
31-
nvml-wrapper = "0.10.0"
32-
glob = "0.3"
29+
nvml-wrapper = { version = "0.10.0", optional = true }
3330

3431
[target.'cfg(target_os = "macos")'.dependencies]
35-
core-graphics = "0.23.2"
36-
metal = { version = "0.26.0", features = ["private"] }
32+
core-graphics = { version = "0.24.0", optional = true }
33+
metal = { version = "0.27.0", features = ["private"], optional = true }
3734

3835
[target.'cfg(target_os = "windows")'.dependencies]
39-
winapi = { version = "0.3", features = ["dxgi", "d3dcommon"] }
40-
dxgi = "0.4"
41-
42-
[lib]
43-
name = "gpu_share_vm_manager"
44-
path = "src/lib.rs"
45-
46-
[[bin]]
47-
name = "gpu-share-vm-manager"
48-
path = "src/main.rs"
36+
dxgi = { version = "0.3.0-alpha4", optional = true }
37+
winapi = { version = "0.3", features = ["dxgi", "d3dcommon"], optional = true }
38+
windows = { version = "0.48", features = ["Win32_Graphics_Dxgi"] }
4939

5040
[features]
51-
metal = ["dep:metal", "dep:core-graphics"]
52-
windows = ["dep:dxgi", "dep:winapi"]
53-
linux = ["dep:nvml-wrapper", "dep:libvirt"]
41+
default = ["metal"]
42+
metal = ["dep:core-graphics", "dep:metal"]
43+
windows = ["dep:dxgi", "winapi"]

src/api/middleware/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub mod rate_limit;

src/api/middleware/rate_limit.rs

Lines changed: 118 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@ use axum::{
22
http::StatusCode,
33
response::{IntoResponse, Response},
44
};
5-
use std::{num::NonZeroU32, time::Duration};
6-
use tower::{
7-
layer::util::{Stack, LayerFn},
8-
Limit, RateLimitLayer,
5+
pub use governor::{
6+
clock::QuantaClock,
7+
middleware::NoOpMiddleware,
8+
state::keyed::DashMapStateStore as DashMapStore,
9+
Quota, RateLimiter,
910
};
11+
use std::{num::NonZeroU32, sync::Arc, time::Duration};
12+
use tower::limit::RateLimitLayer;
13+
use std::error::Error as StdError;
14+
use std::fmt;
1015

1116
/// Rate limiting configuration for API endpoints
1217
#[derive(Debug, Clone)]
@@ -18,40 +23,63 @@ pub struct RateLimitConfig {
1823
impl RateLimitConfig {
1924
/// Creates a new rate limiter layer based on configuration
2025
pub fn layer(&self) -> RateLimitLayer {
21-
let window = Duration::from_secs(self.per_seconds);
22-
RateLimitLayer::new(self.requests.get(), window)
26+
let rate = self.requests.get() as u64;
27+
let per = Duration::from_secs(self.per_seconds);
28+
RateLimitLayer::new(rate, per)
2329
}
2430
}
2531

2632
/// Global rate limiting configuration
33+
#[derive(Clone)]
2734
pub struct GlobalRateLimit {
2835
/// General API rate limits
29-
pub api: RateLimitConfig,
36+
pub api: Arc<RateLimiter<String, DashMapStore<String>, QuantaClock, NoOpMiddleware>>,
3037
/// Stricter limits for GPU operations
31-
pub gpu_operations: RateLimitConfig,
38+
pub gpu_operations: Arc<RateLimiter<String, DashMapStore<String>, QuantaClock, NoOpMiddleware>>,
3239
/// Authentication-specific limits
33-
pub auth: RateLimitConfig,
40+
pub auth: Arc<RateLimiter<String, DashMapStore<String>, QuantaClock, NoOpMiddleware>>,
3441
}
3542

3643
impl Default for GlobalRateLimit {
3744
fn default() -> Self {
45+
let clock = QuantaClock::default();
3846
Self {
39-
api: RateLimitConfig {
40-
requests: NonZeroU32::new(100).unwrap(),
41-
per_seconds: 60,
42-
},
43-
gpu_operations: RateLimitConfig {
44-
requests: NonZeroU32::new(30).unwrap(),
45-
per_seconds: 60,
46-
},
47-
auth: RateLimitConfig {
48-
requests: NonZeroU32::new(10).unwrap(),
49-
per_seconds: 60,
50-
},
47+
api: Arc::new(
48+
RateLimiter::dashmap_with_clock(
49+
Quota::per_second(NonZeroU32::new(5).unwrap()).allow_burst(NonZeroU32::new(10).unwrap()),
50+
clock.clone(),
51+
)
52+
),
53+
gpu_operations: Arc::new(
54+
RateLimiter::dashmap_with_clock(
55+
Quota::per_minute(NonZeroU32::new(3).unwrap()).allow_burst(NonZeroU32::new(5).unwrap()),
56+
clock.clone(),
57+
)
58+
),
59+
auth: Arc::new(
60+
RateLimiter::dashmap_with_clock(
61+
Quota::per_minute(NonZeroU32::new(10).unwrap()).allow_burst(NonZeroU32::new(15).unwrap()),
62+
clock,
63+
)
64+
),
5165
}
5266
}
5367
}
5468

69+
impl GlobalRateLimit {
70+
pub fn api_quota(&self) -> Quota {
71+
Quota::per_second(NonZeroU32::new(5).unwrap()).allow_burst(NonZeroU32::new(10).unwrap())
72+
}
73+
74+
pub fn gpu_quota(&self) -> Quota {
75+
Quota::per_minute(NonZeroU32::new(3).unwrap()).allow_burst(NonZeroU32::new(5).unwrap())
76+
}
77+
78+
pub fn auth_quota(&self) -> Quota {
79+
Quota::per_minute(NonZeroU32::new(10).unwrap()).allow_burst(NonZeroU32::new(15).unwrap())
80+
}
81+
}
82+
5583
/// Custom rate limit exceeded response
5684
#[derive(Debug)]
5785
pub struct RateLimitExceeded;
@@ -68,11 +96,72 @@ impl IntoResponse for RateLimitExceeded {
6896

6997
/// Layer factory for rate limiting with custom response
7098
pub fn rate_limit_layer(
71-
config: RateLimitConfig,
72-
) -> Stack<LayerFn<fn(Limit) -> Limit>, RateLimitLayer> {
73-
let layer = config.layer();
74-
tower::ServiceBuilder::new()
75-
.layer(layer)
76-
.map_err(|_| RateLimitExceeded)
77-
.into_inner()
78-
}
99+
_limiter: Arc<RateLimiter<String, DashMapStore<String>, QuantaClock, NoOpMiddleware>>,
100+
) -> RateLimitLayer {
101+
// Sabit rate limit değerleri
102+
let rate = 100;
103+
let per = Duration::from_secs(1);
104+
RateLimitLayer::new(rate, per)
105+
}
106+
107+
// Enhanced error handling for rate limits
108+
impl StdError for RateLimitExceeded {}
109+
110+
impl fmt::Display for RateLimitExceeded {
111+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112+
write!(f, "Rate limit exceeded")
113+
}
114+
}
115+
116+
117+
#[cfg(test)]
118+
mod tests {
119+
use super::*;
120+
use axum::body::Body;
121+
use axum::http::Request;
122+
use tower::{Service, ServiceExt};
123+
124+
#[tokio::test]
125+
async fn test_rate_limiting() {
126+
let config = RateLimitConfig {
127+
requests: NonZeroU32::new(2).unwrap(),
128+
per_seconds: 1,
129+
};
130+
131+
let mut service = tower::ServiceBuilder::new()
132+
.layer(config.layer())
133+
.service(tower::service_fn(|_| async {
134+
Ok::<_, std::convert::Infallible>(Response::new(Body::empty()))
135+
}));
136+
137+
138+
let response = service
139+
.ready()
140+
.await
141+
.unwrap()
142+
.call(Request::new(Body::empty()))
143+
.await
144+
.unwrap();
145+
assert_eq!(response.status(), StatusCode::OK);
146+
147+
148+
let response = service
149+
.ready()
150+
.await
151+
.unwrap()
152+
.call(Request::new(Body::empty()))
153+
.await
154+
.unwrap();
155+
assert_eq!(response.status(), StatusCode::OK);
156+
157+
158+
let response = service
159+
.ready()
160+
.await
161+
.unwrap()
162+
.call(Request::new(Body::empty()))
163+
.await
164+
.unwrap();
165+
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
166+
}
167+
}

src/api/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod middleware;
12
pub mod routes;
23

34
pub use routes::{create_router, AppState};

0 commit comments

Comments
 (0)