Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions lib/llm/src/http/service/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,33 @@ async fn completions(
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
use crate::protocols::openai::completions::get_prompt_batch_size;

// return a 503 if the service is not ready
check_ready(&state)?;

validate_completion_fields_generic(&request)?;

// Detect batch prompts
let batch_size = get_prompt_batch_size(&request.inner.prompt);
let n = request.inner.n.unwrap_or(1);

// If single prompt or single-element batch, use original flow
if batch_size == 1 {
return completions_single(state, request, stream_handle).await;
}

// Batch processing: handle multiple prompts
completions_batch(state, request, stream_handle, batch_size, n).await
}

/// Handle single prompt completions (original logic)
#[tracing::instrument(skip_all)]
async fn completions_single(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
) -> Result<Response, ErrorResponse> {
let request_id = request.id().to_string();

// todo - decide on default
Expand Down Expand Up @@ -433,6 +455,162 @@ async fn completions(
}
}

/// Handle batch prompt completions (multiple prompts with n choices each)
#[tracing::instrument(skip_all)]
async fn completions_batch(
state: Arc<service_v2::State>,
request: Context<NvCreateCompletionRequest>,
stream_handle: ConnectionHandle,
batch_size: usize,
n: u8,
) -> Result<Response, ErrorResponse> {
use crate::protocols::openai::completions::extract_single_prompt;
use futures::stream::{self, StreamExt};

let request_id = request.id().to_string();
let streaming = request.inner.stream.unwrap_or(false);
let model = request.inner.model.clone();

// Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);

let engine = state
.manager()
.get_completions_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;

let parsing_options = state.manager().get_parsing_options(&model);

let mut response_collector = state.metrics_clone().create_response_collector(&model);

// prepare to process any annotations
let annotations = request.annotations();

// Create inflight_guard before calling engine to ensure errors are counted
let mut inflight_guard =
state
.metrics_clone()
.create_inflight_guard(&model, Endpoint::Completions, streaming);

// Generate streams for each prompt in the batch
let mut all_streams = Vec::new();
let mut first_ctx = None;

for prompt_idx in 0..batch_size {
// Extract single prompt at this index
let single_prompt = extract_single_prompt(&request.inner.prompt, prompt_idx);

// Create a new request with this single prompt
let mut single_request = request.content().clone();
single_request.inner.prompt = single_prompt;

// Generate unique request_id for each prompt: original_id-{prompt_idx}
let unique_request_id = format!("{}-{}", request.id(), prompt_idx);
let single_request_context = Context::with_id(single_request, unique_request_id);

// Generate stream for this prompt
let stream = engine
.generate(single_request_context)
.await
.map_err(|e| ErrorMessage::from_anyhow(e, "Failed to generate completions"))?;

// Capture context from first stream
if first_ctx.is_none() {
first_ctx = Some(stream.context());
}

// Remap choice indices: choice.index += prompt_idx * n
let prompt_idx_u32 = prompt_idx as u32;
let n_u32 = n as u32;
let remapped_stream = stream.map(move |mut response| {
if let Some(ref mut data) = response.data {
for choice in &mut data.inner.choices {
choice.index += prompt_idx_u32 * n_u32;
}
}
response
});

all_streams.push(remapped_stream);
}

// Merge all streams
let merged_stream = stream::select_all(all_streams);

// capture the context to cancel the stream if the client disconnects
let ctx = first_ctx.expect("At least one stream should be generated");

let annotations_vec = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateCompletionResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});

// apply any annotations to the front of the stream
let merged_stream = stream::iter(annotations_vec).chain(merged_stream);

if streaming {
// For streaming, we'll drop the http_queue_guard on the first token
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream.map(move |response| {
// Calls observe_response() on each token
process_response_using_event_converter_and_observe_metrics(
EventConverter::from(response),
&mut response_collector,
&mut http_queue_guard,
)
});
let stream = monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);

let mut sse_stream = Sse::new(stream);

if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
}

