Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions pkg/llm-d-inference-sim/logprobs_processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package llmdinferencesim

import (
"crypto/md5"
"fmt"
"sync"

openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
)

// LogprobData represents cached logprob information for a token
type LogprobData struct {
MainLogprob float64 `json:"main_logprob"`
TopLogprobs []openaiserverapi.ChatCompletionLogProb `json:"top_logprobs"`
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do the struct and the fields have to be public? And why are the json tags needed?


// LogprobsProcessor handles logprobs generation and caching following vLLM architecture
type LogprobsProcessor struct {
// tokenCache caches logprobs by token content and topK to avoid recomputation
tokenCache map[string]*LogprobData
cacheMutex sync.RWMutex

// cacheHits and cacheMisses for metrics
cacheHits int64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cacheHits and cacheMisses should be changed under protection of the mutex or be an atomic counters (from sync/atomic package)

cacheMisses int64

// maxCacheSize limits memory usage
maxCacheSize int
}

// NewLogprobsProcessor creates a new LogprobsProcessor following vLLM design patterns
func NewLogprobsProcessor(maxCacheSize int) *LogprobsProcessor {
if maxCacheSize <= 0 {
maxCacheSize = 10000 // Default cache size
}

return &LogprobsProcessor{
tokenCache: make(map[string]*LogprobData),
maxCacheSize: maxCacheSize,
}
}

// generateCacheKey creates a deterministic key for caching based on token and topK
func (lp *LogprobsProcessor) generateCacheKey(token string, topK int) string {
return fmt.Sprintf("%s:%d", token, topK)
}

// generateDeterministicLogprobs creates logprobs with deterministic values based on token content
// This follows vLLM's approach of consistent logprobs for the same token in similar contexts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ruivieira can you please explain how context is taken into consideration here. As I understand a token should give same logprob only in the same/similar context. What does context means here? Isn't it the previous tokens? But they are not passed to this function.

func (lp *LogprobsProcessor) generateDeterministicLogprobs(token string, topK int) *LogprobData {
// Use token content to seed deterministic generation (similar to vLLM's approach)
hash := md5.Sum([]byte(token))
seed := int64(hash[0])<<24 | int64(hash[1])<<16 | int64(hash[2])<<8 | int64(hash[3])

// Generate main logprob deterministically based on token
// Real logprobs are typically negative, with values closer to 0 being more likely
mainLogprob := -0.1 - (float64(seed%2000) / 1000.0) // Range: -0.1 to -2.1

if topK <= 0 {
return &LogprobData{
MainLogprob: mainLogprob,
TopLogprobs: nil,
}
}

// Generate top-k alternatives deterministically
topLogprobs := make([]openaiserverapi.ChatCompletionLogProb, 0, topK)
for i := 0; i < topK; i++ {
// Generate deterministic alternative token
altToken := fmt.Sprintf("alt_%d_%x", i, hash[i%4])

// Each alternative gets progressively lower probability
altLogprob := mainLogprob - (float64(i+1) * (0.5 + float64((seed+int64(i))%1500)/1000.0))

// Convert token to bytes
bytes := make([]int, len(altToken))
for j, b := range []byte(altToken) {
bytes[j] = int(b)
}

topLogprobs = append(topLogprobs, openaiserverapi.ChatCompletionLogProb{
Token: altToken,
Logprob: altLogprob,
Bytes: bytes,
})
}

return &LogprobData{
MainLogprob: mainLogprob,
TopLogprobs: topLogprobs,
}
}

// GetLogprobs returns logprobs for a token, using cache when possible
func (lp *LogprobsProcessor) GetLogprobs(token string, topK int) (float64, []openaiserverapi.ChatCompletionLogProb) {
cacheKey := lp.generateCacheKey(token, topK)

// Check cache first
lp.cacheMutex.RLock()
if cached, exists := lp.tokenCache[cacheKey]; exists {
lp.cacheMutex.RUnlock()
lp.cacheHits++
return cached.MainLogprob, cached.TopLogprobs
}
lp.cacheMutex.RUnlock()

// Cache miss - generate new logprobs
lp.cacheMisses++
logprobData := lp.generateDeterministicLogprobs(token, topK)

// Store in cache (with size limit)
lp.cacheMutex.Lock()
if len(lp.tokenCache) >= lp.maxCacheSize {
// Simple eviction: remove oldest entry
// In production, this could use LRU or other strategies
for k := range lp.tokenCache {
delete(lp.tokenCache, k)
break
}
Comment on lines +132 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Iteration on a map will not be in order of adding objects to the map, is we want to remove the oldest item - need to store the insertion date/time or latest usage date/time

}
lp.tokenCache[cacheKey] = logprobData
lp.cacheMutex.Unlock()

return logprobData.MainLogprob, logprobData.TopLogprobs
}

// ProcessChatLogprobs creates logprobs data for chat completions following vLLM patterns
func (lp *LogprobsProcessor) ProcessChatLogprobs(tokens []string, topK int) *openaiserverapi.ChatCompletionLogProbs {
if len(tokens) == 0 {
return nil
}

logprobs := &openaiserverapi.ChatCompletionLogProbs{
Content: make([]openaiserverapi.ChatCompletionLogProbsContent, 0, len(tokens)),
}

for _, token := range tokens {
mainLogprob, topLps := lp.GetLogprobs(token, topK)

// Convert token to bytes
bytes := make([]int, len(token))
for i, b := range []byte(token) {
bytes[i] = int(b)
}

logprobs.Content = append(logprobs.Content, openaiserverapi.ChatCompletionLogProbsContent{
Token: token,
Logprob: mainLogprob,
Bytes: bytes,
TopLogprobs: topLps,
})
}

return logprobs
}

// ProcessTextLogprobs creates logprobs data for text completions following vLLM patterns
func (lp *LogprobsProcessor) ProcessTextLogprobs(tokens []string, topK int) *openaiserverapi.CompletionLogProbs {
if len(tokens) == 0 {
return nil
}

logprobs := &openaiserverapi.CompletionLogProbs{
TextOffset: make([]int, 0, len(tokens)),
TokenLogprobs: make([]float64, 0, len(tokens)),
Tokens: make([]string, 0, len(tokens)),
}

if topK > 0 {
logprobs.TopLogprobs = make([]map[string]float64, 0, len(tokens))
}

textOffset := 0
for _, token := range tokens {
mainLogprob, topLps := lp.GetLogprobs(token, topK)

logprobs.TextOffset = append(logprobs.TextOffset, textOffset)
logprobs.TokenLogprobs = append(logprobs.TokenLogprobs, mainLogprob)
logprobs.Tokens = append(logprobs.Tokens, token)

if topK > 0 {
topMap := make(map[string]float64, len(topLps))
for _, lp := range topLps {
topMap[lp.Token] = lp.Logprob
}
logprobs.TopLogprobs = append(logprobs.TopLogprobs, topMap)
}

textOffset += len(token)
}

return logprobs
}

// GetCacheStats returns cache performance statistics
func (lp *LogprobsProcessor) GetCacheStats() (hits, misses int64, hitRate float64) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider to not add function used in test only to the class, we may add an additional file or maybe choose not to use it in test

lp.cacheMutex.RLock()
defer lp.cacheMutex.RUnlock()

total := lp.cacheHits + lp.cacheMisses
hitRate = 0.0
if total > 0 {
hitRate = float64(lp.cacheHits) / float64(total)
}

return lp.cacheHits, lp.cacheMisses, hitRate
}

