Skip to content

Commit 34f334a

Browse files
committed
add shutdown mehcanism
sender - reciever
1 parent 31be006 commit 34f334a

File tree

2 files changed

+95
-35
lines changed

2 files changed

+95
-35
lines changed

src/api/routes.rs

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ use axum::{
6464
Json,
6565
http::StatusCode,
6666
response::IntoResponse,
67+
response::{Json, Response},
68+
http::Request,
69+
http::StatusCode,
6770
};
6871
use serde::{Deserialize, Serialize};
6972
use std::sync::Arc;
@@ -82,10 +85,13 @@ fn handle_error(err: impl std::fmt::Display) -> StatusCode {
8285
StatusCode::INTERNAL_SERVER_ERROR
8386
}
8487

88+
#[derive(Clone)]
8589
pub struct AppState {
8690
pub libvirt: Arc<Mutex<LibvirtManager>>,
8791
pub gpu_manager: Arc<Mutex<GPUManager>>,
8892
pub metrics: Arc<Mutex<MetricsCollector>>,
93+
pub shutdown_signal: Arc<Mutex<tokio::sync::oneshot::Sender<()>>>,
94+
pub shutdown_receiver: Arc<Mutex<tokio::sync::oneshot::Receiver<()>>>,
8995
}
9096

9197
#[derive(Debug, Deserialize)]
@@ -95,6 +101,8 @@ pub struct CreateVMRequest {
95101
pub memory_mb: u64,
96102
pub gpu_required: bool,
97103
pub disk_size_gb: Option<u64>,
104+
// pub username: String,
105+
// pub password: String,
98106
}
99107

100108
#[derive(Debug, Serialize)]
@@ -146,14 +154,42 @@ async fn handle_error(error: Box<dyn std::error::Error + Send + Sync>) -> impl I
146154
if error.is::<RateLimitExceeded>() {
147155
return RateLimitExceeded.into_response();
148156
}
149-
// ... existing error handling ...
157+
158+
if let Some(libvirt_error) = error.downcast_ref::<libvirt::Error>() {
159+
match libvirt_error.code() {
160+
libvirt::ErrorNumber::NO_DOMAIN => {
161+
return (StatusCode::NOT_FOUND, "VM not found").into_response()
162+
}
163+
libvirt::ErrorNumber::OPERATION_INVALID => {
164+
return (StatusCode::BAD_REQUEST, "Invalid operation").into_response()
165+
}
166+
_ => {}
167+
}
168+
}
169+
170+
if let Some(gpu_error) = error.downcast_ref::<gpu::GPUError>() {
171+
match gpu_error {
172+
gpu::GPUError::NotFound => {
173+
return (StatusCode::NOT_FOUND, "GPU not found").into_response()
174+
}
175+
gpu::GPUError::AlreadyAttached => {
176+
return (StatusCode::CONFLICT, "GPU already attached").into_response()
177+
}
178+
_ => {}
179+
}
180+
}
181+
(
182+
StatusCode::INTERNAL_SERVER_ERROR,
183+
format!("Internal server error: {}", error),
184+
)
185+
.into_response()
150186
}
151187

