Skip to content

Commit 93ada89

Browse files
authored
feat: enable HTTP completion endpoint to accept arrays of prompts and generate multiple completions per prompt (#3953)
Signed-off-by: zhongdaor <[email protected]>
1 parent 6bccf09 commit 93ada89

File tree

5 files changed

+410
-6
lines changed

5 files changed

+410
-6
lines changed

lib/llm/src/http/service/openai.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,33 @@ async fn completions(
318318
request: Context<NvCreateCompletionRequest>,
319319
stream_handle: ConnectionHandle,
320320
) -> Result<Response, ErrorResponse> {
321+
use crate::protocols::openai::completions::get_prompt_batch_size;
322+
321323
// return a 503 if the service is not ready
322324
check_ready(&state)?;
323325

324326
validate_completion_fields_generic(&request)?;
325327

328+
// Detect batch prompts
329+
let batch_size = get_prompt_batch_size(&request.inner.prompt);
330+
let n = request.inner.n.unwrap_or(1);
331+
332+
// If single prompt or single-element batch, use original flow
333+
if batch_size == 1 {
334+
return completions_single(state, request, stream_handle).await;
335+
}
336+
337+
// Batch processing: handle multiple prompts
338+
completions_batch(state, request, stream_handle, batch_size, n).await
339+
}
340+
341+
/// Handle single prompt completions (original logic)
342+
#[tracing::instrument(skip_all)]
343+
async fn completions_single(
344+
state: Arc<service_v2::State>,
345+
request: Context<NvCreateCompletionRequest>,
346+
stream_handle: ConnectionHandle,
347+
) -> Result<Response, ErrorResponse> {
326348
let request_id = request.id().to_string();
327349

328350
// todo - decide on default
@@ -433,6 +455,162 @@ async fn completions(
433455
}
434456
}
435457

458+
/// Handle batch prompt completions (multiple prompts with n choices each)
459+
#[tracing::instrument(skip_all)]
460+
async fn completions_batch(
461+
state: Arc<service_v2::State>,
462+
request: Context<NvCreateCompletionRequest>,
463+
stream_handle: ConnectionHandle,
464+
batch_size: usize,
465+
n: u8,
466+
) -> Result<Response, ErrorResponse> {
467+
use crate::protocols::openai::completions::extract_single_prompt;
468+
use futures::stream::{self, StreamExt};
469+
470+
let request_id = request.id().to_string();
471+
let streaming = request.inner.stream.unwrap_or(false);
472+
let model = request.inner.model.clone();
473+
474+
// Create http_queue_guard early - tracks time waiting to be processed
475+
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
476+
477+
let engine = state
478+
.manager()
479+
.get_completions_engine(&model)
480+
.map_err(|_| ErrorMessage::model_not_found())?;
481+
482+
let parsing_options = state.manager().get_parsing_options(&model);
483+
484+
let mut response_collector = state.metrics_clone().create_response_collector(&model);
485+
486+
// prepare to process any annotations
487+
let annotations = request.annotations();
488+
489+
// Create inflight_guard before calling engine to ensure errors are counted
490+
let mut inflight_guard =
491+
state
492+
.metrics_clone()
493+
.create_inflight_guard(&model, Endpoint::Completions, streaming);
494+
495+
// Generate streams for each prompt in the batch
496+
let mut all_streams = Vec::new();
497+
let mut first_ctx = None;
498+
499+
for prompt_idx in 0..batch_size {
500+
// Extract single prompt at this index
501+
let single_prompt = extract_single_prompt(&request.inner.prompt, prompt_idx);
502+
503+
// Create a new request with this single prompt
504+
let mut single_request = request.content().clone();
505+
single_request.inner.prompt = single_prompt;
506+
507+
// Generate unique request_id for each prompt: original_id-{prompt_idx}
508+
let unique_request_id = format!("{}-{}", request.id(), prompt_idx);
509+
let single_request_context = Context::with_id(single_request, unique_request_id);
510+
511+
// Generate stream for this prompt
512+
let stream = engine
513+
.generate(single_request_context)
514+
.await
515+
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;
516+
517+
// Capture context from first stream
518+
if first_ctx.is_none() {
519+
first_ctx = Some(stream.context());
520+
}
521+
522+
// Remap choice indices: choice.index += prompt_idx * n
523+
let prompt_idx_u32 = prompt_idx as u32;
524+
let n_u32 = n as u32;
525+
let remapped_stream = stream.map(move |mut response| {
526+
if let Some(ref mut data) = response.data {
527+
for choice in &mut data.inner.choices {
528+
choice.index += prompt_idx_u32 * n_u32;
529+
}
530+
}
531+
response
532+
});
533+
534+
all_streams.push(remapped_stream);
535+
}
536+
537+
// Merge all streams
538+
let merged_stream = stream::select_all(all_streams);
539+
540+
// capture the context to cancel the stream if the client disconnects
541+
let ctx = first_ctx.expect("At least one stream should be generated");
542+
543+
let annotations_vec = annotations.map_or(Vec::new(), |annotations| {
544+
annotations
545+
.iter()
546+
.filter_map(|annotation| {
547+
if annotation == ANNOTATION_REQUEST_ID {
548+
Annotated::<NvCreateCompletionResponse>::from_annotation(
549+
ANNOTATION_REQUEST_ID,
550+
&request_id,
551+
)
552+
.ok()
553+
} else {
554+
None
555+
}
556+
})
557+
.collect::<Vec<_>>()
558+
});
559+
560+
// apply any annotations to the front of the stream
561+
let merged_stream = stream::iter(annotations_vec).chain(merged_stream);
562+
563+
if streaming {
564+
// For streaming, we'll drop the http_queue_guard on the first token
565+
let mut http_queue_guard = Some(http_queue_guard);
566+
let stream = merged_stream.map(move |response| {
567+
// Calls observe_response() on each token
568+
process_response_using_event_converter_and_observe_metrics(
569+
EventConverter::from(response),
570+
&mut response_collector,
571+
&mut http_queue_guard,
572+
)
573+
});
574+
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
575+
576+
let mut sse_stream = Sse::new(stream);
577+
578+
if let Some(keep_alive) = state.sse_keep_alive() {
579+
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
580+
}
581+
582+
Ok(sse_stream.into_response())
583+
} else {
584+
// Tap the stream to collect metrics for non-streaming requests without altering items
585+
let mut http_queue_guard = Some(http_queue_guard);
586+
let stream = merged_stream.inspect(move |response| {
587+
// Calls observe_response() on each token - drops http_queue_guard on first token
588+
process_response_and_observe_metrics(
589+
response,
590+
&mut response_collector,
591+
&mut http_queue_guard,
592+
);
593+
});
594+
595+
let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
596+
.await
597+
.map_err(|e| {
598+
tracing::error!(
599+
"Failed to fold completions stream for {}: {:?}",
600+
request_id,
601+
e
602+
);
603+
ErrorMessage::internal_server_error(&format!(
604+
"Failed to fold completions stream for {}: {:?}",
605+
request_id, e
606+
))
607+
})?;
608+
609+
inflight_guard.mark_ok();
610+
Ok(Json(response).into_response())
611+
}
612+
}
613+
436614
#[tracing::instrument(skip_all)]
437615
async fn embeddings(
438616
State(state): State<Arc<service_v2::State>>,

lib/llm/src/protocols/openai/completions.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,39 @@ pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
7878
}
7979
}
8080

