Skip to content

Commit 2412c8a

Browse files
pkg: add ability to provide number of choices sampled.
To add the N parameter, each choice in the response is the same text content. This also helps simplify things a little bit in terms of testing and response construction. Signed-off-by: Madhav Jivrajani <[email protected]>
1 parent 5c58b12 commit 2412c8a

File tree

4 files changed

+231
-96
lines changed

4 files changed

+231
-96
lines changed

pkg/llm-d-inference-sim/simulator.go

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,9 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
570570
// usageData - usage (tokens statistics) for this response
571571
// modelName - display name returned to the client and used in metrics. It is either the first alias
572572
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
573+
// numCompletionOptions - number of choices to return in the response.
573574
func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
574-
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
575+
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, numCompletionOptions *int) openaiserverapi.CompletionResponse {
575576
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
576577
time.Now().Unix(), modelName, usageData)
577578

@@ -588,8 +589,6 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion
588589
baseResp.KVParams.TPSize = 1
589590
}
590591

591-
baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason)
592-
593592
respText := strings.Join(respTokens, "")
594593
if isChatCompletion {
595594
baseResp.Object = chatCompletionObject
@@ -601,41 +600,51 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion
601600
message.Content = openaiserverapi.Content{Raw: respText}
602601
}
603602

604-
choice := openaiserverapi.CreateChatRespChoice(baseChoice, message)
605-
606-
// Generate logprobs if requested
607-
if logprobs != nil && toolCalls == nil {
608-
if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 {
609-
choice.Logprobs = logprobsData
603+
// Generate numCompletionOptions choices in the response.
604+
choices := []openaiserverapi.ChatRespChoice{}
605+
for i := range *numCompletionOptions {
606+
baseChoice := openaiserverapi.CreateBaseResponseChoice(i, finishReason)
607+
choice := openaiserverapi.CreateChatRespChoice(baseChoice, message)
608+
// Generate logprobs if requested
609+
if logprobs != nil && toolCalls == nil {
610+
if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 {
611+
choice.Logprobs = logprobsData
612+
} else {
613+
// Set to nil if generation failed or content is empty
614+
choice.Logprobs = nil
615+
}
610616
} else {
611-
// Set to nil if generation failed or content is empty
617+
// Explicitly ensure logprobs is nil when not requested
612618
choice.Logprobs = nil
613619
}
614-
} else {
615-
// Explicitly ensure logprobs is nil when not requested
616-
choice.Logprobs = nil
620+
choices = append(choices, choice)
617621
}
618622

619-
return openaiserverapi.CreateChatCompletionResponse(baseResp, []openaiserverapi.ChatRespChoice{choice})
623+
return openaiserverapi.CreateChatCompletionResponse(baseResp, choices)
620624
}
621625

622-
choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText)
623-
624-
// Generate logprobs if requested for text completion
625-
if logprobs != nil && *logprobs > 0 {
626-
if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 {
627-
choice.Logprobs = logprobsData
626+
// Generate numCompletionOptions choices in the response.
627+
choices := []openaiserverapi.TextRespChoice{}
628+
for i := range *numCompletionOptions {
629+
baseChoice := openaiserverapi.CreateBaseResponseChoice(i, finishReason)
630+
choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText)
631+
// Generate logprobs if requested for text completion
632+
if logprobs != nil && *logprobs > 0 {
633+
if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 {
634+
choice.Logprobs = logprobsData
635+
} else {
636+
// Set to nil if generation failed or tokens is empty
637+
choice.Logprobs = nil
638+
}
628639
} else {
629-
// Set to nil if generation failed or tokens is empty
640+
// Explicitly ensure logprobs is nil when not requested
630641
choice.Logprobs = nil
631642
}
632-
} else {
633-
// Explicitly ensure logprobs is nil when not requested
634-
choice.Logprobs = nil
643+
choices = append(choices, choice)
635644
}
636645

637646
baseResp.Object = textCompletionObject
638-
return openaiserverapi.CreateTextCompletionResponse(baseResp, []openaiserverapi.TextRespChoice{choice})
647+
return openaiserverapi.CreateTextCompletionResponse(baseResp, choices)
639648
}
640649

641650
// sendResponse sends response for completion API, supports both completions (text and chat)
@@ -655,7 +664,7 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
655664
}
656665

657666
resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
658-
reqCtx.CompletionReq.IsDoRemoteDecode())
667+
reqCtx.CompletionReq.IsDoRemoteDecode(), reqCtx.CompletionReq.GetNumCompletionOptions())
659668

660669
// calculate how long to wait before returning the response, time is based on number of tokens
661670
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
@@ -668,7 +677,11 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
668677
common.WriteToChannel(s.metrics.reqPrefillTimeChan, time.Since(startPrefill).Seconds(), s.logger, "metrics.reqPrefillTimeChan")
669678

670679
startDecode := time.Now()
671-
for range usageData.CompletionTokens - 1 {
680+
// CompletionTokens accounts for all tokens across all choices in the response.
681+
// Each choice is going to have the same set of tokens from the simulator, therefore
682+
// 'preferred' choice is just the requisite share of tokens from the total CompletionTokens.
683+
actualComplCount := usageData.CompletionTokens / *reqCtx.CompletionReq.GetNumCompletionOptions()
684+
for range actualComplCount - 1 {
672685
perTokenLatency := s.getInterTokenLatency()
673686
time.Sleep(time.Duration(perTokenLatency) * time.Millisecond)
674687

0 commit comments

Comments
 (0)