152188
#[axum::debug_handler]
153189
async fn create_vm(
154190
State(state): State<Arc<AppState>>,
155-
Json(params): Json<CreateVMRequest>,
156-
) -> Result<Json<VMResponse>, StatusCode> {
191+
Json(params): Json<CreateVMRequest>
192+
) -> Result<impl IntoResponse, StatusCode> {
157193
let libvirt = state.libvirt.lock().await;
158194

159195
let config = VMConfig {
@@ -188,8 +224,8 @@ async fn create_vm(
188224

189225
#[axum::debug_handler]
190226
async fn list_vms(
191-
State(state): State<Arc<AppState>>,
192-
) -> Result<Json<Vec<VMResponse>>, StatusCode> {
227+
State(state): State<Arc<AppState>>
228+
) -> Result<impl IntoResponse, StatusCode> {
193229
let libvirt = state.libvirt.lock().await;
194230

195231
let domains = libvirt.list_domains()
@@ -220,8 +256,8 @@ async fn list_vms(
220256
#[axum::debug_handler]
221257
async fn get_vm(
222258
State(state): State<Arc<AppState>>,
223-
Path(id): Path<String>,
224-
) -> Result<Json<VMResponse>, StatusCode> {
259+
Path(id): Path<String>
260+
) -> Result<impl IntoResponse, StatusCode> {
225261
let libvirt = state.libvirt.lock().await;
226262

227263
let domain = libvirt.lookup_domain(&id)
@@ -258,53 +294,64 @@ async fn start_vm(
258294
}
259295

260296
#[axum::debug_handler]
261-
async fn stop_vm(
297+
async fn start_vm(
262298
State(state): State<Arc<AppState>>,
263-
Path(id): Path<String>,
264-
) -> Result<StatusCode, StatusCode> {
299+
Path(id): Path<String>
300+
) -> Result<impl IntoResponse, StatusCode> {
265301
let libvirt = state.libvirt.lock().await;
266-
267-
libvirt.stop_domain(&id)
302+
libvirt.start_vm(&id)
268303
.await
269-
.map_err(handle_error)?;
270-
304+
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
271305
Ok(StatusCode::OK)
272306
}
273307

274308
#[axum::debug_handler]
275-
async fn delete_vm(
309+
async fn stop_vm(
276310
State(state): State<Arc<AppState>>,
277-
Path(id): Path<String>,
278-
) -> Result<StatusCode, StatusCode> {
311+
Path(id): Path<String>
312+
) -> Result<impl IntoResponse, StatusCode> {
279313
let libvirt = state.libvirt.lock().await;
280-
281-
libvirt.delete_domain(&id)
314+
libvirt.stop_vm(&id)
282315
.await
283-
.map_err(handle_error)?;
284-
316+
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
285317
Ok(StatusCode::OK)
286318
}
287319

288320
#[axum::debug_handler]
289-
async fn list_gpus(
321+
async fn login(
290322
State(state): State<Arc<AppState>>,
291-
) -> Result<Json<Vec<GPUDevice>>, StatusCode> {
292-
let mut gpu_manager = state.gpu_manager.lock().await;
293-
294-
let gpus = gpu_manager.discover_gpus()
323+
Json(credentials): Json<LoginRequest>,
324+
) -> Result<Json<LoginResponse>, StatusCode> {
325+
let mut libvirt = state.libvirt.lock().await;
326+
327+
let domain = libvirt.lookup_domain(&credentials.username)
295328
.map_err(handle_error)?;
296329

297-
Ok(Json(gpus))
330+
let info = domain.get_info()
331+
.map_err(handle_error)?;
298332
}
299333

334+
// async fn login(
335+
// #[axum::debug_handler]
336+
// async fn list_gpus(
337+
// State(state): State<Arc<AppState>>
338+
// ) -> Result<impl IntoResponse, StatusCode> {
339+
// let gpu_manager = state.gpu_manager.lock().await;
340+
// let gpus = gpu_manager.list_gpus()
341+
// .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
342+
// Ok(Json(gpus))
343+
// }
344+
300345
#[axum::debug_handler]
301346
async fn attach_gpu(
302347
State(state): State<Arc<AppState>>,
303348
Path(id): Path<String>,
304-
Json(request): Json<AttachGPURequest>,
305-
) -> Result<StatusCode, StatusCode> {
306-
let libvirt = state.libvirt.lock().await;
349+
Json(request): Json<GPUConfig>
350+
) -> Result<impl IntoResponse, StatusCode> {
307351
let mut gpu_manager = state.gpu_manager.lock().await;
352+
353+
let gpu_id = request.gpu_id.clone();
354+
let gpu_manager = state.gpu_manager.lock().await;
308355

309356
let domain = libvirt.lookup_domain(&id)
310357
.map_err(handle_error)?;
@@ -323,15 +370,22 @@ async fn attach_gpu(
323370
Ok(StatusCode::OK)
324371
}
325372

373+
#[axum::debug_handler]
374+
async fn fallback_handler(
375+
State(state): State<Arc<AppState>>,
376+
req: Request,
377+
) -> Result<Response, StatusCode> {
378+
error!("Fallback handler called for request: {:?}", req);
379+
Err(StatusCode::NOT_FOUND)
380+
}
381+
326382
#[axum::debug_handler]
327383
async fn get_metrics(
328384
State(state): State<Arc<AppState>>,
329-
Path(id): Path<String>,
330-
) -> Result<Json<Vec<ResourceMetrics>>, StatusCode> {
385+
Path(id): Path<String>
386+
) -> Result<impl IntoResponse, StatusCode> {
331387
let metrics = state.metrics.lock().await;
332-
333388
let vm_metrics = metrics.get_vm_metrics(&id)
334-
.map_err(handle_error)?;
335-
389+
.map_err(|_| StatusCode::NOT_FOUND)?;
336390
Ok(Json(vm_metrics))
337391
}

src/main.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::sync::Arc;
22
use tokio::sync::Mutex;
33
use tracing::info;
44
use tokio::net::TcpListener;
5+
use tokio::sync::oneshot;
56

67
mod core;
78
mod gpu;
@@ -24,11 +25,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2425
24, // 24 hour retention
2526
)));
2627

28+
// Shutdown mechanism for graceful shutdown
29+
let (shutdown_sender, shutdown_receiver) = oneshot::channel();
30+
2731
// Initialize application state
2832
let state = Arc::new(api::AppState {
2933
libvirt,
3034
gpu_manager,
3135
metrics,
36+
shutdown_signal: Arc::new(Mutex::new(shutdown_sender)),
37+
shutdown_receiver: Arc::new(Mutex::new(shutdown_receiver)),
3238
});
3339

3440
// Create API router

0 commit comments

Comments
 (0)