81+
/// Get the batch size from a prompt (1 for single prompts, array length for batch prompts)
82+
pub fn get_prompt_batch_size(prompt: &dynamo_async_openai::types::Prompt) -> usize {
83+
match prompt {
84+
dynamo_async_openai::types::Prompt::String(_) => 1,
85+
dynamo_async_openai::types::Prompt::IntegerArray(_) => 1,
86+
dynamo_async_openai::types::Prompt::StringArray(arr) => arr.len(),
87+
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr.len(),
88+
}
89+
}
90+
91+
/// Extract a single prompt from a batch at the given index.
92+
/// For single prompts, returns a clone regardless of index.
93+
/// For batch prompts, returns the prompt at the specified index.
94+
pub fn extract_single_prompt(
95+
prompt: &dynamo_async_openai::types::Prompt,
96+
index: usize,
97+
) -> dynamo_async_openai::types::Prompt {
98+
match prompt {
99+
dynamo_async_openai::types::Prompt::String(s) => {
100+
dynamo_async_openai::types::Prompt::String(s.clone())
101+
}
102+
dynamo_async_openai::types::Prompt::IntegerArray(arr) => {
103+
dynamo_async_openai::types::Prompt::IntegerArray(arr.clone())
104+
}
105+
dynamo_async_openai::types::Prompt::StringArray(arr) => {
106+
dynamo_async_openai::types::Prompt::String(arr[index].clone())
107+
}
108+
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => {
109+
dynamo_async_openai::types::Prompt::IntegerArray(arr[index].clone())
110+
}
111+
}
112+
}
113+
81114
impl NvExtProvider for NvCreateCompletionRequest {
82115
fn nvext(&self) -> Option<&NvExt> {
83116
self.nvext.as_ref()
@@ -403,7 +436,11 @@ impl ValidateRequest for NvCreateCompletionRequest {
403436
validate::validate_top_k(self.get_top_k())?;
404437
// Cross-field validation
405438
validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?;
406-
439+
// total choices validation for completions batch requests
440+
validate::validate_total_choices(
441+
get_prompt_batch_size(&self.inner.prompt),
442+
self.inner.n.unwrap_or(1),
443+
)?;
407444
Ok(())
408445
}
409446
}

lib/llm/src/protocols/openai/validate.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ pub const MAX_N: u8 = 128;
6666
/// Allowed range of values for `n` (number of choices)
6767
pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N);
6868

69+
/// Maximum allowed total number of choices (batch_size × n)
70+
pub const MAX_TOTAL_CHOICES: usize = 128;
71+
6972
/// Minimum allowed value for OpenAI's `logit_bias` values
7073
pub const MIN_LOGIT_BIAS: f32 = -100.0;
7174
/// Maximum allowed value for OpenAI's `logit_bias` values
@@ -261,6 +264,21 @@ pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> {
261264
Ok(())
262265
}
263266

267+
/// Validates total choices (batch_size × n) doesn't exceed maximum
268+
pub fn validate_total_choices(batch_size: usize, n: u8) -> Result<(), anyhow::Error> {
269+
let total_choices = batch_size * (n as usize);
270+
if total_choices > MAX_TOTAL_CHOICES {
271+
anyhow::bail!(
272+
"Total choices (batch_size × n = {} × {} = {}) exceeds maximum of {}",
273+
batch_size,
274+
n,
275+
total_choices,
276+
MAX_TOTAL_CHOICES
277+
);
278+
}
279+
Ok(())
280+
}
281+
264282
/// Validates n and temperature interaction
265283
/// When n > 1, temperature must be > 0 to ensure diverse outputs
266284
pub fn validate_n_with_temperature(

0 commit comments

Comments
 (0)