@@ -318,11 +318,33 @@ async fn completions(
318318 request : Context < NvCreateCompletionRequest > ,
319319 stream_handle : ConnectionHandle ,
320320) -> Result < Response , ErrorResponse > {
321+ use crate :: protocols:: openai:: completions:: get_prompt_batch_size;
322+
321323 // return a 503 if the service is not ready
322324 check_ready ( & state) ?;
323325
324326 validate_completion_fields_generic ( & request) ?;
325327
328+ // Detect batch prompts
329+ let batch_size = get_prompt_batch_size ( & request. inner . prompt ) ;
330+ let n = request. inner . n . unwrap_or ( 1 ) ;
331+
332+ // If single prompt or single-element batch, use original flow
333+ if batch_size == 1 {
334+ return completions_single ( state, request, stream_handle) . await ;
335+ }
336+
337+ // Batch processing: handle multiple prompts
338+ completions_batch ( state, request, stream_handle, batch_size, n) . await
339+ }
340+
341+ /// Handle single prompt completions (original logic)
342+ #[ tracing:: instrument( skip_all) ]
343+ async fn completions_single (
344+ state : Arc < service_v2:: State > ,
345+ request : Context < NvCreateCompletionRequest > ,
346+ stream_handle : ConnectionHandle ,
347+ ) -> Result < Response , ErrorResponse > {
326348 let request_id = request. id ( ) . to_string ( ) ;
327349
328350 // todo - decide on default
@@ -433,6 +455,162 @@ async fn completions(
433455 }
434456}
435457
458+ /// Handle batch prompt completions (multiple prompts with n choices each)
459+ #[ tracing:: instrument( skip_all) ]
460+ async fn completions_batch (
461+ state : Arc < service_v2:: State > ,
462+ request : Context < NvCreateCompletionRequest > ,
463+ stream_handle : ConnectionHandle ,
464+ batch_size : usize ,
465+ n : u8 ,
466+ ) -> Result < Response , ErrorResponse > {
467+ use crate :: protocols:: openai:: completions:: extract_single_prompt;
468+ use futures:: stream:: { self , StreamExt } ;
469+
470+ let request_id = request. id ( ) . to_string ( ) ;
471+ let streaming = request. inner . stream . unwrap_or ( false ) ;
472+ let model = request. inner . model . clone ( ) ;
473+
474+ // Create http_queue_guard early - tracks time waiting to be processed
475+ let http_queue_guard = state. metrics_clone ( ) . create_http_queue_guard ( & model) ;
476+
477+ let engine = state
478+ . manager ( )
479+ . get_completions_engine ( & model)
480+ . map_err ( |_| ErrorMessage :: model_not_found ( ) ) ?;
481+
482+ let parsing_options = state. manager ( ) . get_parsing_options ( & model) ;
483+
484+ let mut response_collector = state. metrics_clone ( ) . create_response_collector ( & model) ;
485+
486+ // prepare to process any annotations
487+ let annotations = request. annotations ( ) ;
488+
489+ // Create inflight_guard before calling engine to ensure errors are counted
490+ let mut inflight_guard =
491+ state
492+ . metrics_clone ( )
493+ . create_inflight_guard ( & model, Endpoint :: Completions , streaming) ;
494+
495+ // Generate streams for each prompt in the batch
496+ let mut all_streams = Vec :: new ( ) ;
497+ let mut first_ctx = None ;
498+
499+ for prompt_idx in 0 ..batch_size {
500+ // Extract single prompt at this index
501+ let single_prompt = extract_single_prompt ( & request. inner . prompt , prompt_idx) ;
502+
503+ // Create a new request with this single prompt
504+ let mut single_request = request. content ( ) . clone ( ) ;
505+ single_request. inner . prompt = single_prompt;
506+
507+ // Generate unique request_id for each prompt: original_id-{prompt_idx}
508+ let unique_request_id = format ! ( "{}-{}" , request. id( ) , prompt_idx) ;
509+ let single_request_context = Context :: with_id ( single_request, unique_request_id) ;
510+
511+ // Generate stream for this prompt
512+ let stream = engine
513+ . generate ( single_request_context)
514+ . await
515+ . map_err ( |e| ErrorMessage :: from_anyhow ( e, "Failed to generate completions" ) ) ?;
516+
517+ // Capture context from first stream
518+ if first_ctx. is_none ( ) {
519+ first_ctx = Some ( stream. context ( ) ) ;
520+ }
521+
522+ // Remap choice indices: choice.index += prompt_idx * n
523+ let prompt_idx_u32 = prompt_idx as u32 ;
524+ let n_u32 = n as u32 ;
525+ let remapped_stream = stream. map ( move |mut response| {
526+ if let Some ( ref mut data) = response. data {
527+ for choice in & mut data. inner . choices {
528+ choice. index += prompt_idx_u32 * n_u32;
529+ }
530+ }
531+ response
532+ } ) ;
533+
534+ all_streams. push ( remapped_stream) ;
535+ }
536+
537+ // Merge all streams
538+ let merged_stream = stream:: select_all ( all_streams) ;
539+
540+ // capture the context to cancel the stream if the client disconnects
541+ let ctx = first_ctx. expect ( "At least one stream should be generated" ) ;
542+
543+ let annotations_vec = annotations. map_or ( Vec :: new ( ) , |annotations| {
544+ annotations
545+ . iter ( )
546+ . filter_map ( |annotation| {
547+ if annotation == ANNOTATION_REQUEST_ID {
548+ Annotated :: < NvCreateCompletionResponse > :: from_annotation (
549+ ANNOTATION_REQUEST_ID ,
550+ & request_id,
551+ )
552+ . ok ( )
553+ } else {
554+ None
555+ }
556+ } )
557+ . collect :: < Vec < _ > > ( )
558+ } ) ;
559+
560+ // apply any annotations to the front of the stream
561+ let merged_stream = stream:: iter ( annotations_vec) . chain ( merged_stream) ;
562+
563+ if streaming {
564+ // For streaming, we'll drop the http_queue_guard on the first token
565+ let mut http_queue_guard = Some ( http_queue_guard) ;
566+ let stream = merged_stream. map ( move |response| {
567+ // Calls observe_response() on each token
568+ process_response_using_event_converter_and_observe_metrics (
569+ EventConverter :: from ( response) ,
570+ & mut response_collector,
571+ & mut http_queue_guard,
572+ )
573+ } ) ;
574+ let stream = monitor_for_disconnects ( stream, ctx, inflight_guard, stream_handle) ;
575+
576+ let mut sse_stream = Sse :: new ( stream) ;
577+
578+ if let Some ( keep_alive) = state. sse_keep_alive ( ) {
579+ sse_stream = sse_stream. keep_alive ( KeepAlive :: default ( ) . interval ( keep_alive) ) ;
580+ }
581+
582+ Ok ( sse_stream. into_response ( ) )
583+ } else {
584+ // Tap the stream to collect metrics for non-streaming requests without altering items
585+ let mut http_queue_guard = Some ( http_queue_guard) ;
586+ let stream = merged_stream. inspect ( move |response| {
587+ // Calls observe_response() on each token - drops http_queue_guard on first token
588+ process_response_and_observe_metrics (
589+ response,
590+ & mut response_collector,
591+ & mut http_queue_guard,
592+ ) ;
593+ } ) ;
594+
595+ let response = NvCreateCompletionResponse :: from_annotated_stream ( stream, parsing_options)
596+ . await
597+ . map_err ( |e| {
598+ tracing:: error!(
599+ "Failed to fold completions stream for {}: {:?}" ,
600+ request_id,
601+ e
602+ ) ;
603+ ErrorMessage :: internal_server_error ( & format ! (
604+ "Failed to fold completions stream for {}: {:?}" ,
605+ request_id, e
606+ ) )
607+ } ) ?;
608+
609+ inflight_guard. mark_ok ( ) ;
610+ Ok ( Json ( response) . into_response ( ) )
611+ }
612+ }
613+
436614#[ tracing:: instrument( skip_all) ]
437615async fn embeddings (
438616 State ( state) : State < Arc < service_v2:: State > > ,
0 commit comments