Skip to content

Commit 516ef6f

Browse files
shawn-mcdonald-devcursoragentaaronvg
authored
fix: select earliest successful LLM call by lexicographic request_id order (#2692)
### Issue Reference This PR fixes #2451 ### Summary Updated LLM call selection logic to correctly pick the earliest successful call by lexicographic `HttpRequestId` ordering. This ensures deterministic selection when multiple successful calls exist. ### Changes This PR addresses an issue where the code prioritized completion time instead of lexicographically earliest ulid uuid. 1. Filter to keep only candidates where `is_success = true` 2. Sort these successful candidates by http request id, ulid uuid, lexicographically) 3. Choose the first candidate from this sorted list ### Details - Filters candidates to only successful calls - Sorts by `request_id.to_string()` lexicographically (since `HttpRequestId` lacks `Ord`) - Picks the first entry in the sorted list - Verified via `cargo test -p baml-runtime -q` <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Refactors LLM call selection to pick the earliest successful call by lexicographic HttpRequestId and marks only that call as selected, with new tests covering success-vs-failure and ordering. > > - **Tracing/FunctionLog (storage.rs)** > - Refactor LLM call assembly: introduce `CallCandidate` to gather data for each `request_id`. > - Selection logic: filter successful candidates, sort by `request_id.to_string()` lexicographically, select the first; mark only that call as `selected` across `Basic` and `Stream` kinds. > - **Tests** > - Add `test_selected_call_prefers_success_over_failure` and `test_selected_call_chooses_earlier_success_if_last_failed` to verify selection behavior. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 9a3143e. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Co-authored-by: Cursor Agent <[email protected]> Co-authored-by: aaron <[email protected]>
1 parent 50c7026 commit 516ef6f

File tree

1 file changed

+264
-26
lines changed
  • engine/baml-runtime/src/tracingv2/storage

1 file changed

+264
-26
lines changed

engine/baml-runtime/src/tracingv2/storage/storage.rs

Lines changed: 264 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,25 @@ fn build_function_log(
277277
let end_ms = function_end_time;
278278
let duration = end_ms.map(|end| end.saturating_sub(start_ms));
279279

280-
// Build each LLMCall or LLMStreamCall
281-
let mut calls = Vec::new();
282-
for (_rid, call_acc) in calls_map {
280+
// Build each LLM call candidate first so we can compute the selected one by timestamp
281+
struct CallCandidate {
282+
request_id: HttpRequestId,
283+
is_stream: bool,
284+
client: String,
285+
provider: String,
286+
start_t: i64,
287+
end_t: i64,
288+
partial_duration: i64,
289+
http_request: Option<Arc<HTTPRequest>>,
290+
http_response: Option<Arc<HTTPResponse>>,
291+
http_response_stream: Option<Arc<Mutex<Vec<Arc<HTTPResponseStream>>>>>,
292+
local_usage: Usage,
293+
is_success: bool,
294+
}
295+
296+
let mut candidates: Vec<CallCandidate> = Vec::new();
297+
298+
for (rid, call_acc) in calls_map {
283299
let (client, provider) = parse_llm_client_and_provider(call_acc.llm_request.as_ref());
284300
let start_t = call_acc.timestamp_first_seen.unwrap_or(start_ms);
285301
let end_t = call_acc.timestamp_last_seen.unwrap_or(start_t);
@@ -308,47 +324,90 @@ fn build_function_log(
308324
(None, Some(j)) => Some(j),
309325
};
310326

311-
if !is_stream {
312-
// Basic LLMCall
327+
let is_success = call_acc
328+
.llm_response
329+
.as_ref()
330+
.map(|resp| resp.error_message.is_none())
331+
.unwrap_or(false);
332+
333+
candidates.push(CallCandidate {
334+
request_id: rid.clone(),
335+
is_stream,
336+
client,
337+
provider,
338+
start_t,
339+
end_t,
340+
partial_duration,
341+
http_request: call_acc.http_request.clone(),
342+
http_response: call_acc.http_response.clone(),
343+
http_response_stream: call_acc.http_response_stream.clone(),
344+
local_usage,
345+
is_success,
346+
});
347+
}
348+
349+
// Determine which candidate should be marked selected
350+
let mut selected_idx: Option<usize> = None;
351+
if !candidates.is_empty() {
352+
// Filter successful candidates
353+
let mut successful_calls: Vec<(usize, &CallCandidate)> = candidates
354+
.iter()
355+
.enumerate()
356+
.filter(|(_, c)| c.is_success)
357+
.collect();
358+
359+
if !successful_calls.is_empty() {
360+
// Sort successful calls by lexicographic order of request_id (ULID UUID)
361+
successful_calls
362+
.sort_by(|(_, a), (_, b)| a.request_id.to_string().cmp(&b.request_id.to_string()));
363+
364+
// Pick the first (earliest lexicographically)
365+
selected_idx = Some(successful_calls[0].0);
366+
}
367+
}
368+
369+
// Build final calls vector, marking only the selected one as selected
370+
let mut calls = Vec::new();
371+
for (i, c) in candidates.into_iter().enumerate() {
372+
let is_selected = matches!(selected_idx, Some(sel) if sel == i);
373+
if !c.is_stream {
313374
calls.push(LLMCallKind::Basic(LLMCall {
314-
client_name: client,
315-
provider,
375+
client_name: c.client,
376+
provider: c.provider,
316377
timing: Timing {
317-
start_time_utc_ms: start_t,
318-
duration_ms: Some(partial_duration),
378+
start_time_utc_ms: c.start_t,
379+
duration_ms: Some(c.partial_duration),
319380
},
320-
request: call_acc.http_request.clone(),
321-
response: call_acc.http_response.clone(),
322-
usage: Some(local_usage),
323-
selected: call_acc.llm_response.is_some(),
381+
request: c.http_request,
382+
response: c.http_response,
383+
usage: Some(c.local_usage),
384+
selected: is_selected,
324385
}));
325386
} else {
326-
let sse_chunks = call_acc.http_response_stream.and_then(|chunks| {
387+
let sse_chunks = c.http_response_stream.and_then(|chunks| {
327388
let chunks = chunks.lock().unwrap();
328389
let request_id = chunks.first().map(|e| e.request_id.clone())?;
329390
Some(Arc::new(LLMHTTPStreamResponse {
330391
request_id,
331392
event: chunks.iter().map(|e| e.event.clone()).collect::<Vec<_>>(),
332393
}))
333394
});
334-
335-
// Streaming call
336395
calls.push(LLMCallKind::Stream(LLMStreamCall {
337396
llm_call: LLMCall {
338-
client_name: client,
339-
provider,
397+
client_name: c.client,
398+
provider: c.provider,
340399
timing: Timing {
341-
start_time_utc_ms: start_t,
342-
duration_ms: Some(partial_duration),
400+
start_time_utc_ms: c.start_t,
401+
duration_ms: Some(c.partial_duration),
343402
},
344-
request: call_acc.http_request.clone(),
345-
response: call_acc.http_response.clone(),
346-
usage: Some(local_usage),
347-
selected: call_acc.llm_response.is_some(),
403+
request: c.http_request,
404+
response: c.http_response,
405+
usage: Some(c.local_usage),
406+
selected: is_selected,
348407
},
349408
timing: StreamTiming {
350-
start_time_utc_ms: start_t,
351-
duration_ms: Some(partial_duration),
409+
start_time_utc_ms: c.start_t,
410+
duration_ms: Some(c.partial_duration),
352411
},
353412
sse_chunks,
354413
}));
@@ -1262,6 +1321,185 @@ mod tests {
12621321
collector
12631322
}
12641323

1324+
#[test]
1325+
#[serial]
1326+
fn test_selected_call_prefers_success_over_failure() {
1327+
let rt = tokio::runtime::Runtime::new().unwrap();
1328+
rt.block_on(async {
1329+
let f_id = FunctionCallId::new();
1330+
1331+
// Create one failed response and one successful response
1332+
let rid_fail = HttpRequestId::new();
1333+
let rid_success = HttpRequestId::new();
1334+
1335+
let failed_req = LoggedLLMRequest {
1336+
request_id: rid_fail.clone(),
1337+
client_name: "client_a".into(),
1338+
client_provider: "provider_a".into(),
1339+
params: IndexMap::new(),
1340+
prompt: vec![LLMChatMessage {
1341+
role: "user".into(),
1342+
content: vec![LLMChatMessagePart::Text("hi".into())],
1343+
}],
1344+
};
1345+
let failed_resp = LoggedLLMResponse::new_failure(
1346+
rid_fail.clone(),
1347+
"boom".into(),
1348+
Some("m1".into()),
1349+
Some("error".into()),
1350+
vec![],
1351+
);
1352+
1353+
let ok_req = LoggedLLMRequest {
1354+
request_id: rid_success.clone(),
1355+
client_name: "client_b".into(),
1356+
client_provider: "provider_b".into(),
1357+
params: IndexMap::new(),
1358+
prompt: vec![LLMChatMessage {
1359+
role: "user".into(),
1360+
content: vec![LLMChatMessagePart::Text("hello".into())],
1361+
}],
1362+
};
1363+
let ok_resp = LoggedLLMResponse::new_success(
1364+
rid_success.clone(),
1365+
"m2".into(),
1366+
Some("stop".into()),
1367+
LLMUsage {
1368+
input_tokens: Some(1),
1369+
output_tokens: Some(2),
1370+
total_tokens: Some(3),
1371+
cached_input_tokens: Some(0),
1372+
},
1373+
"ok".into(),
1374+
vec![],
1375+
);
1376+
1377+
let collector = inject_test_events(
1378+
&f_id,
1379+
"test_selected_call",
1380+
vec![(failed_req, failed_resp), (ok_req, ok_resp)],
1381+
)
1382+
.await;
1383+
1384+
let mut flog = FunctionLog::new(f_id.clone());
1385+
let calls = flog.calls();
1386+
assert_eq!(calls.len(), 2);
1387+
1388+
// Exactly one should be marked selected, and it should be the success
1389+
let selected: Vec<_> = calls.iter().filter(|c| c.selected()).collect();
1390+
assert_eq!(selected.len(), 1);
1391+
let sel = selected[0];
1392+
match sel {
1393+
LLMCallKind::Basic(c) => {
1394+
assert_eq!(c.client_name, "client_b");
1395+
assert!(c.selected);
1396+
}
1397+
LLMCallKind::Stream(s) => {
1398+
assert_eq!(s.llm_call.client_name, "client_b");
1399+
assert!(s.llm_call.selected);
1400+
}
1401+
}
1402+
1403+
drop(flog);
1404+
drop(collector);
1405+
{
1406+
let tracer = BAML_TRACER.lock().unwrap();
1407+
assert_eq!(tracer.ref_count_for(&f_id), 0);
1408+
assert!(tracer.get_events(&f_id).is_none());
1409+
}
1410+
});
1411+
}
1412+
1413+
#[test]
1414+
#[serial]
1415+
fn test_selected_call_chooses_earlier_success_if_last_failed() {
1416+
let rt = tokio::runtime::Runtime::new().unwrap();
1417+
rt.block_on(async {
1418+
let f_id = FunctionCallId::new();
1419+
1420+
// First a successful call, then a failed call (latest is failed)
1421+
let rid_success = HttpRequestId::new();
1422+
let rid_fail = HttpRequestId::new();
1423+
1424+
let ok_req = LoggedLLMRequest {
1425+
request_id: rid_success.clone(),
1426+
client_name: "client_ok".into(),
1427+
client_provider: "provider_ok".into(),
1428+
params: IndexMap::new(),
1429+
prompt: vec![LLMChatMessage {
1430+
role: "user".into(),
1431+
content: vec![LLMChatMessagePart::Text("hello".into())],
1432+
}],
1433+
};
1434+
let ok_resp = LoggedLLMResponse::new_success(
1435+
rid_success.clone(),
1436+
"m2".into(),
1437+
Some("stop".into()),
1438+
LLMUsage {
1439+
input_tokens: Some(1),
1440+
output_tokens: Some(2),
1441+
total_tokens: Some(3),
1442+
cached_input_tokens: Some(0),
1443+
},
1444+
"ok".into(),
1445+
vec![],
1446+
);
1447+
1448+
let failed_req = LoggedLLMRequest {
1449+
request_id: rid_fail.clone(),
1450+
client_name: "client_fail".into(),
1451+
client_provider: "provider_fail".into(),
1452+
params: IndexMap::new(),
1453+
prompt: vec![LLMChatMessage {
1454+
role: "user".into(),
1455+
content: vec![LLMChatMessagePart::Text("hi".into())],
1456+
}],
1457+
};
1458+
let failed_resp = LoggedLLMResponse::new_failure(
1459+
rid_fail.clone(),
1460+
"boom".into(),
1461+
Some("m1".into()),
1462+
Some("error".into()),
1463+
vec![],
1464+
);
1465+
1466+
// Inject in order: success first, failure second (so failure is latest by timestamp)
1467+
let collector = inject_test_events(
1468+
&f_id,
1469+
"test_selected_call_last_failed",
1470+
vec![(ok_req, ok_resp), (failed_req, failed_resp)],
1471+
)
1472+
.await;
1473+
1474+
let mut flog = FunctionLog::new(f_id.clone());
1475+
let calls = flog.calls();
1476+
assert_eq!(calls.len(), 2);
1477+
1478+
// Latest failed, we expect selected to be the successful earlier call
1479+
let selected: Vec<_> = calls.iter().filter(|c| c.selected()).collect();
1480+
assert_eq!(selected.len(), 1);
1481+
let sel = selected[0];
1482+
match sel {
1483+
LLMCallKind::Basic(c) => {
1484+
assert_eq!(c.client_name, "client_ok");
1485+
assert!(c.selected);
1486+
}
1487+
LLMCallKind::Stream(s) => {
1488+
assert_eq!(s.llm_call.client_name, "client_ok");
1489+
assert!(s.llm_call.selected);
1490+
}
1491+
}
1492+
1493+
drop(flog);
1494+
drop(collector);
1495+
{
1496+
let tracer = BAML_TRACER.lock().unwrap();
1497+
assert_eq!(tracer.ref_count_for(&f_id), 0);
1498+
assert!(tracer.get_events(&f_id).is_none());
1499+
}
1500+
});
1501+
}
1502+
12651503
#[test]
12661504
#[serial]
12671505
fn test_usage_accumulation_within_function_log_retries() {

0 commit comments

Comments
 (0)