diff --git a/engine/baml-runtime/src/tracingv2/storage/storage.rs b/engine/baml-runtime/src/tracingv2/storage/storage.rs index 56b86aa4b8..beadbc2675 100644 --- a/engine/baml-runtime/src/tracingv2/storage/storage.rs +++ b/engine/baml-runtime/src/tracingv2/storage/storage.rs @@ -277,9 +277,25 @@ fn build_function_log( let end_ms = function_end_time; let duration = end_ms.map(|end| end.saturating_sub(start_ms)); - // Build each LLMCall or LLMStreamCall - let mut calls = Vec::new(); - for (_rid, call_acc) in calls_map { + // Build each LLM call candidate first so we can compute the selected one by timestamp + struct CallCandidate { + request_id: HttpRequestId, + is_stream: bool, + client: String, + provider: String, + start_t: i64, + end_t: i64, + partial_duration: i64, + http_request: Option>, + http_response: Option>, + http_response_stream: Option>>>>, + local_usage: Usage, + is_success: bool, + } + + let mut candidates: Vec = Vec::new(); + + for (rid, call_acc) in calls_map { let (client, provider) = parse_llm_client_and_provider(call_acc.llm_request.as_ref()); let start_t = call_acc.timestamp_first_seen.unwrap_or(start_ms); let end_t = call_acc.timestamp_last_seen.unwrap_or(start_t); @@ -308,22 +324,67 @@ fn build_function_log( (None, Some(j)) => Some(j), }; - if !is_stream { - // Basic LLMCall + let is_success = call_acc + .llm_response + .as_ref() + .map(|resp| resp.error_message.is_none()) + .unwrap_or(false); + + candidates.push(CallCandidate { + request_id: rid.clone(), + is_stream, + client, + provider, + start_t, + end_t, + partial_duration, + http_request: call_acc.http_request.clone(), + http_response: call_acc.http_response.clone(), + http_response_stream: call_acc.http_response_stream.clone(), + local_usage, + is_success, + }); + } + + // Determine which candidate should be marked selected + let mut selected_idx: Option = None; + if !candidates.is_empty() { + // Filter successful candidates + let mut successful_calls: Vec<(usize, &CallCandidate)> = candidates + .iter() + .enumerate() + .filter(|(_, c)| c.is_success) + .collect(); + + if !successful_calls.is_empty() { + // Sort successful calls by lexicographic order of request_id (ULID UUID) + successful_calls + .sort_by(|(_, a), (_, b)| a.request_id.to_string().cmp(&b.request_id.to_string())); + + // Pick the first (earliest lexicographically) + selected_idx = Some(successful_calls[0].0); + } + } + + // Build final calls vector, marking only the selected one as selected + let mut calls = Vec::new(); + for (i, c) in candidates.into_iter().enumerate() { + let is_selected = matches!(selected_idx, Some(sel) if sel == i); + if !c.is_stream { calls.push(LLMCallKind::Basic(LLMCall { - client_name: client, - provider, + client_name: c.client, + provider: c.provider, timing: Timing { - start_time_utc_ms: start_t, - duration_ms: Some(partial_duration), + start_time_utc_ms: c.start_t, + duration_ms: Some(c.partial_duration), }, - request: call_acc.http_request.clone(), - response: call_acc.http_response.clone(), - usage: Some(local_usage), - selected: call_acc.llm_response.is_some(), + request: c.http_request, + response: c.http_response, + usage: Some(c.local_usage), + selected: is_selected, })); } else { - let sse_chunks = call_acc.http_response_stream.and_then(|chunks| { + let sse_chunks = c.http_response_stream.and_then(|chunks| { let chunks = chunks.lock().unwrap(); let request_id = chunks.first().map(|e| e.request_id.clone())?; Some(Arc::new(LLMHTTPStreamResponse { @@ -331,24 +392,22 @@ fn build_function_log( event: chunks.iter().map(|e| e.event.clone()).collect::>(), })) }); - - // Streaming call calls.push(LLMCallKind::Stream(LLMStreamCall { llm_call: LLMCall { - client_name: client, - provider, + client_name: c.client, + provider: c.provider, timing: Timing { - start_time_utc_ms: start_t, - duration_ms: Some(partial_duration), + start_time_utc_ms: c.start_t, + duration_ms: Some(c.partial_duration), }, - request: call_acc.http_request.clone(), - response: call_acc.http_response.clone(), - usage: Some(local_usage), - selected: call_acc.llm_response.is_some(), + request: c.http_request, + response: c.http_response, + usage: Some(c.local_usage), + selected: is_selected, }, timing: StreamTiming { - start_time_utc_ms: start_t, - duration_ms: Some(partial_duration), + start_time_utc_ms: c.start_t, + duration_ms: Some(c.partial_duration), }, sse_chunks, })); @@ -1262,6 +1321,185 @@ mod tests { collector } + #[test] + #[serial] + fn test_selected_call_prefers_success_over_failure() { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let f_id = FunctionCallId::new(); + + // Create one failed response and one successful response + let rid_fail = HttpRequestId::new(); + let rid_success = HttpRequestId::new(); + + let failed_req = LoggedLLMRequest { + request_id: rid_fail.clone(), + client_name: "client_a".into(), + client_provider: "provider_a".into(), + params: IndexMap::new(), + prompt: vec![LLMChatMessage { + role: "user".into(), + content: vec![LLMChatMessagePart::Text("hi".into())], + }], + }; + let failed_resp = LoggedLLMResponse::new_failure( + rid_fail.clone(), + "boom".into(), + Some("m1".into()), + Some("error".into()), + vec![], + ); + + let ok_req = LoggedLLMRequest { + request_id: rid_success.clone(), + client_name: "client_b".into(), + client_provider: "provider_b".into(), + params: IndexMap::new(), + prompt: vec![LLMChatMessage { + role: "user".into(), + content: vec![LLMChatMessagePart::Text("hello".into())], + }], + }; + let ok_resp = LoggedLLMResponse::new_success( + rid_success.clone(), + "m2".into(), + Some("stop".into()), + LLMUsage { + input_tokens: Some(1), + output_tokens: Some(2), + total_tokens: Some(3), + cached_input_tokens: Some(0), + }, + "ok".into(), + vec![], + ); + + let collector = inject_test_events( + &f_id, + "test_selected_call", + vec![(failed_req, failed_resp), (ok_req, ok_resp)], + ) + .await; + + let mut flog = FunctionLog::new(f_id.clone()); + let calls = flog.calls(); + assert_eq!(calls.len(), 2); + + // Exactly one should be marked selected, and it should be the success + let selected: Vec<_> = calls.iter().filter(|c| c.selected()).collect(); + assert_eq!(selected.len(), 1); + let sel = selected[0]; + match sel { + LLMCallKind::Basic(c) => { + assert_eq!(c.client_name, "client_b"); + assert!(c.selected); + } + LLMCallKind::Stream(s) => { + assert_eq!(s.llm_call.client_name, "client_b"); + assert!(s.llm_call.selected); + } + } + + drop(flog); + drop(collector); + { + let tracer = BAML_TRACER.lock().unwrap(); + assert_eq!(tracer.ref_count_for(&f_id), 0); + assert!(tracer.get_events(&f_id).is_none()); + } + }); + } + + #[test] + #[serial] + fn test_selected_call_chooses_earlier_success_if_last_failed() { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let f_id = FunctionCallId::new(); + + // First a successful call, then a failed call (latest is failed) + let rid_success = HttpRequestId::new(); + let rid_fail = HttpRequestId::new(); + + let ok_req = LoggedLLMRequest { + request_id: rid_success.clone(), + client_name: "client_ok".into(), + client_provider: "provider_ok".into(), + params: IndexMap::new(), + prompt: vec![LLMChatMessage { + role: "user".into(), + content: vec![LLMChatMessagePart::Text("hello".into())], + }], + }; + let ok_resp = LoggedLLMResponse::new_success( + rid_success.clone(), + "m2".into(), + Some("stop".into()), + LLMUsage { + input_tokens: Some(1), + output_tokens: Some(2), + total_tokens: Some(3), + cached_input_tokens: Some(0), + }, + "ok".into(), + vec![], + ); + + let failed_req = LoggedLLMRequest { + request_id: rid_fail.clone(), + client_name: "client_fail".into(), + client_provider: "provider_fail".into(), + params: IndexMap::new(), + prompt: vec![LLMChatMessage { + role: "user".into(), + content: vec![LLMChatMessagePart::Text("hi".into())], + }], + }; + let failed_resp = LoggedLLMResponse::new_failure( + rid_fail.clone(), + "boom".into(), + Some("m1".into()), + Some("error".into()), + vec![], + ); + + // Inject in order: success first, failure second (so failure is latest by timestamp) + let collector = inject_test_events( + &f_id, + "test_selected_call_last_failed", + vec![(ok_req, ok_resp), (failed_req, failed_resp)], + ) + .await; + + let mut flog = FunctionLog::new(f_id.clone()); + let calls = flog.calls(); + assert_eq!(calls.len(), 2); + + // Latest failed, we expect selected to be the successful earlier call + let selected: Vec<_> = calls.iter().filter(|c| c.selected()).collect(); + assert_eq!(selected.len(), 1); + let sel = selected[0]; + match sel { + LLMCallKind::Basic(c) => { + assert_eq!(c.client_name, "client_ok"); + assert!(c.selected); + } + LLMCallKind::Stream(s) => { + assert_eq!(s.llm_call.client_name, "client_ok"); + assert!(s.llm_call.selected); + } + } + + drop(flog); + drop(collector); + { + let tracer = BAML_TRACER.lock().unwrap(); + assert_eq!(tracer.ref_count_for(&f_id), 0); + assert!(tracer.get_events(&f_id).is_none()); + } + }); + } + #[test] #[serial] fn test_usage_accumulation_within_function_log_retries() {