Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/ai/llms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use async_openai::{
},
};
use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
Expand Down Expand Up @@ -171,6 +172,24 @@ pub fn format_prompt(prompt: String, variables: HashMap<String, String>) -> Stri
result
}

#[derive(Debug, Clone, Deserialize)]
struct ModelInfoRaw {
pub id: String,
pub name: String,
pub provider: String,
pub max_input_tokens: u32,
pub max_output_tokens: u32,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub provider: RemoteLLMProvider,
pub max_input_tokens: u32,
pub max_output_tokens: u32,
}

#[derive(Debug)]
pub struct LLMService {
pub local_vllm_client: Option<async_openai::Client<OpenAIConfig>>,
Expand All @@ -180,6 +199,7 @@ pub struct LLMService {
pub default_remote_models: Option<HashMap<RemoteLLMProvider, String>>,
pub local_gpu_manager: Arc<LocalGPUManager>,
pub is_unified_remote: bool,
pub models_info: HashMap<String, ModelInfo>,
}

impl LLMService {
Expand Down Expand Up @@ -415,6 +435,27 @@ impl LLMService {
_ => Some(default_remote_models),
};

let models_info_file = include_str!("models_info.json");
let models_info_values = serde_json::from_str::<Vec<ModelInfoRaw>>(models_info_file)
.context("Unable to parse models info JSON")?;

let mut models_info: HashMap<String, ModelInfo> = HashMap::new();
for model in models_info_values {
models_info.insert(
model.id.clone(),
ModelInfo {
id: model.id.clone(),
name: model.name.clone(),
provider: model
.provider
.parse()
.context("Invalid provider in models_info.json")?,
max_input_tokens: model.max_input_tokens,
max_output_tokens: model.max_output_tokens,
},
);
}

Ok(Self {
local_vllm_client,
unified_remote_client,
Expand All @@ -423,6 +464,7 @@ impl LLMService {
model: llm_config.model,
local_gpu_manager,
is_unified_remote,
models_info,
})
}

Expand Down
107 changes: 107 additions & 0 deletions src/ai/models_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
[
{
"name": "Haiku 4.5",
"provider": "anthropic",
"max_input_tokens": 200000,
"max_output_tokens": 64000,
"id": "claude-haiku-4-5"
},
{
"name": "Kimi K2",
"provider": "groq",
"max_input_tokens": 262144,
"max_output_tokens": 16384,
"id": "moonshotai/kimi-k2-instruct-0905"
},
{
"name": "Gemini 2.5 Flash",
"provider": "google",
"max_input_tokens": 1048576,
"max_output_tokens": 65535,
"id": "gemini-2.5-flash"
},
{
"name": "Sonnet 4.5",
"provider": "anthropic",
"max_input_tokens": 200000,
"max_output_tokens": 64000,
"id": "claude-sonnet-4-5"
},
{
"name": "GPT-5",
"provider": "openai",
"max_input_tokens": 400000,
"max_output_tokens": 128000,
"id": "gpt-5"
},
{
"name": "Qwen 3 32B",
"provider": "groq",
"max_input_tokens": 131072,
"max_output_tokens": 40960,
"id": "qwen/qwen3-32b"
},
{
"name": "GPT-5 Nano",
"provider": "openai",
"max_input_tokens": 400000,
"max_output_tokens": 128000,
"id": "gpt-5-nano"
},
{
"name": "Gemini 2.5 Flash Lite",
"provider": "google",
"max_input_tokens": 1048576,
"max_output_tokens": 65536,
"id": "gemini-2.5-flash-lite"
},
{
"name": "Llama 4 Maverick",
"provider": "groq",
"max_input_tokens": 131072,
"max_output_tokens": 8192,
"id": "meta-llama/llama-4-maverick-17b-128e-instruct"
},
{
"name": "GPT-OSS 120B",
"provider": "groq",
"max_input_tokens": 131072,
"max_output_tokens": 65536,
"id": "openai/gpt-oss-120b"
},
{
"name": "GPT-5 Pro",
"provider": "openai",
"max_input_tokens": 400000,
"max_output_tokens": 272000,
"id": "gpt-5-pro"
},
{
"name": "Llama 3.3 70B",
"provider": "groq",
"max_input_tokens": 131072,
"max_output_tokens": 32760,
"id": "llama-3.3-70b-versatile"
},
{
"name": "Gemini 2.5 Pro",
"provider": "google",
"max_input_tokens": 1048576,
"max_output_tokens": 65535,
"id": "gemini-2.5-pro"
},
{
"name": "Opus 4.1",
"provider": "anthropic",
"max_input_tokens": 200000,
"max_output_tokens": 32000,
"id": "claude-opus-4-1"
},
{
"name": "GPT-5 Mini",
"provider": "openai",
"max_input_tokens": 400000,
"max_output_tokens": 128000,
"id": "gpt-5-mini"
}
]
13 changes: 13 additions & 0 deletions src/ai/state_machines/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,19 @@ impl AnswerStateMachine {
.await
.map_err(|e| AnswerError::RagAtError(format!("{e:?}")))?
} else {
let model_info = self
.llm_service
.models_info
.get(&_llm_config.model)
.ok_or_else(|| {
AnswerError::LLMServiceError(format!(
"Model info not found for model: {}",
_llm_config.model
))
})?;

dbg!(model_info);

let max_documents = Limit(interaction.max_documents.unwrap_or(5));
let min_similarity = Similarity(interaction.min_similarity.unwrap_or(0.5));

Expand Down