Skip to content

Commit f7fc558

Browse files
authored
Fix: last completion chunk with finish_reason not sent in edge case scenario (#515)
Signed-off-by: declark1 <[email protected]>
1 parent 999b24c commit f7fc558

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

src/orchestrator/handlers/completions_detection/streaming.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,12 +615,19 @@ async fn process_detection_batch_stream(
615615
mut detection_batch_stream: DetectionBatchStream,
616616
response_tx: mpsc::Sender<Result<Option<Completion>, Error>>,
617617
) {
618+
let mut batch_tracker: HashMap<u32, Vec<(usize, usize)>> = HashMap::new();
618619
while let Some(result) = detection_batch_stream.next().await {
619620
match result {
620621
Ok((choice_index, chunk, detections)) => {
622+
let indices = (chunk.input_start_index, chunk.input_end_index);
621623
match output_detection_response(&completion_state, choice_index, chunk, detections)
622624
{
623625
Ok(completion) => {
626+
// Record indices for this batch
627+
batch_tracker
628+
.entry(choice_index)
629+
.and_modify(|entry| entry.push(indices))
630+
.or_insert(vec![indices]);
624631
// Send completion to response channel
625632
debug!(%trace_id, %choice_index, ?completion, "sending completion chunk to response channel");
626633
if response_tx.send(Ok(Some(completion))).await.is_err() {
@@ -644,5 +651,35 @@ async fn process_detection_batch_stream(
644651
}
645652
}
646653
}
654+
// Ensure the last completion chunk including finish_reason is sent for each choice.
655+
//
656+
// An edge case exists where the last completion chunk would not be included in the final batch
657+
// if it has empty choice text. This is because chunks without choice text are not sent to the detection pipeline.
658+
for (choice_index, indices) in batch_tracker {
659+
// Lookup the last completion chunk received
660+
let completions = completion_state.completions.get(&choice_index).unwrap();
661+
let (last_index, completion) = completions
662+
.last_key_value()
663+
.map(|(index, completion)| (*index, completion))
664+
.unwrap();
665+
// Get the index of last completion chunk included in the last batch
666+
let (_start_index, end_index) = indices.last().copied().unwrap();
667+
if last_index != end_index {
668+
// The last batch didn't include the last completion chunk, send it to the response channel
669+
if last_index != end_index + 1 {
670+
warn!(%trace_id, %choice_index, %last_index, %end_index, "unexpected number of completion chunks remaining for choice");
671+
debug!(%trace_id, ?completions);
672+
}
673+
debug!(%trace_id, %choice_index, ?completion, "sending last completion chunk to response channel");
674+
if response_tx
675+
.send(Ok(Some(completion.clone())))
676+
.await
677+
.is_err()
678+
{
679+
info!(%trace_id, "task completed: client disconnected");
680+
return;
681+
}
682+
}
683+
}
647684
info!(%trace_id, "task completed: detection batch stream closed");
648685
}

0 commit comments

Comments
 (0)