Skip to content

Commit f18c1ac

Browse files
committed
#issue10: Rate Limiting
Tiered rate limiting strategy: General API limits (100 requests/minute) Tighter limits for GPU operations (30 requests/minute) Special limits for Auth endpoints (10 requests/minute) Dynamic configuration: All limits can be set via configuration files Different limits for different endpoint groups
1 parent 0ce9fec commit f18c1ac

File tree

6 files changed

+167
-4
lines changed

6 files changed

+167
-4
lines changed

Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ async-trait = "0.1"
1515
config = "0.15.6"
1616
axum = { version = "0.8", features = ["macros"] }
1717
hyper = { version = "1.0", features = ["full"] }
18-
tower = "0.5.2"
19-
tower-http = { version = "0.6.2", features = ["trace"] }
18+
tower = { version = "0.5.2", features = ["limit", "util"] }
19+
tower-http = { version = "0.6.2", features = ["trace", "limit"] }
2020
clap = { version = "4.4", features = ["derive"] }
2121
colored = "3.0"
2222
thiserror = "2.0.11"
2323
chrono = "0.4"
2424
uuid = { version = "1.8.0", features = ["v4"] }
2525
libvirt = "0.1.0"
26+
governor = { version = "0.8.0", features = ["std", "nohashmap"] }
2627

2728
[lib]
2829
name = "gpu_share_vm_manager"

src/api/middleware/rate_limit.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use axum::{
2+
http::StatusCode,
3+
response::{IntoResponse, Response},
4+
};
5+
use std::{num::NonZeroU32, time::Duration};
6+
use tower::{
7+
layer::util::{Stack, LayerFn},
8+
Limit, RateLimitLayer,
9+
};
10+
11+
/// Rate limiting configuration for API endpoints
12+
#[derive(Debug, Clone)]
13+
pub struct RateLimitConfig {
14+
pub requests: NonZeroU32,
15+
pub per_seconds: u64,
16+
}
17+
18+
impl RateLimitConfig {
19+
/// Creates a new rate limiter layer based on configuration
20+
pub fn layer(&self) -> RateLimitLayer {
21+
let window = Duration::from_secs(self.per_seconds);
22+
RateLimitLayer::new(self.requests.get(), window)
23+
}
24+
}
25+
26+
/// Global rate limiting configuration
27+
pub struct GlobalRateLimit {
28+
/// General API rate limits
29+
pub api: RateLimitConfig,
30+
/// Stricter limits for GPU operations
31+
pub gpu_operations: RateLimitConfig,
32+
/// Authentication-specific limits
33+
pub auth: RateLimitConfig,
34+
}
35+
36+
impl Default for GlobalRateLimit {
37+
fn default() -> Self {
38+
Self {
39+
api: RateLimitConfig {
40+
requests: NonZeroU32::new(100).unwrap(),
41+
per_seconds: 60,
42+
},
43+
gpu_operations: RateLimitConfig {
44+
requests: NonZeroU32::new(30).unwrap(),
45+
per_seconds: 60,
46+
},
47+
auth: RateLimitConfig {
48+
requests: NonZeroU32::new(10).unwrap(),
49+
per_seconds: 60,
50+
},
51+
}
52+
}
53+
}
54+
55+
/// Custom rate limit exceeded response
56+
#[derive(Debug)]
57+
pub struct RateLimitExceeded;
58+
59+
impl IntoResponse for RateLimitExceeded {
60+
fn into_response(self) -> Response {
61+
(
62+
StatusCode::TOO_MANY_REQUESTS,
63+
"Rate limit exceeded. Please try again later.",
64+
)
65+
.into_response()
66+
}
67+
}
68+
69+
/// Layer factory for rate limiting with custom response
70+
pub fn rate_limit_layer(
71+
config: RateLimitConfig,
72+
) -> Stack<LayerFn<fn(Limit) -> Limit>, RateLimitLayer> {
73+
let layer = config.layer();
74+
tower::ServiceBuilder::new()
75+
.layer(layer)
76+
.map_err(|_| RateLimitExceeded)
77+
.into_inner()
78+
}

src/api/routes.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use axum::{
6363
extract::{Path, State},
6464
Json,
6565
http::StatusCode,
66+
response::{IntoResponse},
6667
};
6768
use serde::{Deserialize, Serialize};
6869
use std::sync::Arc;
@@ -74,6 +75,7 @@ use crate::core::libvirt::LibvirtManager;
7475
use crate::core::vm::{VMStatus, VMConfig};
7576
use crate::gpu::device::{GPUManager, GPUDevice, GPUConfig};
7677
use crate::monitoring::metrics::{MetricsCollector, ResourceMetrics};
78+
use crate::api::middleware::rate_limit::{rate_limit_layer, GlobalRateLimit, RateLimitExceeded};
7779

