diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index f5cd4a0aea..c361832173 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -318,11 +318,33 @@ async fn completions( request: Context, stream_handle: ConnectionHandle, ) -> Result { + use crate::protocols::openai::completions::get_prompt_batch_size; + // return a 503 if the service is not ready check_ready(&state)?; validate_completion_fields_generic(&request)?; + // Detect batch prompts + let batch_size = get_prompt_batch_size(&request.inner.prompt); + let n = request.inner.n.unwrap_or(1); + + // If single prompt or single-element batch, use original flow + if batch_size == 1 { + return completions_single(state, request, stream_handle).await; + } + + // Batch processing: handle multiple prompts + completions_batch(state, request, stream_handle, batch_size, n).await +} + +/// Handle single prompt completions (original logic) +#[tracing::instrument(skip_all)] +async fn completions_single( + state: Arc, + request: Context, + stream_handle: ConnectionHandle, +) -> Result { let request_id = request.id().to_string(); // todo - decide on default @@ -433,6 +455,162 @@ async fn completions( } } +/// Handle batch prompt completions (multiple prompts with n choices each) +#[tracing::instrument(skip_all)] +async fn completions_batch( + state: Arc, + request: Context, + stream_handle: ConnectionHandle, + batch_size: usize, + n: u8, +) -> Result { + use crate::protocols::openai::completions::extract_single_prompt; + use futures::stream::{self, StreamExt}; + + let request_id = request.id().to_string(); + let streaming = request.inner.stream.unwrap_or(false); + let model = request.inner.model.clone(); + + // Create http_queue_guard early - tracks time waiting to be processed + let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model); + + let engine = state + .manager() + .get_completions_engine(&model) + .map_err(|_| ErrorMessage::model_not_found())?; + + let parsing_options = state.manager().get_parsing_options(&model); + + let mut response_collector = state.metrics_clone().create_response_collector(&model); + + // prepare to process any annotations + let annotations = request.annotations(); + + // Create inflight_guard before calling engine to ensure errors are counted + let mut inflight_guard = + state + .metrics_clone() + .create_inflight_guard(&model, Endpoint::Completions, streaming); + + // Generate streams for each prompt in the batch + let mut all_streams = Vec::new(); + let mut first_ctx = None; + + for prompt_idx in 0..batch_size { + // Extract single prompt at this index + let single_prompt = extract_single_prompt(&request.inner.prompt, prompt_idx); + + // Create a new request with this single prompt + let mut single_request = request.content().clone(); + single_request.inner.prompt = single_prompt; + + // Generate unique request_id for each prompt: original_id-{prompt_idx} + let unique_request_id = format!("{}-{}", request.id(), prompt_idx); + let single_request_context = Context::with_id(single_request, unique_request_id); + + // Generate stream for this prompt + let stream = engine + .generate(single_request_context) + .await + .map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?; + + // Capture context from first stream + if first_ctx.is_none() { + first_ctx = Some(stream.context()); + } + + // Remap choice indices: choice.index += prompt_idx * n + let prompt_idx_u32 = prompt_idx as u32; + let n_u32 = n as u32; + let remapped_stream = stream.map(move |mut response| { + if let Some(ref mut data) = response.data { + for choice in &mut data.inner.choices { + choice.index += prompt_idx_u32 * n_u32; + } + } + response + }); + + all_streams.push(remapped_stream); + } + + // Merge all streams + let merged_stream = stream::select_all(all_streams); + + // capture the context to cancel the stream if the client disconnects + let ctx = first_ctx.expect("At least one stream should be generated"); + + let annotations_vec = annotations.map_or(Vec::new(), |annotations| { + annotations + .iter() + .filter_map(|annotation| { + if annotation == ANNOTATION_REQUEST_ID { + Annotated::::from_annotation( + ANNOTATION_REQUEST_ID, + &request_id, + ) + .ok() + } else { + None + } + }) + .collect::>() + }); + + // apply any annotations to the front of the stream + let merged_stream = stream::iter(annotations_vec).chain(merged_stream); + + if streaming { + // For streaming, we'll drop the http_queue_guard on the first token + let mut http_queue_guard = Some(http_queue_guard); + let stream = merged_stream.map(move |response| { + // Calls observe_response() on each token + process_response_using_event_converter_and_observe_metrics( + EventConverter::from(response), + &mut response_collector, + &mut http_queue_guard, + ) + }); + let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle); + + let mut sse_stream = Sse::new(stream); + + if let Some(keep_alive) = state.sse_keep_alive() { + sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive)); + } + + Ok(sse_stream.into_response()) + } else { + // Tap the stream to collect metrics for non-streaming requests without altering items + let mut http_queue_guard = Some(http_queue_guard); + let stream = merged_stream.inspect(move |response| { + // Calls observe_response() on each token - drops http_queue_guard on first token + process_response_and_observe_metrics( + response, + &mut response_collector, + &mut http_queue_guard, + ); + }); + + let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options) + .await + .map_err(|e| { + tracing::error!( + "Failed to fold completions stream for {}: {:?}", + request_id, + e + ); + ErrorMessage::internal_server_error(&format!( + "Failed to fold completions stream for {}: {:?}", + request_id, e + )) + })?; + + inflight_guard.mark_ok(); + Ok(Json(response).into_response()) + } +} + #[tracing::instrument(skip_all)] async fn embeddings( State(state): State>, diff --git a/lib/llm/src/protocols/openai/completions.rs b/lib/llm/src/protocols/openai/completions.rs index 45d5885b6c..48d232637e 100644 --- a/lib/llm/src/protocols/openai/completions.rs +++ b/lib/llm/src/protocols/openai/completions.rs @@ -78,6 +78,39 @@ pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String { } } +/// Get the batch size from a prompt (1 for single prompts, array length for batch prompts) +pub fn get_prompt_batch_size(prompt: &dynamo_async_openai::types::Prompt) -> usize { + match prompt { + dynamo_async_openai::types::Prompt::String(_) => 1, + dynamo_async_openai::types::Prompt::IntegerArray(_) => 1, + dynamo_async_openai::types::Prompt::StringArray(arr) => arr.len(), + dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr.len(), + } +} + +/// Extract a single prompt from a batch at the given index. +/// For single prompts, returns a clone regardless of index. +/// For batch prompts, returns the prompt at the specified index. +pub fn extract_single_prompt( + prompt: &dynamo_async_openai::types::Prompt, + index: usize, +) -> dynamo_async_openai::types::Prompt { + match prompt { + dynamo_async_openai::types::Prompt::String(s) => { + dynamo_async_openai::types::Prompt::String(s.clone()) + } + dynamo_async_openai::types::Prompt::IntegerArray(arr) => { + dynamo_async_openai::types::Prompt::IntegerArray(arr.clone()) + } + dynamo_async_openai::types::Prompt::StringArray(arr) => { + dynamo_async_openai::types::Prompt::String(arr[index].clone()) + } + dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => { + dynamo_async_openai::types::Prompt::IntegerArray(arr[index].clone()) + } + } +} + impl NvExtProvider for NvCreateCompletionRequest { fn nvext(&self) -> Option<&NvExt> { self.nvext.as_ref() @@ -403,7 +436,11 @@ impl ValidateRequest for NvCreateCompletionRequest { validate::validate_top_k(self.get_top_k())?; // Cross-field validation validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?; - + // total choices validation for completions batch requests + validate::validate_total_choices( + get_prompt_batch_size(&self.inner.prompt), + self.inner.n.unwrap_or(1), + )?; Ok(()) } } diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index c022c971c4..4798d46933 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -66,6 +66,9 @@ pub const MAX_N: u8 = 128; /// Allowed range of values for `n` (number of choices) pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N); +/// Maximum allowed total number of choices (batch_size × n) +pub const MAX_TOTAL_CHOICES: usize = 128; + /// Minimum allowed value for OpenAI's `logit_bias` values pub const MIN_LOGIT_BIAS: f32 = -100.0; /// Maximum allowed value for OpenAI's `logit_bias` values @@ -261,6 +264,21 @@ pub fn validate_n(n: Option) -> Result<(), anyhow::Error> { Ok(()) } +/// Validates total choices (batch_size × n) doesn't exceed maximum +pub fn validate_total_choices(batch_size: usize, n: u8) -> Result<(), anyhow::Error> { + let total_choices = batch_size * (n as usize); + if total_choices > MAX_TOTAL_CHOICES { + anyhow::bail!( + "Total choices (batch_size × n = {} × {} = {}) exceeds maximum of {}", + batch_size, + n, + total_choices, + MAX_TOTAL_CHOICES + ); + } + Ok(()) +} + /// Validates n and temperature interaction /// When n > 1, temperature must be > 0 to ensure diverse outputs pub fn validate_n_with_temperature( diff --git a/lib/llm/tests/openai_completions.rs b/lib/llm/tests/openai_completions.rs index e3902f2ab4..202c488870 100644 --- a/lib/llm/tests/openai_completions.rs +++ b/lib/llm/tests/openai_completions.rs @@ -118,3 +118,144 @@ fn build_samples() -> Result, String> { Ok(samples) } + +// ============================================================================ +// Batch Prompt Tests +// ============================================================================ + +#[test] +fn test_batch_prompt_utilities() { + use dynamo_async_openai::types::Prompt; + use dynamo_llm::protocols::openai::completions::{ + extract_single_prompt, get_prompt_batch_size, + }; + + // Test single string prompt + let single_string = Prompt::String("Hello, world!".to_string()); + assert_eq!(get_prompt_batch_size(&single_string), 1); + assert_eq!( + extract_single_prompt(&single_string, 0), + Prompt::String("Hello, world!".to_string()) + ); + + // Test single integer array prompt + let single_int = Prompt::IntegerArray(vec![1, 2, 3]); + assert_eq!(get_prompt_batch_size(&single_int), 1); + assert_eq!( + extract_single_prompt(&single_int, 0), + Prompt::IntegerArray(vec![1, 2, 3]) + ); + + // Test string array prompt + let string_array = Prompt::StringArray(vec![ + "First prompt".to_string(), + "Second prompt".to_string(), + "Third prompt".to_string(), + ]); + assert_eq!(get_prompt_batch_size(&string_array), 3); + assert_eq!( + extract_single_prompt(&string_array, 0), + Prompt::String("First prompt".to_string()) + ); + assert_eq!( + extract_single_prompt(&string_array, 1), + Prompt::String("Second prompt".to_string()) + ); + assert_eq!( + extract_single_prompt(&string_array, 2), + Prompt::String("Third prompt".to_string()) + ); + + // Test array of integer arrays + let int_array = Prompt::ArrayOfIntegerArray(vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]]); + assert_eq!(get_prompt_batch_size(&int_array), 3); + assert_eq!( + extract_single_prompt(&int_array, 0), + Prompt::IntegerArray(vec![1, 2, 3]) + ); + assert_eq!( + extract_single_prompt(&int_array, 1), + Prompt::IntegerArray(vec![4, 5]) + ); + assert_eq!( + extract_single_prompt(&int_array, 2), + Prompt::IntegerArray(vec![6, 7, 8, 9]) + ); +} + +#[test] +fn test_total_choices_validation() { + use dynamo_llm::protocols::openai::validate::validate_total_choices; + + // Valid cases + assert!(validate_total_choices(1, 1).is_ok()); + assert!(validate_total_choices(10, 10).is_ok()); + assert!(validate_total_choices(64, 2).is_ok()); + assert!(validate_total_choices(128, 1).is_ok()); + assert!(validate_total_choices(1, 128).is_ok()); + + // Edge case: exactly at the limit + assert!(validate_total_choices(128, 1).is_ok()); + assert!(validate_total_choices(64, 2).is_ok()); + + // Invalid cases: exceeds limit + assert!(validate_total_choices(129, 1).is_err()); + assert!(validate_total_choices(65, 2).is_err()); + assert!(validate_total_choices(100, 2).is_err()); + assert!(validate_total_choices(2, 100).is_err()); + + // Test error message + let result = validate_total_choices(100, 2); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.to_string() + .contains("Total choices (batch_size × n = 100 × 2 = 200) exceeds maximum of 128") + ); +} + +#[test] +fn test_batch_prompt_with_n_parameter() { + use dynamo_async_openai::types::Prompt; + use dynamo_llm::protocols::openai::completions::get_prompt_batch_size; + + // Test batch size calculation + let prompt = Prompt::StringArray(vec!["p1".to_string(), "p2".to_string(), "p3".to_string()]); + let batch_size = get_prompt_batch_size(&prompt); + let n = 2_u8; + + // Total choices = batch_size × n = 3 × 2 = 6 + let total_choices = batch_size * (n as usize); + assert_eq!(total_choices, 6); + + // Choice indices should be: + // prompt 0: indices 0, 1 + // prompt 1: indices 2, 3 + // prompt 2: indices 4, 5 + for prompt_idx in 0..batch_size { + for choice_idx in 0..n { + let expected_index = (prompt_idx as u32) * (n as u32) + (choice_idx as u32); + // Verify index calculation matches vLLM logic + assert_eq!( + expected_index, + prompt_idx as u32 * n as u32 + choice_idx as u32 + ); + } + } +} + +#[test] +fn test_single_prompt_in_array() { + use dynamo_async_openai::types::Prompt; + use dynamo_llm::protocols::openai::completions::{ + extract_single_prompt, get_prompt_batch_size, + }; + + // Single element array should work like regular prompt + let single_in_array = Prompt::StringArray(vec!["Single prompt".to_string()]); + assert_eq!(get_prompt_batch_size(&single_in_array), 1); + assert_eq!( + extract_single_prompt(&single_in_array, 0), + Prompt::String("Single prompt".to_string()) + ); +} diff --git a/tests/frontend/test_completion_mocker_engine.py b/tests/frontend/test_completion_mocker_engine.py index 7b5f0b33c9..1dd9b5e79a 100644 --- a/tests/frontend/test_completion_mocker_engine.py +++ b/tests/frontend/test_completion_mocker_engine.py @@ -158,6 +158,24 @@ def test_completion_string_prompt() -> None: ) +@pytest.mark.usefixtures("start_services") +@pytest.mark.e2e +@pytest.mark.model(TEST_MODEL) +def test_completion_empty_array_prompt() -> None: + payload: Dict[str, Any] = { + "model": TEST_MODEL, + "prompt": [], + "max_tokens": 2000, + } + + response = _send_completion_request(payload) + + assert response.status_code == 400, ( + f"Completion request should failed with status 400 but got" + f"{response.status_code}: {response.text}" + ) + + @pytest.mark.usefixtures("start_services") @pytest.mark.e2e @pytest.mark.model(TEST_MODEL) @@ -182,13 +200,25 @@ def test_completion_single_element_array_prompt() -> None: def test_completion_multi_element_array_prompt() -> None: payload: Dict[str, Any] = { "model": TEST_MODEL, - "prompt": ["Tell me about Mars", "Tell me about Ceres"], - "max_tokens": 2000, + "prompt": [ + "Tell me about Mars", + "Tell me about Ceres", + "Tell me about Jupiter", + ], + "max_tokens": 300, } response = _send_completion_request(payload) + response_data = response.json() + + assert response.status_code == 200, ( + f"Completion request failed with status " + f"{response.status_code}: {response.text}" + ) + + expected_choices = len(payload.get("prompt")) # type: ignore + choices = len(response_data.get("choices", [])) - # request should fail because we are sending multiple prompts assert ( - response.status_code == 500 - ), f"Request should fail with code 500; response:{response.text}" + expected_choices == choices + ), f"Expected {expected_choices} choices, got {choices}"