// ClearCache clears the logprobs cache
func (lp *LogprobsProcessor) ClearCache() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used

lp.cacheMutex.Lock()
defer lp.cacheMutex.Unlock()

lp.tokenCache = make(map[string]*LogprobData)
lp.cacheHits = 0
lp.cacheMisses = 0
}
162 changes: 162 additions & 0 deletions pkg/llm-d-inference-sim/logprobs_processor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package llmdinferencesim

import (
"testing"

. "github.com/onsi/gomega"
)

func TestLogprobsProcessor_Caching(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

// Test that same token generates same logprobs (deterministic)
token := "hello"
topK := 3

logprob1, topLogprobs1 := processor.GetLogprobs(token, topK)
logprob2, topLogprobs2 := processor.GetLogprobs(token, topK)

// Should be identical (deterministic)
Expect(logprob1).To(Equal(logprob2))
Expect(len(topLogprobs1)).To(Equal(len(topLogprobs2)))

Check failure on line 38 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(topLogprobs1).To(HaveLen(len(topLogprobs2)))` instead (ginkgolinter)
Expect(len(topLogprobs1)).To(Equal(topK))

Check failure on line 39 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(topLogprobs1).To(HaveLen(topK))` instead (ginkgolinter)

// Check cache stats
hits, misses, hitRate := processor.GetCacheStats()
Expect(hits).To(Equal(int64(1)))
Expect(misses).To(Equal(int64(1)))
Expect(hitRate).To(Equal(0.5))
}

func TestLogprobsProcessor_DifferentTokens(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

// Test that different tokens generate different logprobs
logprob1, _ := processor.GetLogprobs("hello", 2)
logprob2, _ := processor.GetLogprobs("world", 2)

Expect(logprob1).NotTo(Equal(logprob2))
}