7880
fn handle_error(err: impl std::fmt::Display) -> StatusCode {
7981
error!("Operation failed: {}", err);
@@ -112,17 +114,39 @@ pub struct AttachGPURequest {
112114
}
113115

114116
pub fn create_router(state: Arc<AppState>) -> Router {
117+
let rate_limits = GlobalRateLimit::default();
118+
115119
Router::new()
120+
// Public endpoints with stricter limits
121+
.route("/api/v1/auth/login", post(login))
122+
.layer(rate_limit_layer(rate_limits.auth.clone()))
123+
124+
// GPU operations with specific limits
125+
.route("/api/v1/gpus", get(list_gpus))
126+
.route("/api/v1/vms/:id/attach_gpu", post(attach_gpu))
127+
.layer(rate_limit_layer(rate_limits.gpu_operations.clone()))
128+
129+
// General API endpoints
116130
.route("/api/v1/vms", post(create_vm))
117131
.route("/api/v1/vms", get(list_vms))
118132
.route("/api/v1/vms/:id", get(get_vm))
119133
.route("/api/v1/vms/:id", delete(delete_vm))
120134
.route("/api/v1/vms/:id/start", post(start_vm))
121135
.route("/api/v1/vms/:id/stop", post(stop_vm))
122-
.route("/api/v1/gpus", get(list_gpus))
123-
.route("/api/v1/vms/:id/attach_gpu", post(attach_gpu))
124136
.route("/api/v1/metrics/:id", get(get_metrics))
137+
.layer(rate_limit_layer(rate_limits.api.clone()))
138+
139+
// Shared state and fallback
125140
.with_state(state)
141+
.fallback(fallback_handler)
142+
.layer(HandleErrorLayer::new(handle_error))
143+
}
144+
145+
async fn handle_error(error: Box<dyn std::error::Error + Send + Sync>) -> impl IntoResponse {
146+
if error.is::<RateLimitExceeded>() {
147+
return RateLimitExceeded.into_response();
148+
}
149+
// ... existing error handling ...
126150
}
127151

128152
#[axum::debug_handler]

src/config.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
44
pub struct Config {
55
pub server: ServerConfig,
66
pub metrics: MetricsConfig,
7+
pub rate_limits: RateLimitConfig,
78
}
89

910
#[derive(Debug, Serialize, Deserialize)]
@@ -16,4 +17,11 @@ pub struct ServerConfig {
1617
pub struct MetricsConfig {
1718
pub collection_interval_secs: u64,
1819
pub retention_hours: u64,
20+
}
21+
22+
#[derive(Debug, Serialize, Deserialize)]
23+
pub struct RateLimitConfig {
24+
pub api_requests_per_minute: u32,
25+
pub gpu_requests_per_minute: u32,
26+
pub auth_requests_per_minute: u32,
1927
}

src/config/settings.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242
* - vm_image_path: Where VM images go to hibernate
4343
* - max_storage_gb: Because someone will try to store their entire Steam library
4444
*
45+
* 5. RateLimitSettings:
46+
* - api_requests_per_minute: Rate limit for general API requests
47+
* - gpu_requests_per_minute: Rate limit for GPU-related requests
48+
* - auth_requests_per_minute: Rate limit for authentication-related requests
49+
*
4550
* Implementation Details:
4651
* --------------------
4752
* - Using serde for serialization (because writing parsers is so 1990s)
@@ -84,6 +89,7 @@ pub struct Settings {
8489
pub libvirt: LibvirtSettings,
8590
pub monitoring: MonitoringSettings,
8691
pub storage: StorageSettings,
92+
pub rate_limits: RateLimitSettings,
8793
}
8894

8995
#[derive(Debug, Serialize, Deserialize)]
@@ -114,6 +120,23 @@ pub struct StorageSettings {
114120
pub max_storage_gb: u64,
115121
}
116122

123+
#[derive(Debug, Serialize, Deserialize)]
124+
pub struct RateLimitSettings {
125+
pub api_requests_per_minute: u32,
126+
pub gpu_requests_per_minute: u32,
127+
pub auth_requests_per_minute: u32,
128+
}
129+
130+
impl Default for RateLimitSettings {
131+
fn default() -> Self {
132+
Self {
133+
api_requests_per_minute: 100,
134+
gpu_requests_per_minute: 30,
135+
auth_requests_per_minute: 10,
136+
}
137+
}
138+
}
139+
117140
impl Settings {
118141
pub fn new() -> Result<Self, ConfigError> {
119142
let config_path = std::env::var("CONFIG_PATH")
@@ -161,5 +184,10 @@ pub fn generate_default_config() -> Settings {
161184
vm_image_path: PathBuf::from("/var/lib/gpu-share/images"),
162185
max_storage_gb: 100,
163186
},
187+
rate_limits: RateLimitSettings {
188+
api_requests_per_minute: 100,
189+
gpu_requests_per_minute: 30,
190+
auth_requests_per_minute: 10,
191+
},
164192
}
165193
}

src/core/errors/handlers.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#[derive(Clone)]
2+
pub struct CircuitBreaker {
3+
state: Arc<Mutex<CircuitState>>,
4+
failure_threshold: u32,
5+
reset_timeout: Duration,
6+
}
7+
8+
impl CircuitBreaker {
9+
pub fn new(failure_threshold: u32, reset_timeout: Duration) -> Self {
10+
Self {
11+
state: Arc::new(Mutex::new(CircuitState::Closed)),
12+
failure_threshold,
13+
reset_timeout,
14+
}
15+
}
16+
17+
pub async fn execute<F, T, E>(&self, mut operation: F) -> Result<T, GpuShareError>
18+
where
19+
F: FnMut() -> Result<T, E>,
20+
E: Into<GpuShareError>,
21+
{
22+
// TODO: Circuit breaker implementation -@virjilakrum
23+
}
24+
}

0 commit comments

Comments
 (0)