From 21372357bebd619a7b75009f4aa0f86b2a436a19 Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Thu, 16 Oct 2025 11:12:52 +0100 Subject: [PATCH 1/4] feat: Log probabilities support Signed-off-by: Rui Vieira --- pkg/llm-d-inference-sim/logprobs_processor.go | 233 ++++++++++++++++++ .../logprobs_processor_test.go | 216 ++++++++++++++++ pkg/llm-d-inference-sim/simulator.go | 67 +++-- pkg/llm-d-inference-sim/streaming.go | 39 ++- pkg/openai-server-api/request.go | 36 +++ pkg/openai-server-api/response.go | 46 ++++ 6 files changed, 619 insertions(+), 18 deletions(-) create mode 100644 pkg/llm-d-inference-sim/logprobs_processor.go create mode 100644 pkg/llm-d-inference-sim/logprobs_processor_test.go diff --git a/pkg/llm-d-inference-sim/logprobs_processor.go b/pkg/llm-d-inference-sim/logprobs_processor.go new file mode 100644 index 00000000..df011fe0 --- /dev/null +++ b/pkg/llm-d-inference-sim/logprobs_processor.go @@ -0,0 +1,233 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "crypto/md5" + "fmt" + "sync" + + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +// LogprobData represents cached logprob information for a token +type LogprobData struct { + MainLogprob float64 `json:"main_logprob"` + TopLogprobs []openaiserverapi.ChatCompletionLogProb `json:"top_logprobs"` +} + +// LogprobsProcessor handles logprobs generation and caching following vLLM architecture +type LogprobsProcessor struct { + // tokenCache caches logprobs by token content and topK to avoid recomputation + tokenCache map[string]*LogprobData + cacheMutex sync.RWMutex + + // cacheHits and cacheMisses for metrics + cacheHits int64 + cacheMisses int64 + + // maxCacheSize limits memory usage + maxCacheSize int +} + +// NewLogprobsProcessor creates a new LogprobsProcessor following vLLM design patterns +func NewLogprobsProcessor(maxCacheSize int) *LogprobsProcessor { + if maxCacheSize <= 0 { + maxCacheSize = 10000 // Default cache size + } + + return &LogprobsProcessor{ + tokenCache: make(map[string]*LogprobData), + maxCacheSize: maxCacheSize, + } +} + +// generateCacheKey creates a deterministic key for caching based on token and topK +func (lp *LogprobsProcessor) generateCacheKey(token string, topK int) string { + return fmt.Sprintf("%s:%d", token, topK) +} + +// generateDeterministicLogprobs creates logprobs with deterministic values based on token content +// This follows vLLM's approach of consistent logprobs for the same token in similar contexts +func (lp *LogprobsProcessor) generateDeterministicLogprobs(token string, topK int) *LogprobData { + // Use token content to seed deterministic generation (similar to vLLM's approach) + hash := md5.Sum([]byte(token)) + seed := int64(hash[0])<<24 | int64(hash[1])<<16 | int64(hash[2])<<8 | int64(hash[3]) + + // Generate main logprob deterministically based on token + // Real logprobs are typically negative, with values closer to 0 being more likely + mainLogprob := -0.1 - (float64(seed%2000) / 1000.0) // Range: -0.1 to -2.1 + + if topK <= 0 { + return &LogprobData{ + MainLogprob: mainLogprob, + TopLogprobs: nil, + } + } + + // Generate top-k alternatives deterministically + topLogprobs := make([]openaiserverapi.ChatCompletionLogProb, 0, topK) + for i := 0; i < topK; i++ { + // Generate deterministic alternative token + altToken := fmt.Sprintf("alt_%d_%x", i, hash[i%4]) + + // Each alternative gets progressively lower probability + altLogprob := mainLogprob - (float64(i+1) * (0.5 + float64((seed+int64(i))%1500)/1000.0)) + + // Convert token to bytes + bytes := make([]int, len(altToken)) + for j, b := range []byte(altToken) { + bytes[j] = int(b) + } + + topLogprobs = append(topLogprobs, openaiserverapi.ChatCompletionLogProb{ + Token: altToken, + Logprob: altLogprob, + Bytes: bytes, + }) + } + + return &LogprobData{ + MainLogprob: mainLogprob, + TopLogprobs: topLogprobs, + } +} + +// GetLogprobs returns logprobs for a token, using cache when possible +func (lp *LogprobsProcessor) GetLogprobs(token string, topK int) (float64, []openaiserverapi.ChatCompletionLogProb) { + cacheKey := lp.generateCacheKey(token, topK) + + // Check cache first + lp.cacheMutex.RLock() + if cached, exists := lp.tokenCache[cacheKey]; exists { + lp.cacheMutex.RUnlock() + lp.cacheHits++ + return cached.MainLogprob, cached.TopLogprobs + } + lp.cacheMutex.RUnlock() + + // Cache miss - generate new logprobs + lp.cacheMisses++ + logprobData := lp.generateDeterministicLogprobs(token, topK) + + // Store in cache (with size limit) + lp.cacheMutex.Lock() + if len(lp.tokenCache) >= lp.maxCacheSize { + // Simple eviction: remove oldest entry + // In production, this could use LRU or other strategies + for k := range lp.tokenCache { + delete(lp.tokenCache, k) + break + } + } + lp.tokenCache[cacheKey] = logprobData + lp.cacheMutex.Unlock() + + return logprobData.MainLogprob, logprobData.TopLogprobs +} + +// ProcessChatLogprobs creates logprobs data for chat completions following vLLM patterns +func (lp *LogprobsProcessor) ProcessChatLogprobs(tokens []string, topK int) *openaiserverapi.ChatCompletionLogProbs { + if len(tokens) == 0 { + return nil + } + + logprobs := &openaiserverapi.ChatCompletionLogProbs{ + Content: make([]openaiserverapi.ChatCompletionLogProbsContent, 0, len(tokens)), + } + + for _, token := range tokens { + mainLogprob, topLps := lp.GetLogprobs(token, topK) + + // Convert token to bytes + bytes := make([]int, len(token)) + for i, b := range []byte(token) { + bytes[i] = int(b) + } + + logprobs.Content = append(logprobs.Content, openaiserverapi.ChatCompletionLogProbsContent{ + Token: token, + Logprob: mainLogprob, + Bytes: bytes, + TopLogprobs: topLps, + }) + } + + return logprobs +} + +// ProcessTextLogprobs creates logprobs data for text completions following vLLM patterns +func (lp *LogprobsProcessor) ProcessTextLogprobs(tokens []string, topK int) *openaiserverapi.CompletionLogProbs { + if len(tokens) == 0 { + return nil + } + + logprobs := &openaiserverapi.CompletionLogProbs{ + TextOffset: make([]int, 0, len(tokens)), + TokenLogprobs: make([]float64, 0, len(tokens)), + Tokens: make([]string, 0, len(tokens)), + } + + if topK > 0 { + logprobs.TopLogprobs = make([]map[string]float64, 0, len(tokens)) + } + + textOffset := 0 + for _, token := range tokens { + mainLogprob, topLps := lp.GetLogprobs(token, topK) + + logprobs.TextOffset = append(logprobs.TextOffset, textOffset) + logprobs.TokenLogprobs = append(logprobs.TokenLogprobs, mainLogprob) + logprobs.Tokens = append(logprobs.Tokens, token) + + if topK > 0 { + topMap := make(map[string]float64, len(topLps)) + for _, lp := range topLps { + topMap[lp.Token] = lp.Logprob + } + logprobs.TopLogprobs = append(logprobs.TopLogprobs, topMap) + } + + textOffset += len(token) + } + + return logprobs +} + +// GetCacheStats returns cache performance statistics +func (lp *LogprobsProcessor) GetCacheStats() (hits, misses int64, hitRate float64) { + lp.cacheMutex.RLock() + defer lp.cacheMutex.RUnlock() + + total := lp.cacheHits + lp.cacheMisses + hitRate = 0.0 + if total > 0 { + hitRate = float64(lp.cacheHits) / float64(total) + } + + return lp.cacheHits, lp.cacheMisses, hitRate +} + +// ClearCache clears the logprobs cache +func (lp *LogprobsProcessor) ClearCache() { + lp.cacheMutex.Lock() + defer lp.cacheMutex.Unlock() + + lp.tokenCache = make(map[string]*LogprobData) + lp.cacheHits = 0 + lp.cacheMisses = 0 +} \ No newline at end of file diff --git a/pkg/llm-d-inference-sim/logprobs_processor_test.go b/pkg/llm-d-inference-sim/logprobs_processor_test.go new file mode 100644 index 00000000..e1438b34 --- /dev/null +++ b/pkg/llm-d-inference-sim/logprobs_processor_test.go @@ -0,0 +1,216 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "testing" +) + +func TestLogprobsProcessor_Caching(t *testing.T) { + processor := NewLogprobsProcessor(100) + + // Test that same token generates same logprobs (deterministic) + token := "hello" + topK := 3 + + logprob1, topLogprobs1 := processor.GetLogprobs(token, topK) + logprob2, topLogprobs2 := processor.GetLogprobs(token, topK) + + // Should be identical (deterministic) + if logprob1 != logprob2 { + t.Errorf("Expected same logprob for same token, got %.4f and %.4f", logprob1, logprob2) + } + if len(topLogprobs1) != len(topLogprobs2) { + t.Errorf("Expected same topLogprobs length, got %d and %d", len(topLogprobs1), len(topLogprobs2)) + } + if len(topLogprobs1) != topK { + t.Errorf("Expected topLogprobs length %d, got %d", topK, len(topLogprobs1)) + } + + // Check cache stats + hits, misses, hitRate := processor.GetCacheStats() + if hits != 1 { + t.Errorf("Expected 1 cache hit, got %d", hits) + } + if misses != 1 { + t.Errorf("Expected 1 cache miss, got %d", misses) + } + if hitRate != 0.5 { + t.Errorf("Expected 50%% hit rate, got %.2f", hitRate) + } +} + +func TestLogprobsProcessor_DifferentTokens(t *testing.T) { + processor := NewLogprobsProcessor(100) + + // Test that different tokens generate different logprobs + logprob1, _ := processor.GetLogprobs("hello", 2) + logprob2, _ := processor.GetLogprobs("world", 2) + + if logprob1 == logprob2 { + t.Errorf("Different tokens should have different logprobs, both got %.4f", logprob1) + } +} + +func TestLogprobsProcessor_DifferentTopK(t *testing.T) { + processor := NewLogprobsProcessor(100) + + // Test that same token with different topK generates different results + token := "test" + + _, topLogprobs1 := processor.GetLogprobs(token, 2) + _, topLogprobs2 := processor.GetLogprobs(token, 5) + + if len(topLogprobs1) != 2 { + t.Errorf("Expected 2 top logprobs, got %d", len(topLogprobs1)) + } + if len(topLogprobs2) != 5 { + t.Errorf("Expected 5 top logprobs, got %d", len(topLogprobs2)) + } +} + +func TestLogprobsProcessor_ChatLogprobs(t *testing.T) { + processor := NewLogprobsProcessor(100) + + tokens := []string{"Hello", "world", "!"} + topK := 3 + + logprobs := processor.ProcessChatLogprobs(tokens, topK) + + if logprobs == nil { + t.Fatal("Expected non-nil chat logprobs") + } + if len(logprobs.Content) != len(tokens) { + t.Errorf("Expected %d content items, got %d", len(tokens), len(logprobs.Content)) + } + + for i, content := range logprobs.Content { + if content.Token != tokens[i] { + t.Errorf("Expected token %s at index %d, got %s", tokens[i], i, content.Token) + } + if content.Logprob >= 0 { + t.Errorf("Expected negative logprob, got %.4f", content.Logprob) + } + if len(content.TopLogprobs) != topK { + t.Errorf("Expected %d top logprobs, got %d", topK, len(content.TopLogprobs)) + } + if content.Bytes == nil { + t.Error("Expected non-nil bytes") + } + } +} + +func TestLogprobsProcessor_TextLogprobs(t *testing.T) { + processor := NewLogprobsProcessor(100) + + tokens := []string{"Hello", "world"} + topK := 2 + + logprobs := processor.ProcessTextLogprobs(tokens, topK) + + if logprobs == nil { + t.Fatal("Expected non-nil text logprobs") + } + if len(logprobs.Tokens) != len(tokens) { + t.Errorf("Expected %d tokens, got %d", len(tokens), len(logprobs.Tokens)) + } + if len(logprobs.TokenLogprobs) != len(tokens) { + t.Errorf("Expected %d token logprobs, got %d", len(tokens), len(logprobs.TokenLogprobs)) + } + if len(logprobs.TextOffset) != len(tokens) { + t.Errorf("Expected %d text offsets, got %d", len(tokens), len(logprobs.TextOffset)) + } + if len(logprobs.TopLogprobs) != len(tokens) { + t.Errorf("Expected %d top logprobs arrays, got %d", len(tokens), len(logprobs.TopLogprobs)) + } + + // Check text offsets are cumulative + expectedOffset := 0 + for i, token := range tokens { + if logprobs.TextOffset[i] != expectedOffset { + t.Errorf("Expected offset %d at index %d, got %d", expectedOffset, i, logprobs.TextOffset[i]) + } + if logprobs.Tokens[i] != token { + t.Errorf("Expected token %s at index %d, got %s", token, i, logprobs.Tokens[i]) + } + if logprobs.TokenLogprobs[i] >= 0 { + t.Errorf("Expected negative logprob at index %d, got %.4f", i, logprobs.TokenLogprobs[i]) + } + if len(logprobs.TopLogprobs[i]) != topK { + t.Errorf("Expected %d top logprobs at index %d, got %d", topK, i, len(logprobs.TopLogprobs[i])) + } + expectedOffset += len(token) + } +} + +func TestLogprobsProcessor_EmptyTokens(t *testing.T) { + processor := NewLogprobsProcessor(100) + + // Test empty token lists + chatLogprobs := processor.ProcessChatLogprobs([]string{}, 3) + textLogprobs := processor.ProcessTextLogprobs([]string{}, 3) + + if chatLogprobs != nil { + t.Error("Expected nil chat logprobs for empty tokens") + } + if textLogprobs != nil { + t.Error("Expected nil text logprobs for empty tokens") + } +} + +func TestLogprobsProcessor_ZeroTopK(t *testing.T) { + processor := NewLogprobsProcessor(100) + + logprob, topLogprobs := processor.GetLogprobs("test", 0) + + if logprob >= 0 { + t.Errorf("Expected negative logprob, got %.4f", logprob) + } + if topLogprobs != nil { + t.Error("Expected nil top logprobs for topK=0") + } +} + +func TestLogprobsProcessor_CacheEviction(t *testing.T) { + // Test with very small cache size to trigger eviction + processor := NewLogprobsProcessor(2) + + // Fill cache beyond capacity + processor.GetLogprobs("token1", 1) + processor.GetLogprobs("token2", 1) + processor.GetLogprobs("token3", 1) // Should trigger eviction + + hits, misses, _ := processor.GetCacheStats() + if hits != 0 { + t.Errorf("Expected 0 cache hits, got %d", hits) + } + if misses != 3 { + t.Errorf("Expected 3 cache misses, got %d", misses) + } + + // Access one of the earlier tokens - may or may not be in cache due to eviction + processor.GetLogprobs("token1", 1) + + // Cache should be working (some entries may have been evicted) + hits2, misses2, _ := processor.GetCacheStats() + if hits2 < 0 { + t.Errorf("Expected non-negative cache hits, got %d", hits2) + } + if misses2 < 3 { + t.Errorf("Expected at least 3 cache misses, got %d", misses2) + } +} \ No newline at end of file diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index e5d70ede..8604abd8 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -118,6 +118,8 @@ type VllmSimulator struct { tokenizer tokenization.Tokenizer // dataset is used for token generation in responses dataset dataset.Dataset + // logprobsProcessor handles logprobs generation and caching following vLLM architecture + logprobsProcessor *LogprobsProcessor } // New creates a new VllmSimulator instance with the given logger @@ -128,16 +130,17 @@ func New(logger logr.Logger) (*VllmSimulator, error) { } return &VllmSimulator{ - logger: logger, - reqChan: make(chan *openaiserverapi.CompletionReqCtx, maxNumberOfRequests), - toolsValidator: toolsValidator, - kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration - namespace: os.Getenv(podNsEnv), - pod: os.Getenv(podNameEnv), - runReqChan: make(chan int64, maxNumberOfRequests), - waitingReqChan: make(chan int64, maxNumberOfRequests), - lorasChan: make(chan loraUsage, maxNumberOfRequests), - kvCacheUsageChan: make(chan float64, maxNumberOfRequests), + logger: logger, + reqChan: make(chan *openaiserverapi.CompletionReqCtx, maxNumberOfRequests), + toolsValidator: toolsValidator, + kvcacheHelper: nil, // kvcache helper will be created only if required after reading configuration + namespace: os.Getenv(podNsEnv), + pod: os.Getenv(podNameEnv), + runReqChan: make(chan int64, maxNumberOfRequests), + waitingReqChan: make(chan int64, maxNumberOfRequests), + lorasChan: make(chan loraUsage, maxNumberOfRequests), + kvCacheUsageChan: make(chan float64, maxNumberOfRequests), + logprobsProcessor: NewLogprobsProcessor(10000), // Initialize with 10k cache size following vLLM patterns }, nil } @@ -393,6 +396,8 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { doRemotePrefill: req.IsDoRemotePrefill(), nPromptTokens: usageData.PromptTokens, nCachedPromptTokens: reqCtx.CompletionReq.GetNumberOfCachedPromptTokens(), + requestID: req.GetRequestID(), + request: req, }, responseTokens, toolCalls, finishReason, usageDataToSend, ) @@ -427,15 +432,35 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool } } +// generateSimulatedLogprobs creates synthetic but realistic-looking logprobs for a token +// Delegates to LogprobsProcessor following vLLM architecture patterns +// Returns the main token's logprob and a list of top-k alternative tokens with their logprobs +func (s *VllmSimulator) generateSimulatedLogprobs(token string, topK int) (float64, []openaiserverapi.ChatCompletionLogProb) { + return s.logprobsProcessor.GetLogprobs(token, topK) +} + +// generateChatLogprobs creates logprobs data for chat completions +// Delegates to LogprobsProcessor following vLLM architecture patterns +func (s *VllmSimulator) generateChatLogprobs(tokens []string, topK int) *openaiserverapi.ChatCompletionLogProbs { + return s.logprobsProcessor.ProcessChatLogprobs(tokens, topK) +} + +// generateTextLogprobs creates logprobs data for text completions +// Delegates to LogprobsProcessor following vLLM architecture patterns +func (s *VllmSimulator) generateTextLogprobs(tokens []string, topK int) *openaiserverapi.CompletionLogProbs { + return s.logprobsProcessor.ProcessTextLogprobs(tokens, topK) +} + // createCompletionResponse creates the response for completion requests, supports both completion request types (text and chat) // as defined by isChatCompletion +// req - the original completion request // respTokens - tokenized content to be sent in the response // toolCalls - tool calls to be sent in the response // finishReason - a pointer to string that represents finish reason, can be nil or stop or length, ... // usageData - usage (tokens statistics) for this response // modelName - display name returned to the client and used in metrics. It is either the first alias // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). -func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, +func (s *VllmSimulator) createCompletionResponse(req openaiserverapi.CompletionRequest, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { baseResp := openaiserverapi.BaseCompletionResponse{ ID: chatComplIDPrefix + common.GenerateUUIDString(), @@ -467,16 +492,30 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke } else { message.Content = openaiserverapi.Content{Raw: respText} } + + // Generate logprobs if requested + var logprobs *openaiserverapi.ChatCompletionLogProbs + if req.ShouldIncludeLogprobs() && len(respTokens) > 0 { + logprobs = s.generateChatLogprobs(respTokens, req.GetTopLogprobs()) + } + return &openaiserverapi.ChatCompletionResponse{ BaseCompletionResponse: baseResp, - Choices: []openaiserverapi.ChatRespChoice{{Message: message, BaseResponseChoice: baseChoice}}, + Choices: []openaiserverapi.ChatRespChoice{{Message: message, BaseResponseChoice: baseChoice, Logprobs: logprobs}}, } } baseResp.Object = textCompletionObject + + // Generate logprobs if requested (for text completion) + var logprobs *openaiserverapi.CompletionLogProbs + if req.ShouldIncludeLogprobs() && len(respTokens) > 0 { + logprobs = s.generateTextLogprobs(respTokens, req.GetTopLogprobs()) + } + return &openaiserverapi.TextCompletionResponse{ BaseCompletionResponse: baseResp, - Choices: []openaiserverapi.TextRespChoice{{BaseResponseChoice: baseChoice, Text: respText}}, + Choices: []openaiserverapi.TextRespChoice{{BaseResponseChoice: baseChoice, Text: respText, Logprobs: logprobs}}, } } @@ -490,7 +529,7 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke // usageData - usage (tokens statistics) for this response func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall, modelName string, finishReason string, usageData *openaiserverapi.Usage) { - resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, + resp := s.createCompletionResponse(reqCtx.CompletionReq, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, reqCtx.CompletionReq.IsDoRemoteDecode()) // calculate how long to wait before returning the response, time is based on number of tokens diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index c64affc8..2d066880 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -37,6 +37,7 @@ type streamingContext struct { nPromptTokens int nCachedPromptTokens int requestID string + request openaiserverapi.CompletionRequest } // sendStreamingResponse creates and sends a streaming response for completion requests of both types (text and chat) @@ -62,7 +63,7 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons if len(responseTokens) > 0 || len(toolCalls) > 0 { if context.isChatCompletion { // in chat completion first chunk contains the role - chunk := s.createChatCompletionChunk(context, "", nil, openaiserverapi.RoleAssistant, nil) + chunk := s.createChatCompletionChunk(context, "", nil, openaiserverapi.RoleAssistant, nil, false, 0) if err := s.sendChunk(w, chunk, ""); err != nil { context.ctx.Error("Sending stream first chunk failed, "+err.Error(), fasthttp.StatusInternalServerError) return @@ -128,8 +129,16 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) { finishReasonToSend = &finishReason } + + // Determine if we should include logprobs for this token + shouldIncludeLogprobs := context.request != nil && context.request.ShouldIncludeLogprobs() + topK := 0 + if shouldIncludeLogprobs && context.request != nil { + topK = context.request.GetTopLogprobs() + } + if context.isChatCompletion { - chunk = s.createChatCompletionChunk(context, token, toolChunkInsert, "", finishReasonToSend) + chunk = s.createChatCompletionChunk(context, token, toolChunkInsert, "", finishReasonToSend, shouldIncludeLogprobs, topK) } else { chunk = s.createTextCompletionChunk(context, token, finishReasonToSend) } @@ -144,7 +153,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ var chunk openaiserverapi.CompletionRespChunk if finishReason == dataset.StopFinishReason { if context.isChatCompletion { - chunk = s.createChatCompletionChunk(context, "", nil, "", &finishReason) + chunk = s.createChatCompletionChunk(context, "", nil, "", &finishReason, false, 0) } else { chunk = s.createTextCompletionChunk(context, "", &finishReason) } @@ -201,7 +210,7 @@ func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, tok // createChatCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion // API response, for chat completion. It sets either role, or token, or tool call info in the message. func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall, - role string, finishReason *string) openaiserverapi.CompletionRespChunk { + role string, finishReason *string, includeLogprobs bool, topK int) openaiserverapi.CompletionRespChunk { chunk := openaiserverapi.ChatCompletionRespChunk{ BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ ID: chatComplIDPrefix + common.GenerateUUIDString(), @@ -224,6 +233,28 @@ func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, tok chunk.Choices[0].Delta.ToolCalls = []openaiserverapi.ToolCall{*tool} } else if len(token) > 0 { chunk.Choices[0].Delta.Content.Raw = token + + // Generate logprobs for this token if requested + if includeLogprobs { + mainLogprob, topLps := s.generateSimulatedLogprobs(token, topK) + + // Convert token to bytes + bytes := make([]int, len(token)) + for i, b := range []byte(token) { + bytes[i] = int(b) + } + + chunk.Choices[0].Logprobs = &openaiserverapi.ChatCompletionLogProbs{ + Content: []openaiserverapi.ChatCompletionLogProbsContent{ + { + Token: token, + Logprob: mainLogprob, + Bytes: bytes, + TopLogprobs: topLps, + }, + }, + } + } } return &chunk diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 34db0ee6..fc06c03b 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -67,6 +67,10 @@ type CompletionRequest interface { IsDoRemotePrefill() bool // GetFullPrompt returns the full prompt including system and user prompts GetFullPrompt() string + // ShouldIncludeLogprobs returns true if logprobs should be included in the response + ShouldIncludeLogprobs() bool + // GetTopLogprobs returns the number of top logprobs to include (0 if not requested) + GetTopLogprobs() int } // BaseCompletionRequest contains base completion request related information @@ -178,6 +182,13 @@ type ChatCompletionRequest struct { // possible values: none, auto, required. // Sending an object with a specific tool, is currently not supported. ToolChoice string `json:"tool_choice,omitempty"` + + // Logprobs indicates whether to return log probabilities of the output tokens + Logprobs *bool `json:"logprobs,omitempty"` + + // TopLogprobs specifies how many log probability values to return per token (0-20) + // Requires Logprobs to be set to true + TopLogprobs *int `json:"top_logprobs,omitempty"` } // function defines a tool @@ -253,6 +264,17 @@ func (req *ChatCompletionRequest) GetFullPrompt() string { return prompt } +func (c *ChatCompletionRequest) ShouldIncludeLogprobs() bool { + return c.Logprobs != nil && *c.Logprobs +} + +func (c *ChatCompletionRequest) GetTopLogprobs() int { + if c.TopLogprobs != nil { + return *c.TopLogprobs + } + return 0 +} + // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { @@ -266,6 +288,9 @@ type TextCompletionRequest struct { // The token count of your prompt plus `max_tokens` cannot exceed the model's // context length. MaxTokens *int64 `json:"max_tokens"` + + // Logprobs specifies how many log probability values to return (0-5) + Logprobs *int `json:"logprobs,omitempty"` } func (t *TextCompletionRequest) GetPrompt() string { @@ -291,3 +316,14 @@ func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { func (t *TextCompletionRequest) GetFullPrompt() string { return "### user:\n" + t.Prompt + "\n" } + +func (t *TextCompletionRequest) ShouldIncludeLogprobs() bool { + return t.Logprobs != nil && *t.Logprobs > 0 +} + +func (t *TextCompletionRequest) GetTopLogprobs() int { + if t.Logprobs != nil { + return *t.Logprobs + } + return 0 +} diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index d32784e3..d1fb110d 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -170,11 +170,41 @@ type ToolCall struct { Index int `json:"index"` } +// ChatCompletionLogProb represents a single token's log probability information +type ChatCompletionLogProb struct { + // Token is the token string + Token string `json:"token"` + // Logprob is the log probability of this token + Logprob float64 `json:"logprob"` + // Bytes is the list of UTF-8 bytes for the token + Bytes []int `json:"bytes,omitempty"` +} + +// ChatCompletionLogProbsContent represents log probability information for a single output token +type ChatCompletionLogProbsContent struct { + // Token is the selected token + Token string `json:"token"` + // Logprob is the log probability of the selected token + Logprob float64 `json:"logprob"` + // Bytes is the list of UTF-8 bytes for the token + Bytes []int `json:"bytes,omitempty"` + // TopLogprobs is the list of top log probability candidates + TopLogprobs []ChatCompletionLogProb `json:"top_logprobs"` +} + +// ChatCompletionLogProbs represents log probability information for the output +type ChatCompletionLogProbs struct { + // Content is the list of log probabilities for each output token + Content []ChatCompletionLogProbsContent `json:"content,omitempty"` +} + // ChatRespChoice represents a single chat completion response choise type ChatRespChoice struct { BaseResponseChoice // Message contains choice's Message Message Message `json:"message"` + // Logprobs contains log probability information for the output tokens + Logprobs *ChatCompletionLogProbs `json:"logprobs,omitempty"` } // TextCompletionResponse defines structure of /completion response @@ -184,11 +214,25 @@ type TextCompletionResponse struct { Choices []TextRespChoice `json:"choices"` } +// CompletionLogProbs represents log probability information for text completions +type CompletionLogProbs struct { + // TextOffset is the character offset from the start of the completion + TextOffset []int `json:"text_offset"` + // TokenLogprobs is the log probability of each token + TokenLogprobs []float64 `json:"token_logprobs"` + // Tokens is the list of tokens + Tokens []string `json:"tokens"` + // TopLogprobs is the list of top log probability candidates for each token position + TopLogprobs []map[string]float64 `json:"top_logprobs,omitempty"` +} + // TextRespChoice represents a single text completion response choise type TextRespChoice struct { BaseResponseChoice // Text defines request's content Text string `json:"text"` + // Logprobs contains log probability information + Logprobs *CompletionLogProbs `json:"logprobs,omitempty"` } // CompletionRespChunk is an interface that defines a single response chunk @@ -206,6 +250,8 @@ type ChatRespChunkChoice struct { BaseResponseChoice // Delta is a content of the chunk Delta Message `json:"delta"` + // Logprobs contains log probability information for the chunk tokens + Logprobs *ChatCompletionLogProbs `json:"logprobs,omitempty"` } // CompletionError defines the simulator's response in case of an error From 871e591b489a6ea3636dd15f9ca6c2e89048a1ea Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Thu, 16 Oct 2025 13:40:06 +0100 Subject: [PATCH 2/4] chore: Move testing to Gomega Signed-off-by: Rui Vieira --- .../logprobs_processor_test.go | 138 ++++++------------ 1 file changed, 42 insertions(+), 96 deletions(-) diff --git a/pkg/llm-d-inference-sim/logprobs_processor_test.go b/pkg/llm-d-inference-sim/logprobs_processor_test.go index e1438b34..bc2edf29 100644 --- a/pkg/llm-d-inference-sim/logprobs_processor_test.go +++ b/pkg/llm-d-inference-sim/logprobs_processor_test.go @@ -18,9 +18,12 @@ package llmdinferencesim import ( "testing" + + . "github.com/onsi/gomega" ) func TestLogprobsProcessor_Caching(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) // Test that same token generates same logprobs (deterministic) @@ -31,42 +34,30 @@ func TestLogprobsProcessor_Caching(t *testing.T) { logprob2, topLogprobs2 := processor.GetLogprobs(token, topK) // Should be identical (deterministic) - if logprob1 != logprob2 { - t.Errorf("Expected same logprob for same token, got %.4f and %.4f", logprob1, logprob2) - } - if len(topLogprobs1) != len(topLogprobs2) { - t.Errorf("Expected same topLogprobs length, got %d and %d", len(topLogprobs1), len(topLogprobs2)) - } - if len(topLogprobs1) != topK { - t.Errorf("Expected topLogprobs length %d, got %d", topK, len(topLogprobs1)) - } + Expect(logprob1).To(Equal(logprob2)) + Expect(len(topLogprobs1)).To(Equal(len(topLogprobs2))) + Expect(len(topLogprobs1)).To(Equal(topK)) // Check cache stats hits, misses, hitRate := processor.GetCacheStats() - if hits != 1 { - t.Errorf("Expected 1 cache hit, got %d", hits) - } - if misses != 1 { - t.Errorf("Expected 1 cache miss, got %d", misses) - } - if hitRate != 0.5 { - t.Errorf("Expected 50%% hit rate, got %.2f", hitRate) - } + Expect(hits).To(Equal(int64(1))) + Expect(misses).To(Equal(int64(1))) + Expect(hitRate).To(Equal(0.5)) } func TestLogprobsProcessor_DifferentTokens(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) // Test that different tokens generate different logprobs logprob1, _ := processor.GetLogprobs("hello", 2) logprob2, _ := processor.GetLogprobs("world", 2) - if logprob1 == logprob2 { - t.Errorf("Different tokens should have different logprobs, both got %.4f", logprob1) - } + Expect(logprob1).NotTo(Equal(logprob2)) } func TestLogprobsProcessor_DifferentTopK(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) // Test that same token with different topK generates different results @@ -75,15 +66,12 @@ func TestLogprobsProcessor_DifferentTopK(t *testing.T) { _, topLogprobs1 := processor.GetLogprobs(token, 2) _, topLogprobs2 := processor.GetLogprobs(token, 5) - if len(topLogprobs1) != 2 { - t.Errorf("Expected 2 top logprobs, got %d", len(topLogprobs1)) - } - if len(topLogprobs2) != 5 { - t.Errorf("Expected 5 top logprobs, got %d", len(topLogprobs2)) - } + Expect(len(topLogprobs1)).To(Equal(2)) + Expect(len(topLogprobs2)).To(Equal(5)) } func TestLogprobsProcessor_ChatLogprobs(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) tokens := []string{"Hello", "world", "!"} @@ -91,30 +79,19 @@ func TestLogprobsProcessor_ChatLogprobs(t *testing.T) { logprobs := processor.ProcessChatLogprobs(tokens, topK) - if logprobs == nil { - t.Fatal("Expected non-nil chat logprobs") - } - if len(logprobs.Content) != len(tokens) { - t.Errorf("Expected %d content items, got %d", len(tokens), len(logprobs.Content)) - } + Expect(logprobs).NotTo(BeNil()) + Expect(len(logprobs.Content)).To(Equal(len(tokens))) for i, content := range logprobs.Content { - if content.Token != tokens[i] { - t.Errorf("Expected token %s at index %d, got %s", tokens[i], i, content.Token) - } - if content.Logprob >= 0 { - t.Errorf("Expected negative logprob, got %.4f", content.Logprob) - } - if len(content.TopLogprobs) != topK { - t.Errorf("Expected %d top logprobs, got %d", topK, len(content.TopLogprobs)) - } - if content.Bytes == nil { - t.Error("Expected non-nil bytes") - } + Expect(content.Token).To(Equal(tokens[i])) + Expect(content.Logprob).To(BeNumerically("<", 0)) + Expect(len(content.TopLogprobs)).To(Equal(topK)) + Expect(content.Bytes).NotTo(BeNil()) } } func TestLogprobsProcessor_TextLogprobs(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) tokens := []string{"Hello", "world"} @@ -122,70 +99,47 @@ func TestLogprobsProcessor_TextLogprobs(t *testing.T) { logprobs := processor.ProcessTextLogprobs(tokens, topK) - if logprobs == nil { - t.Fatal("Expected non-nil text logprobs") - } - if len(logprobs.Tokens) != len(tokens) { - t.Errorf("Expected %d tokens, got %d", len(tokens), len(logprobs.Tokens)) - } - if len(logprobs.TokenLogprobs) != len(tokens) { - t.Errorf("Expected %d token logprobs, got %d", len(tokens), len(logprobs.TokenLogprobs)) - } - if len(logprobs.TextOffset) != len(tokens) { - t.Errorf("Expected %d text offsets, got %d", len(tokens), len(logprobs.TextOffset)) - } - if len(logprobs.TopLogprobs) != len(tokens) { - t.Errorf("Expected %d top logprobs arrays, got %d", len(tokens), len(logprobs.TopLogprobs)) - } + Expect(logprobs).NotTo(BeNil()) + Expect(len(logprobs.Tokens)).To(Equal(len(tokens))) + Expect(len(logprobs.TokenLogprobs)).To(Equal(len(tokens))) + Expect(len(logprobs.TextOffset)).To(Equal(len(tokens))) + Expect(len(logprobs.TopLogprobs)).To(Equal(len(tokens))) // Check text offsets are cumulative expectedOffset := 0 for i, token := range tokens { - if logprobs.TextOffset[i] != expectedOffset { - t.Errorf("Expected offset %d at index %d, got %d", expectedOffset, i, logprobs.TextOffset[i]) - } - if logprobs.Tokens[i] != token { - t.Errorf("Expected token %s at index %d, got %s", token, i, logprobs.Tokens[i]) - } - if logprobs.TokenLogprobs[i] >= 0 { - t.Errorf("Expected negative logprob at index %d, got %.4f", i, logprobs.TokenLogprobs[i]) - } - if len(logprobs.TopLogprobs[i]) != topK { - t.Errorf("Expected %d top logprobs at index %d, got %d", topK, i, len(logprobs.TopLogprobs[i])) - } + Expect(logprobs.TextOffset[i]).To(Equal(expectedOffset)) + Expect(logprobs.Tokens[i]).To(Equal(token)) + Expect(logprobs.TokenLogprobs[i]).To(BeNumerically("<", 0)) + Expect(len(logprobs.TopLogprobs[i])).To(Equal(topK)) expectedOffset += len(token) } } func TestLogprobsProcessor_EmptyTokens(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) // Test empty token lists chatLogprobs := processor.ProcessChatLogprobs([]string{}, 3) textLogprobs := processor.ProcessTextLogprobs([]string{}, 3) - if chatLogprobs != nil { - t.Error("Expected nil chat logprobs for empty tokens") - } - if textLogprobs != nil { - t.Error("Expected nil text logprobs for empty tokens") - } + Expect(chatLogprobs).To(BeNil()) + Expect(textLogprobs).To(BeNil()) } func TestLogprobsProcessor_ZeroTopK(t *testing.T) { + RegisterTestingT(t) processor := NewLogprobsProcessor(100) logprob, topLogprobs := processor.GetLogprobs("test", 0) - if logprob >= 0 { - t.Errorf("Expected negative logprob, got %.4f", logprob) - } - if topLogprobs != nil { - t.Error("Expected nil top logprobs for topK=0") - } + Expect(logprob).To(BeNumerically("<", 0)) + Expect(topLogprobs).To(BeNil()) } func TestLogprobsProcessor_CacheEviction(t *testing.T) { + RegisterTestingT(t) // Test with very small cache size to trigger eviction processor := NewLogprobsProcessor(2) @@ -195,22 +149,14 @@ func TestLogprobsProcessor_CacheEviction(t *testing.T) { processor.GetLogprobs("token3", 1) // Should trigger eviction hits, misses, _ := processor.GetCacheStats() - if hits != 0 { - t.Errorf("Expected 0 cache hits, got %d", hits) - } - if misses != 3 { - t.Errorf("Expected 3 cache misses, got %d", misses) - } + Expect(hits).To(Equal(int64(0))) + Expect(misses).To(Equal(int64(3))) // Access one of the earlier tokens - may or may not be in cache due to eviction processor.GetLogprobs("token1", 1) // Cache should be working (some entries may have been evicted) hits2, misses2, _ := processor.GetCacheStats() - if hits2 < 0 { - t.Errorf("Expected non-negative cache hits, got %d", hits2) - } - if misses2 < 3 { - t.Errorf("Expected at least 3 cache misses, got %d", misses2) - } + Expect(hits2).To(BeNumerically(">=", 0)) + Expect(misses2).To(BeNumerically(">=", 3)) } \ No newline at end of file From 30e0e6ea6f31cb35c702521f9e3971b51bccf36c Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Thu, 16 Oct 2025 13:43:49 +0100 Subject: [PATCH 3/4] chore: extract condition to outside loop Signed-off-by: Rui Vieira --- pkg/llm-d-inference-sim/streaming.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 2d066880..e7940777 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -105,6 +105,9 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ ttft := s.getWaitTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill) time.Sleep(time.Duration(ttft) * time.Millisecond) + // Calculate finish reason condition once before the loop + shouldSendFinishReason := finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason + for i, token := range genTokens { if i != 0 { time.Sleep(time.Duration(s.getInterTokenLatency()) * time.Millisecond) @@ -126,7 +129,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ var chunk openaiserverapi.CompletionRespChunk var finishReasonToSend *string - if i == len(genTokens)-1 && (finishReason == dataset.LengthFinishReason || finishReason == dataset.ToolsFinishReason) { + if i == len(genTokens)-1 && shouldSendFinishReason { finishReasonToSend = &finishReason } From 821963057e98103fdac31193deb76451ab6ce35a Mon Sep 17 00:00:00 2001 From: Rui Vieira Date: Thu, 16 Oct 2025 13:45:25 +0100 Subject: [PATCH 4/4] chore: format files Signed-off-by: Rui Vieira --- pkg/llm-d-inference-sim/logprobs_processor.go | 4 ++-- pkg/llm-d-inference-sim/logprobs_processor_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/llm-d-inference-sim/logprobs_processor.go b/pkg/llm-d-inference-sim/logprobs_processor.go index df011fe0..7517c82a 100644 --- a/pkg/llm-d-inference-sim/logprobs_processor.go +++ b/pkg/llm-d-inference-sim/logprobs_processor.go @@ -26,7 +26,7 @@ import ( // LogprobData represents cached logprob information for a token type LogprobData struct { - MainLogprob float64 `json:"main_logprob"` + MainLogprob float64 `json:"main_logprob"` TopLogprobs []openaiserverapi.ChatCompletionLogProb `json:"top_logprobs"` } @@ -230,4 +230,4 @@ func (lp *LogprobsProcessor) ClearCache() { lp.tokenCache = make(map[string]*LogprobData) lp.cacheHits = 0 lp.cacheMisses = 0 -} \ No newline at end of file +} diff --git a/pkg/llm-d-inference-sim/logprobs_processor_test.go b/pkg/llm-d-inference-sim/logprobs_processor_test.go index bc2edf29..8366e9eb 100644 --- a/pkg/llm-d-inference-sim/logprobs_processor_test.go +++ b/pkg/llm-d-inference-sim/logprobs_processor_test.go @@ -159,4 +159,4 @@ func TestLogprobsProcessor_CacheEviction(t *testing.T) { hits2, misses2, _ := processor.GetCacheStats() Expect(hits2).To(BeNumerically(">=", 0)) Expect(misses2).To(BeNumerically(">=", 3)) -} \ No newline at end of file +}