Ok(sse_stream.into_response())
} else {
// Tap the stream to collect metrics for non-streaming requests without altering items
let mut http_queue_guard = Some(http_queue_guard);
let stream = merged_stream.inspect(move |response| {
// Calls observe_response() on each token - drops http_queue_guard on first token
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});

let response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
.await
.map_err(|e| {
tracing::error!(
"Failed to fold completions stream for {}: {:?}",
request_id,
e
);
ErrorMessage::internal_server_error(&format!(
"Failed to fold completions stream for {}: {:?}",
request_id, e
))
})?;

inflight_guard.mark_ok();
Ok(Json(response).into_response())
}
}

#[tracing::instrument(skip_all)]
async fn embeddings(
State(state): State<Arc<service_v2::State>>,
Expand Down
39 changes: 38 additions & 1 deletion lib/llm/src/protocols/openai/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,39 @@ pub fn prompt_to_string(prompt: &dynamo_async_openai::types::Prompt) -> String {
}
}

/// Get the batch size from a prompt (1 for single prompts, array length for batch prompts)
pub fn get_prompt_batch_size(prompt: &dynamo_async_openai::types::Prompt) -> usize {
match prompt {
dynamo_async_openai::types::Prompt::String(_) => 1,
dynamo_async_openai::types::Prompt::IntegerArray(_) => 1,
dynamo_async_openai::types::Prompt::StringArray(arr) => arr.len(),
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr.len(),
}
}

/// Extract a single prompt from a batch at the given index.
/// For single prompts, returns a clone regardless of index.
/// For batch prompts, returns the prompt at the specified index.
pub fn extract_single_prompt(
prompt: &dynamo_async_openai::types::Prompt,
index: usize,
) -> dynamo_async_openai::types::Prompt {
match prompt {
dynamo_async_openai::types::Prompt::String(s) => {
dynamo_async_openai::types::Prompt::String(s.clone())
}
dynamo_async_openai::types::Prompt::IntegerArray(arr) => {
dynamo_async_openai::types::Prompt::IntegerArray(arr.clone())
}
dynamo_async_openai::types::Prompt::StringArray(arr) => {
dynamo_async_openai::types::Prompt::String(arr[index].clone())
}
dynamo_async_openai::types::Prompt::ArrayOfIntegerArray(arr) => {
dynamo_async_openai::types::Prompt::IntegerArray(arr[index].clone())
}
}
}

impl NvExtProvider for NvCreateCompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
Expand Down Expand Up @@ -403,7 +436,11 @@ impl ValidateRequest for NvCreateCompletionRequest {
validate::validate_top_k(self.get_top_k())?;
// Cross-field validation
validate::validate_n_with_temperature(self.inner.n, self.inner.temperature)?;

// total choices validation for completions batch requests
validate::validate_total_choices(
get_prompt_batch_size(&self.inner.prompt),
self.inner.n.unwrap_or(1),
)?;
Ok(())
}
}
18 changes: 18 additions & 0 deletions lib/llm/src/protocols/openai/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub const MAX_N: u8 = 128;
/// Allowed range of values for `n` (number of choices)
pub const N_RANGE: (u8, u8) = (MIN_N, MAX_N);

/// Maximum allowed total number of choices (batch_size × n)
pub const MAX_TOTAL_CHOICES: usize = 128;

/// Minimum allowed value for OpenAI's `logit_bias` values
pub const MIN_LOGIT_BIAS: f32 = -100.0;
/// Maximum allowed value for OpenAI's `logit_bias` values
Expand Down Expand Up @@ -261,6 +264,21 @@ pub fn validate_n(n: Option<u8>) -> Result<(), anyhow::Error> {
Ok(())
}

/// Validates total choices (batch_size × n) doesn't exceed maximum
pub fn validate_total_choices(batch_size: usize, n: u8) -> Result<(), anyhow::Error> {
let total_choices = batch_size * (n as usize);
if total_choices > MAX_TOTAL_CHOICES {
anyhow::bail!(
"Total choices (batch_size × n = {} × {} = {}) exceeds maximum of {}",
batch_size,
n,
total_choices,
MAX_TOTAL_CHOICES
);
}
Ok(())
}

/// Validates n and temperature interaction
/// When n > 1, temperature must be > 0 to ensure diverse outputs
pub fn validate_n_with_temperature(
Expand Down
Loading
Loading