func TestLogprobsProcessor_DifferentTopK(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

// Test that same token with different topK generates different results
token := "test"

_, topLogprobs1 := processor.GetLogprobs(token, 2)
_, topLogprobs2 := processor.GetLogprobs(token, 5)

Expect(len(topLogprobs1)).To(Equal(2))

Check failure on line 69 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(topLogprobs1).To(HaveLen(2))` instead (ginkgolinter)
Expect(len(topLogprobs2)).To(Equal(5))

Check failure on line 70 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(topLogprobs2).To(HaveLen(5))` instead (ginkgolinter)
}

func TestLogprobsProcessor_ChatLogprobs(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

tokens := []string{"Hello", "world", "!"}
topK := 3

logprobs := processor.ProcessChatLogprobs(tokens, topK)

Expect(logprobs).NotTo(BeNil())
Expect(len(logprobs.Content)).To(Equal(len(tokens)))

Check failure on line 83 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(logprobs.Content).To(HaveLen(len(tokens)))` instead (ginkgolinter)

for i, content := range logprobs.Content {
Expect(content.Token).To(Equal(tokens[i]))
Expect(content.Logprob).To(BeNumerically("<", 0))
Expect(len(content.TopLogprobs)).To(Equal(topK))

Check failure on line 88 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(content.TopLogprobs).To(HaveLen(topK))` instead (ginkgolinter)
Expect(content.Bytes).NotTo(BeNil())
}
}

func TestLogprobsProcessor_TextLogprobs(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

tokens := []string{"Hello", "world"}
topK := 2

logprobs := processor.ProcessTextLogprobs(tokens, topK)

Expect(logprobs).NotTo(BeNil())
Expect(len(logprobs.Tokens)).To(Equal(len(tokens)))

Check failure on line 103 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(logprobs.Tokens).To(HaveLen(len(tokens)))` instead (ginkgolinter)
Expect(len(logprobs.TokenLogprobs)).To(Equal(len(tokens)))

Check failure on line 104 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(logprobs.TokenLogprobs).To(HaveLen(len(tokens)))` instead (ginkgolinter)
Expect(len(logprobs.TextOffset)).To(Equal(len(tokens)))

Check failure on line 105 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(logprobs.TextOffset).To(HaveLen(len(tokens)))` instead (ginkgolinter)
Expect(len(logprobs.TopLogprobs)).To(Equal(len(tokens)))

Check failure on line 106 in pkg/llm-d-inference-sim/logprobs_processor_test.go

View workflow job for this annotation

GitHub Actions / lint-and-test

ginkgo-linter: wrong length assertion. Consider using `Expect(logprobs.TopLogprobs).To(HaveLen(len(tokens)))` instead (ginkgolinter)

// Check text offsets are cumulative
expectedOffset := 0
for i, token := range tokens {
Expect(logprobs.TextOffset[i]).To(Equal(expectedOffset))
Expect(logprobs.Tokens[i]).To(Equal(token))
Expect(logprobs.TokenLogprobs[i]).To(BeNumerically("<", 0))
Expect(len(logprobs.TopLogprobs[i])).To(Equal(topK))
expectedOffset += len(token)
}
}

func TestLogprobsProcessor_EmptyTokens(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

// Test empty token lists
chatLogprobs := processor.ProcessChatLogprobs([]string{}, 3)
textLogprobs := processor.ProcessTextLogprobs([]string{}, 3)

Expect(chatLogprobs).To(BeNil())
Expect(textLogprobs).To(BeNil())
}

func TestLogprobsProcessor_ZeroTopK(t *testing.T) {
RegisterTestingT(t)
processor := NewLogprobsProcessor(100)

logprob, topLogprobs := processor.GetLogprobs("test", 0)

Expect(logprob).To(BeNumerically("<", 0))
Expect(topLogprobs).To(BeNil())
}

func TestLogprobsProcessor_CacheEviction(t *testing.T) {
RegisterTestingT(t)
// Test with very small cache size to trigger eviction
processor := NewLogprobsProcessor(2)

// Fill cache beyond capacity
processor.GetLogprobs("token1", 1)
processor.GetLogprobs("token2", 1)
processor.GetLogprobs("token3", 1) // Should trigger eviction

hits, misses, _ := processor.GetCacheStats()
Expect(hits).To(Equal(int64(0)))
Expect(misses).To(Equal(int64(3)))

// Access one of the earlier tokens - may or may not be in cache due to eviction
processor.GetLogprobs("token1", 1)

// Cache should be working (some entries may have been evicted)
hits2, misses2, _ := processor.GetCacheStats()
Expect(hits2).To(BeNumerically(">=", 0))
Expect(misses2).To(BeNumerically(">=", 3))
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add tests for the simulator with logprobs:
completions and chat completions requests with logprobs with and without streaming, and check the response?

Loading