- 
                Notifications
    You must be signed in to change notification settings 
- Fork 37
feat: Log probabilities support (with cache) #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Rui Vieira <[email protected]>
| // Check cache stats | ||
| hits, misses, hitRate := processor.GetCacheStats() | ||
| if hits != 1 { | ||
| t.Errorf("Expected 1 cache hit, got %d", hits) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all tests in this project we use Expect function of onsi/gomega. E.g. Expect(err).NotTo(HaveOccurred()), Expect(totalBlocks).To(Equal(unusedBlocks))...
Can you please change to be consistent with other tests.
| if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) { | ||
| finishReasonToSend = &finishReason | ||
| } | ||
|  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this part does not change during the loop, could be moved to be calculated before the loop starts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Moved to outside.
| @ruivieira Thank you for your contribution! | 
Signed-off-by: Rui Vieira <[email protected]>
Signed-off-by: Rui Vieira <[email protected]>
Signed-off-by: Rui Vieira <[email protected]>
| // API response, for chat completion. It sets either role, or token, or tool call info in the message. | ||
| func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall, | ||
| role string, finishReason *string) openaiserverapi.CompletionRespChunk { | ||
| role string, finishReason *string, includeLogprobs bool, topK int) openaiserverapi.CompletionRespChunk { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boolean value which defines if logprob should be returned and the topk value are constant during a request processing, such kind of values we store in streamingContext.
If token is empty no need to create all logprob related fields, otherwise relevant response part will be created.
| nPromptTokens: usageData.PromptTokens, | ||
| nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), | ||
| requestID: req.GetRequestID(), | ||
| request: req, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add shouldIncludeLogprobs and topK (maybe change the name to explain the context) to streamingContext instead of request.
(And thanks for adding the missing request id)
| func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall, | ||
| modelName string, finishReason string, usageData *openaiserverapi.Usage) { | ||
| resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, | ||
| resp := s.createCompletionResponse(reqCtx.CompletionReq, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please pass shouldIncludeLogprobs and topK instead of request
| // modelName - display name returned to the client and used in metrics. It is either the first alias | ||
| // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). | ||
| func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, | ||
| func (s *VllmSimulator) createCompletionResponse(req openaiserverapi.CompletionRequest, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please pass shouldIncludeLogprobs and topK instead of request
| type LogprobData struct { | ||
| MainLogprob float64 `json:"main_logprob"` | ||
| TopLogprobs []openaiserverapi.ChatCompletionLogProb `json:"top_logprobs"` | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do the struct and the fields have to be public? And why are the json tags needed?
| hits2, misses2, _ := processor.GetCacheStats() | ||
| Expect(hits2).To(BeNumerically(">=", 0)) | ||
| Expect(misses2).To(BeNumerically(">=", 3)) | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add tests for the simulator with logprobs:
completions and chat completions requests with logprobs with and without streaming, and check the response?
| cacheMutex sync.RWMutex | ||
|  | ||
| // cacheHits and cacheMisses for metrics | ||
| cacheHits int64 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cacheHits and cacheMisses should be changed under protection of the mutex or be an atomic counters (from sync/atomic package)
| } | ||
|  | ||
| // generateDeterministicLogprobs creates logprobs with deterministic values based on token content | ||
| // This follows vLLM's approach of consistent logprobs for the same token in similar contexts | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ruivieira can you please explain how context is taken into consideration here. As I understand a token should give same logprob only in the same/similar context. What does context means here? Isn't it the previous tokens? But they are not passed to this function.
| for k := range lp.tokenCache { | ||
| delete(lp.tokenCache, k) | ||
| break | ||
| } | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Iteration on a map will not be in order of adding objects to the map, is we want to remove the oldest item - need to store the insertion date/time or latest usage date/time
| } | ||
|  | ||
| // ClearCache clears the logprobs cache | ||
| func (lp *LogprobsProcessor) ClearCache() { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used
| } | ||
|  | ||
| // GetCacheStats returns cache performance statistics | ||
| func (lp *LogprobsProcessor) GetCacheStats() (hits, misses int64, hitRate float64) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider to not add function used in test only to the class, we may add an additional file or maybe choose not to use it in test
| @ruivieira Please see some more comment I added. General question: 
 If the purpose is only to support the API without concern for the actual returned content, why not generate random responses (both tokens and probabilities) and keep the implementation minimal — without involving cache, hash calculations, or other overhead? | 
| @mayabar That's a good point. My original implementation did take that approach, but ultimately I've pushed this one mainly for reproducibility and consistent (even if simulated) results. I agree that purely for API coverage generating random responses is enough and avoids the overhead. I'll open a new PR with a new implementation and, if preferred, we close this one? | 
| @ruivieira I think that continue with this PR is better way, you have here lot of relevant code, you just can remove from the logprobs_processor all code relevant to the cache and use random response generation. | 
| Let's continue with your new PR #221 | 
| Replaced by #221 | 
This PR is a proposal to add log probabilities (logprobs) support to llm-d-inference-sim following vLLM's architecture.
vLLM-Style LogprobsProcessor
Added a new
LogprobsProcessorthat following vLLM's implementation (same tokens always produce same logprobs)API coverage
/v1/completions:logprobsparameter (0-5)/v1/chat/completions:logprobsboolean +top_logprobsintegerRefer to #213