- 
                Notifications
    You must be signed in to change notification settings 
- Fork 37
feat: Log probabilities support (with cache) #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,233 @@ | ||
| /* | ||
| Copyright 2025 The llm-d-inference-sim Authors. | ||
|  | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|  | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|  | ||
| package llmdinferencesim | ||
|  | ||
| import ( | ||
| "crypto/md5" | ||
| "fmt" | ||
| "sync" | ||
|  | ||
| openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" | ||
| ) | ||
|  | ||
| // LogprobData represents cached logprob information for a token | ||
| type LogprobData struct { | ||
| MainLogprob float64 `json:"main_logprob"` | ||
| TopLogprobs []openaiserverapi.ChatCompletionLogProb `json:"top_logprobs"` | ||
| } | ||
|  | ||
| // LogprobsProcessor handles logprobs generation and caching following vLLM architecture | ||
| type LogprobsProcessor struct { | ||
| // tokenCache caches logprobs by token content and topK to avoid recomputation | ||
| tokenCache map[string]*LogprobData | ||
| cacheMutex sync.RWMutex | ||
|  | ||
| // cacheHits and cacheMisses for metrics | ||
| cacheHits int64 | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cacheHits and cacheMisses should be changed under protection of the mutex or be an atomic counters (from sync/atomic package) | ||
| cacheMisses int64 | ||
|  | ||
| // maxCacheSize limits memory usage | ||
| maxCacheSize int | ||
| } | ||
|  | ||
| // NewLogprobsProcessor creates a new LogprobsProcessor following vLLM design patterns | ||
| func NewLogprobsProcessor(maxCacheSize int) *LogprobsProcessor { | ||
| if maxCacheSize <= 0 { | ||
| maxCacheSize = 10000 // Default cache size | ||
| } | ||
|  | ||
| return &LogprobsProcessor{ | ||
| tokenCache: make(map[string]*LogprobData), | ||
| maxCacheSize: maxCacheSize, | ||
| } | ||
| } | ||
|  | ||
| // generateCacheKey creates a deterministic key for caching based on token and topK | ||
| func (lp *LogprobsProcessor) generateCacheKey(token string, topK int) string { | ||
| return fmt.Sprintf("%s:%d", token, topK) | ||
| } | ||
|  | ||
| // generateDeterministicLogprobs creates logprobs with deterministic values based on token content | ||
| // This follows vLLM's approach of consistent logprobs for the same token in similar contexts | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ruivieira can you please explain how context is taken into consideration here. As I understand a token should give same logprob only in the same/similar context. What does context means here? Isn't it the previous tokens? But they are not passed to this function. | ||
| func (lp *LogprobsProcessor) generateDeterministicLogprobs(token string, topK int) *LogprobData { | ||
| // Use token content to seed deterministic generation (similar to vLLM's approach) | ||
| hash := md5.Sum([]byte(token)) | ||
| seed := int64(hash[0])<<24 | int64(hash[1])<<16 | int64(hash[2])<<8 | int64(hash[3]) | ||
|  | ||
| // Generate main logprob deterministically based on token | ||
| // Real logprobs are typically negative, with values closer to 0 being more likely | ||
| mainLogprob := -0.1 - (float64(seed%2000) / 1000.0) // Range: -0.1 to -2.1 | ||
|  | ||
| if topK <= 0 { | ||
| return &LogprobData{ | ||
| MainLogprob: mainLogprob, | ||
| TopLogprobs: nil, | ||
| } | ||
| } | ||
|  | ||
| // Generate top-k alternatives deterministically | ||
| topLogprobs := make([]openaiserverapi.ChatCompletionLogProb, 0, topK) | ||
| for i := 0; i < topK; i++ { | ||
| // Generate deterministic alternative token | ||
| altToken := fmt.Sprintf("alt_%d_%x", i, hash[i%4]) | ||
|  | ||
| // Each alternative gets progressively lower probability | ||
| altLogprob := mainLogprob - (float64(i+1) * (0.5 + float64((seed+int64(i))%1500)/1000.0)) | ||
|  | ||
| // Convert token to bytes | ||
| bytes := make([]int, len(altToken)) | ||
| for j, b := range []byte(altToken) { | ||
| bytes[j] = int(b) | ||
| } | ||
|  | ||
| topLogprobs = append(topLogprobs, openaiserverapi.ChatCompletionLogProb{ | ||
| Token: altToken, | ||
| Logprob: altLogprob, | ||
| Bytes: bytes, | ||
| }) | ||
| } | ||
|  | ||
| return &LogprobData{ | ||
| MainLogprob: mainLogprob, | ||
| TopLogprobs: topLogprobs, | ||
| } | ||
| } | ||
|  | ||
| // GetLogprobs returns logprobs for a token, using cache when possible | ||
| func (lp *LogprobsProcessor) GetLogprobs(token string, topK int) (float64, []openaiserverapi.ChatCompletionLogProb) { | ||
| cacheKey := lp.generateCacheKey(token, topK) | ||
|  | ||
| // Check cache first | ||
| lp.cacheMutex.RLock() | ||
| if cached, exists := lp.tokenCache[cacheKey]; exists { | ||
| lp.cacheMutex.RUnlock() | ||
| lp.cacheHits++ | ||
| return cached.MainLogprob, cached.TopLogprobs | ||
| } | ||
| lp.cacheMutex.RUnlock() | ||
|  | ||
| // Cache miss - generate new logprobs | ||
| lp.cacheMisses++ | ||
| logprobData := lp.generateDeterministicLogprobs(token, topK) | ||
|  | ||
| // Store in cache (with size limit) | ||
| lp.cacheMutex.Lock() | ||
| if len(lp.tokenCache) >= lp.maxCacheSize { | ||
| // Simple eviction: remove oldest entry | ||
| // In production, this could use LRU or other strategies | ||
| for k := range lp.tokenCache { | ||
| delete(lp.tokenCache, k) | ||
| break | ||
| } | ||
| 
      Comment on lines
    
      +132
     to 
      +135
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Iteration on a map will not be in order of adding objects to the map, is we want to remove the oldest item - need to store the insertion date/time or latest usage date/time | ||
| } | ||
| lp.tokenCache[cacheKey] = logprobData | ||
| lp.cacheMutex.Unlock() | ||
|  | ||
| return logprobData.MainLogprob, logprobData.TopLogprobs | ||
| } | ||
|  | ||
| // ProcessChatLogprobs creates logprobs data for chat completions following vLLM patterns | ||
| func (lp *LogprobsProcessor) ProcessChatLogprobs(tokens []string, topK int) *openaiserverapi.ChatCompletionLogProbs { | ||
| if len(tokens) == 0 { | ||
| return nil | ||
| } | ||
|  | ||
| logprobs := &openaiserverapi.ChatCompletionLogProbs{ | ||
| Content: make([]openaiserverapi.ChatCompletionLogProbsContent, 0, len(tokens)), | ||
| } | ||
|  | ||
| for _, token := range tokens { | ||
| mainLogprob, topLps := lp.GetLogprobs(token, topK) | ||
|  | ||
| // Convert token to bytes | ||
| bytes := make([]int, len(token)) | ||
| for i, b := range []byte(token) { | ||
| bytes[i] = int(b) | ||
| } | ||
|  | ||
| logprobs.Content = append(logprobs.Content, openaiserverapi.ChatCompletionLogProbsContent{ | ||
| Token: token, | ||
| Logprob: mainLogprob, | ||
| Bytes: bytes, | ||
| TopLogprobs: topLps, | ||
| }) | ||
| } | ||
|  | ||
| return logprobs | ||
| } | ||
|  | ||
| // ProcessTextLogprobs creates logprobs data for text completions following vLLM patterns | ||
| func (lp *LogprobsProcessor) ProcessTextLogprobs(tokens []string, topK int) *openaiserverapi.CompletionLogProbs { | ||
| if len(tokens) == 0 { | ||
| return nil | ||
| } | ||
|  | ||
| logprobs := &openaiserverapi.CompletionLogProbs{ | ||
| TextOffset: make([]int, 0, len(tokens)), | ||
| TokenLogprobs: make([]float64, 0, len(tokens)), | ||
| Tokens: make([]string, 0, len(tokens)), | ||
| } | ||
|  | ||
| if topK > 0 { | ||
| logprobs.TopLogprobs = make([]map[string]float64, 0, len(tokens)) | ||
| } | ||
|  | ||
| textOffset := 0 | ||
| for _, token := range tokens { | ||
| mainLogprob, topLps := lp.GetLogprobs(token, topK) | ||
|  | ||
| logprobs.TextOffset = append(logprobs.TextOffset, textOffset) | ||
| logprobs.TokenLogprobs = append(logprobs.TokenLogprobs, mainLogprob) | ||
| logprobs.Tokens = append(logprobs.Tokens, token) | ||
|  | ||
| if topK > 0 { | ||
| topMap := make(map[string]float64, len(topLps)) | ||
| for _, lp := range topLps { | ||
| topMap[lp.Token] = lp.Logprob | ||
| } | ||
| logprobs.TopLogprobs = append(logprobs.TopLogprobs, topMap) | ||
| } | ||
|  | ||
| textOffset += len(token) | ||
| } | ||
|  | ||
| return logprobs | ||
| } | ||
|  | ||
| // GetCacheStats returns cache performance statistics | ||
| func (lp *LogprobsProcessor) GetCacheStats() (hits, misses int64, hitRate float64) { | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider to not add function used in test only to the class, we may add an additional file or maybe choose not to use it in test | ||
| lp.cacheMutex.RLock() | ||
| defer lp.cacheMutex.RUnlock() | ||
|  | ||
| total := lp.cacheHits + lp.cacheMisses | ||
| hitRate = 0.0 | ||
| if total > 0 { | ||
| hitRate = float64(lp.cacheHits) / float64(total) | ||
| } | ||
|  | ||
| return lp.cacheHits, lp.cacheMisses, hitRate | ||
| } | ||
|  | ||
| // ClearCache clears the logprobs cache | ||
| func (lp *LogprobsProcessor) ClearCache() { | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not used | ||
| lp.cacheMutex.Lock() | ||
| defer lp.cacheMutex.Unlock() | ||
|  | ||
| lp.tokenCache = make(map[string]*LogprobData) | ||
| lp.cacheHits = 0 | ||
| lp.cacheMisses = 0 | ||
| } | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| /* | ||
| Copyright 2025 The llm-d-inference-sim Authors. | ||
|  | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|  | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|  | ||
| package llmdinferencesim | ||
|  | ||
| import ( | ||
| "testing" | ||
|  | ||
| . "github.com/onsi/gomega" | ||
| ) | ||
|  | ||
| func TestLogprobsProcessor_Caching(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| // Test that same token generates same logprobs (deterministic) | ||
| token := "hello" | ||
| topK := 3 | ||
|  | ||
| logprob1, topLogprobs1 := processor.GetLogprobs(token, topK) | ||
| logprob2, topLogprobs2 := processor.GetLogprobs(token, topK) | ||
|  | ||
| // Should be identical (deterministic) | ||
| Expect(logprob1).To(Equal(logprob2)) | ||
| Expect(len(topLogprobs1)).To(Equal(len(topLogprobs2))) | ||
| Expect(len(topLogprobs1)).To(Equal(topK)) | ||
|  | ||
| // Check cache stats | ||
| hits, misses, hitRate := processor.GetCacheStats() | ||
| Expect(hits).To(Equal(int64(1))) | ||
| Expect(misses).To(Equal(int64(1))) | ||
| Expect(hitRate).To(Equal(0.5)) | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_DifferentTokens(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| // Test that different tokens generate different logprobs | ||
| logprob1, _ := processor.GetLogprobs("hello", 2) | ||
| logprob2, _ := processor.GetLogprobs("world", 2) | ||
|  | ||
| Expect(logprob1).NotTo(Equal(logprob2)) | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_DifferentTopK(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| // Test that same token with different topK generates different results | ||
| token := "test" | ||
|  | ||
| _, topLogprobs1 := processor.GetLogprobs(token, 2) | ||
| _, topLogprobs2 := processor.GetLogprobs(token, 5) | ||
|  | ||
| Expect(len(topLogprobs1)).To(Equal(2)) | ||
| Expect(len(topLogprobs2)).To(Equal(5)) | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_ChatLogprobs(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| tokens := []string{"Hello", "world", "!"} | ||
| topK := 3 | ||
|  | ||
| logprobs := processor.ProcessChatLogprobs(tokens, topK) | ||
|  | ||
| Expect(logprobs).NotTo(BeNil()) | ||
| Expect(len(logprobs.Content)).To(Equal(len(tokens))) | ||
|  | ||
| for i, content := range logprobs.Content { | ||
| Expect(content.Token).To(Equal(tokens[i])) | ||
| Expect(content.Logprob).To(BeNumerically("<", 0)) | ||
| Expect(len(content.TopLogprobs)).To(Equal(topK)) | ||
| Expect(content.Bytes).NotTo(BeNil()) | ||
| } | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_TextLogprobs(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| tokens := []string{"Hello", "world"} | ||
| topK := 2 | ||
|  | ||
| logprobs := processor.ProcessTextLogprobs(tokens, topK) | ||
|  | ||
| Expect(logprobs).NotTo(BeNil()) | ||
| Expect(len(logprobs.Tokens)).To(Equal(len(tokens))) | ||
| Expect(len(logprobs.TokenLogprobs)).To(Equal(len(tokens))) | ||
| Expect(len(logprobs.TextOffset)).To(Equal(len(tokens))) | ||
| Expect(len(logprobs.TopLogprobs)).To(Equal(len(tokens))) | ||
|  | ||
| // Check text offsets are cumulative | ||
| expectedOffset := 0 | ||
| for i, token := range tokens { | ||
| Expect(logprobs.TextOffset[i]).To(Equal(expectedOffset)) | ||
| Expect(logprobs.Tokens[i]).To(Equal(token)) | ||
| Expect(logprobs.TokenLogprobs[i]).To(BeNumerically("<", 0)) | ||
| Expect(len(logprobs.TopLogprobs[i])).To(Equal(topK)) | ||
| expectedOffset += len(token) | ||
| } | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_EmptyTokens(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| // Test empty token lists | ||
| chatLogprobs := processor.ProcessChatLogprobs([]string{}, 3) | ||
| textLogprobs := processor.ProcessTextLogprobs([]string{}, 3) | ||
|  | ||
| Expect(chatLogprobs).To(BeNil()) | ||
| Expect(textLogprobs).To(BeNil()) | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_ZeroTopK(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| processor := NewLogprobsProcessor(100) | ||
|  | ||
| logprob, topLogprobs := processor.GetLogprobs("test", 0) | ||
|  | ||
| Expect(logprob).To(BeNumerically("<", 0)) | ||
| Expect(topLogprobs).To(BeNil()) | ||
| } | ||
|  | ||
| func TestLogprobsProcessor_CacheEviction(t *testing.T) { | ||
| RegisterTestingT(t) | ||
| // Test with very small cache size to trigger eviction | ||
| processor := NewLogprobsProcessor(2) | ||
|  | ||
| // Fill cache beyond capacity | ||
| processor.GetLogprobs("token1", 1) | ||
| processor.GetLogprobs("token2", 1) | ||
| processor.GetLogprobs("token3", 1) // Should trigger eviction | ||
|  | ||
| hits, misses, _ := processor.GetCacheStats() | ||
| Expect(hits).To(Equal(int64(0))) | ||
| Expect(misses).To(Equal(int64(3))) | ||
|  | ||
| // Access one of the earlier tokens - may or may not be in cache due to eviction | ||
| processor.GetLogprobs("token1", 1) | ||
|  | ||
| // Cache should be working (some entries may have been evicted) | ||
| hits2, misses2, _ := processor.GetCacheStats() | ||
| Expect(hits2).To(BeNumerically(">=", 0)) | ||
| Expect(misses2).To(BeNumerically(">=", 3)) | ||
| } | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add tests for the simulator with logprobs: | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do the struct and the fields have to be public? And why are the json tags needed?