Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/vllm/src/llm_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ impl LlmService {
Ok(self.validation_service.validate(request).await?)
}

/// Returns the model name
pub fn model(&self) -> String {
self.model_config.model_name.clone()
}

/// Stops the running instance
#[instrument(skip(self))]
pub async fn stop(self) -> Result<(), LlmServiceError> {
Expand Down
2 changes: 2 additions & 0 deletions server/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod chat_completions;
pub mod model_info;
pub mod validate_schema;

pub use chat_completions::*;
pub use model_info::*;
pub use validate_schema::*;
10 changes: 10 additions & 0 deletions server/src/api/model_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;

#[derive(Debug, Deserialize, Serialize, ToSchema)]
pub struct ModelInfo {
pub name: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
2 changes: 2 additions & 0 deletions server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async fn main() -> anyhow::Result<()> {
.await
.map_err(|e| anyhow::anyhow!("Failed to start `LlmService`, with error: {e}"))?;

let model_name = llm_service.model();
let join_handle = tokio::spawn(async move {
llm_service
.run()
Expand All @@ -60,6 +61,7 @@ async fn main() -> anyhow::Result<()> {
let app_state = AppState {
request_counter: Arc::new(AtomicU64::new(0)),
llm_service_sender,
model_name,
shutdown_signal_sender,
streaming_interval_in_millis: env::var("STREAMING_INTERVAL_IN_MILLIS")
.ok()
Expand Down
33 changes: 30 additions & 3 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use axum::{
extract::State,
http::{header, HeaderMap, StatusCode},
response::{sse::KeepAlive, IntoResponse, Sse},
routing::post,
routing::{get, post},
Json, Router,
};
use serde_json::{json, Value};
Expand All @@ -33,12 +33,15 @@ use crate::{
api::{
chat_completions::{ChatCompletionResponse, RequestBody},
validate_schema::validate_with_schema,
ModelInfo,
},
stream::Streamer,
};

/// The URL path to POST JSON for model chat completions.
pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
/// The URL path to GET JSON for model information.
pub const MODEL_INFO_PATH: &str = "/v1/models";
pub const AUTH_BEARER_PREFIX: &str = "Bearer ";

/// Represents the shared state of the application.
Expand All @@ -56,6 +59,8 @@ pub struct AppState {
/// This channel is used to send generate requests to the LLM service and receive
/// the output through a oneshot channel.
pub llm_service_sender: UnboundedSender<ServiceRequest>,
/// The name of the model.
pub model_name: String,
/// A sender for the shutdown signal.
///
/// This channel is used to send a shutdown signal to gracefully stop the server.
Expand All @@ -79,9 +84,10 @@ pub struct AppState {
#[openapi(
paths(
completion_handler,
validate_completion_handler
validate_completion_handler,
model_handler
),
components(schemas(ChatCompletionResponse, RequestBody)),
components(schemas(ChatCompletionResponse, RequestBody, ModelInfo)),
tags(
(name = "Atoma's Chat Completions", description = "Atoma's Chat completion API")
)
Expand Down Expand Up @@ -129,6 +135,7 @@ pub async fn run_server(
&format!("{CHAT_COMPLETIONS_PATH}/validate"),
post(validate_completion_handler),
)
.route(MODEL_INFO_PATH, get(model_handler))
.with_state(app_state)
.merge(SwaggerUi::new("/docs").url("/api-docs/openapi.json", ApiDoc::openapi()));

Expand Down Expand Up @@ -304,6 +311,26 @@ pub async fn validate_completion_handler(
Json(json!({"status": "success"}))
}

#[utoipa::path(
get,
path = MODEL_INFO_PATH,
responses(
(status = 200, description = "Model information", body = ModelInfo)
)
)]
pub async fn model_handler(app_state: State<AppState>) -> impl IntoResponse {
let model_info = ModelInfo {
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_secs(),
name: app_state.model_name.clone(),
object: "model".to_string(),
owned_by: "system".to_string(),
};
Json(model_info)
}

/// Handles a generate request by sending it to the LLM service and processing the response.
///
/// This function is responsible for:
Expand Down