Skip to content

Commit 905dda5

Browse files
larock22tunahorse
andauthored
(feat) Add multiple rounds support for evaluation runs (#19)
* feat: Add multiple rounds support for evaluation runs - Add rounds query parameter to execute_eval_run endpoint - Generate unique run_id (UUID) for each round - Use existing run_id column to group evaluation results - Update frontend to support rounds input field - Add comprehensive integration tests for the feature - Default to 1 round for backward compatibility * fix: Maintain backward compatibility for eval runs API - Return single PromptEvalExecutionRunResponse when rounds=1 (default) - Return array of responses only when rounds > 1 - Update frontend to handle both response types - Add test for backward compatibility * docs: Add comprehensive feature summary for eval rounds * chore: Remove documentation files * fix: Address PR review feedback for eval rounds - Remove unwraps and handle errors properly with map_err - Reverse loop order: iterate over evals first, then rounds - Add tool calling support with proper error handling - Maintain backward compatibility for single round requests --------- Co-authored-by: tunahorse <[email protected]>
1 parent 32580d0 commit 905dda5

File tree

6 files changed

+861
-33
lines changed

6 files changed

+861
-33
lines changed

backend/src/controllers/prompt_eval_run.rs

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
use axum::{
2-
extract::{Path, State},
2+
extract::{Path, State, Query},
33
Json,
44
};
5+
use serde::Deserialize;
56
use uuid::Uuid;
67

8+
#[derive(Deserialize)]
9+
pub struct EvalRunParams {
10+
pub rounds: Option<i64>,
11+
}
12+
713
use super::types::{
814
request::prompt_eval_run::UpdateEvalRunRequest,
915
response::prompt_eval_run::{
@@ -12,19 +18,34 @@ use super::types::{
1218
};
1319

1420
use crate::{
15-
common::types::chat_request::ChatCompletionRequest,
21+
common::types::chat_request::{
22+
ChatCompletionRequest, ChatCompletionRequestTool, ChatCompletionRequestFunctionDescription
23+
},
1624
services::{llm::Llm, types::llm_service::LlmServiceRequest},
1725
AppError, AppState,
1826
};
1927

2028
pub async fn execute_eval_run(
2129
Path((prompt_id, prompt_version_id)): Path<(i64, i64)>,
2230
State(state): State<AppState>,
23-
) -> Result<Json<PromptEvalExecutionRunResponse>, AppError> {
31+
Query(params): Query<EvalRunParams>,
32+
) -> Result<Json<serde_json::Value>, AppError> {
2433
let prompt = state.db.prompt.get_prompt(prompt_id).await?;
2534
let evals = state.db.prompt_eval.get_by_prompt(prompt_id).await?;
26-
let run_id = Uuid::new_v4().to_string();
27-
let mut eval_runs = vec![];
35+
let tools_list = state.db.tool.get_tools_by_prompt_version(prompt.version_id).await?;
36+
let tools = tools_list.into_iter().map(|t| {
37+
ChatCompletionRequestTool::Function {
38+
function: ChatCompletionRequestFunctionDescription {
39+
name: t.tool_name,
40+
description: Some(t.description),
41+
parameters: serde_json::from_str(&t.parameters).unwrap_or_default(),
42+
strict: Some(t.strict)
43+
}
44+
}
45+
}).collect::<Vec<_>>();
46+
47+
let rounds = params.rounds.unwrap_or(1);
48+
let mut all_runs: Vec<PromptEvalExecutionRunResponse> = Vec::new();
2849

2950
for e in evals.iter() {
3051
// Parse system_prompt_input if present
@@ -40,20 +61,18 @@ pub async fn execute_eval_run(
4061
let chat_request = ChatCompletionRequest {
4162
model: prompt.key.clone(),
4263
messages: vec![
43-
// System message with context
4464
crate::common::types::chat_request::ChatCompletionRequestMessage::System {
4565
content: system_content,
4666
name: None,
4767
},
48-
// User message with content
4968
crate::common::types::chat_request::ChatCompletionRequestMessage::User {
5069
content: user_content,
5170
name: None,
5271
},
5372
],
5473
stream: None,
5574
response_format: None,
56-
tools: None,
75+
tools: Some(tools.clone()),
5776
provider: None,
5877
models: None,
5978
transforms: None,
@@ -67,26 +86,63 @@ pub async fn execute_eval_run(
6786
})?;
6887

6988
let llm = Llm::new(llm_props, state.db.log.clone());
70-
let res = llm
71-
.text()
72-
.await
73-
.map_err(|_| AppError::InternalServerError("Something went wrong".to_string()))?;
74-
75-
if let Some(c) = res.0.choices.first() {
76-
// TODO: We should make the DB field nullable so we don't have to hack this
77-
let content = c.message.content.clone().map(|c| c.to_string()).unwrap_or("".to_string());
78-
79-
let eval_run = state
80-
.db
81-
.prompt_eval_run
82-
.create(&run_id, prompt_version_id, e.id, None, &content)
83-
.await?;
84-
85-
eval_runs.push(eval_run);
89+
let mut eval_runs = Vec::new();
90+
91+
for _ in 0..rounds {
92+
let run_id = Uuid::new_v4().to_string();
93+
94+
let res = llm
95+
.text()
96+
.await
97+
.map_err(|e| {
98+
tracing::error!("LLM service error: {}", e);
99+
AppError::InternalServerError(format!("LLM service error: {}", e))
100+
})?;
101+
102+
if let Some(c) = res.0.choices.first() {
103+
if let Some(content) = &c.message.content {
104+
let eval_run = state
105+
.db
106+
.prompt_eval_run
107+
.create(&run_id, prompt_version_id, e.id, None, &content)
108+
.await?;
109+
110+
eval_runs.push(eval_run);
111+
}
112+
113+
// for not just stringify the tool calls
114+
if let Some(tool_calls) = &c.message.tool_calls {
115+
let tool_calls_string = serde_json::to_string(&tool_calls)
116+
.map_err(|e| AppError::InternalServerError(format!("Failed to serialize tool calls: {}", e)))?;
117+
118+
let eval_run = state
119+
.db
120+
.prompt_eval_run
121+
.create(&run_id, prompt_version_id, e.id, None, &tool_calls_string)
122+
.await?;
123+
124+
eval_runs.push(eval_run);
125+
}
126+
}
127+
86128
}
129+
130+
all_runs.push(eval_runs.into());
87131
}
88132

89-
Ok(Json(eval_runs.into()))
133+
// Maintain backward compatibility: return single response when rounds=1
134+
if rounds == 1 && !all_runs.is_empty() {
135+
// Return single response for backward compatibility
136+
Ok(Json(serde_json::to_value(all_runs.into_iter().next()
137+
.ok_or_else(|| AppError::InternalServerError("No eval runs generated".to_string()))?)
138+
.map_err(|e| AppError::InternalServerError(format!("Failed to serialize response: {}", e)))?)
139+
)
140+
} else {
141+
// Return array of responses for multiple rounds
142+
Ok(Json(serde_json::to_value(all_runs)
143+
.map_err(|e| AppError::InternalServerError(format!("Failed to serialize response: {}", e)))?)
144+
)
145+
}
90146
}
91147

92148
pub async fn get_eval_run_by_id(

backend/src/controllers/types/response/prompt_eval_run.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ impl From<PromptEvalRun> for PromptEvalRunResponse {
3535
// EXECUTION RESPONSE
3636
#[derive(Debug, Serialize, Deserialize)]
3737
pub struct PromptEvalExecutionRunResponse {
38-
run_id: String,
39-
runs: Vec<PromptEvalRunResponse>
38+
pub run_id: String,
39+
pub runs: Vec<PromptEvalRunResponse>
4040
}
4141

4242
impl From<Vec<PromptEvalRun>> for PromptEvalExecutionRunResponse {

0 commit comments

Comments
 (0)