Skip to content

Add InMemoryIndex unit tests #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/kvcache/indexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ func (k *Indexer) GetPodScores(ctx context.Context, prompt, modelName string,
traceLogger.Info("found tokens", "tokens", tokens, "block-keys", blockKeys)

// 3. query kvblock indexer for pods
strBlockKeys, keyToPods, err := k.kvBlockIndex.Lookup(ctx, blockKeys, sets.New(podIdentifiers...))
keyToPods, err := k.kvBlockIndex.Lookup(ctx, blockKeys, sets.New(podIdentifiers...))
if err != nil {
return nil, fmt.Errorf("failed to query kvblock indexer: %w", err)
}
traceLogger.Info("found block keys", "block-keys", blockKeys,
"pods", podsPerKeyPrintHelper(keyToPods))

// 4. score pods
podScores, err := k.kvBlockScorer.Score(strBlockKeys, keyToPods)
podScores, err := k.kvBlockScorer.Score(blockKeys, keyToPods)
if err != nil {
return nil, fmt.Errorf("failed to query kvblock scorer: %w", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/kvcache/kvblock/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ func testAddBasic(t *testing.T, index kvblock.Index) {
assert.NoError(t, err)

// Lookup after add
hitKeys, podsPerKey, err := index.Lookup(t.Context(), []kvblock.Key{key}, sets.Set[string]{})
podsPerKey, err := index.Lookup(t.Context(), []kvblock.Key{key}, sets.Set[string]{})
assert.NoError(t, err)
assert.Len(t, hitKeys, 1)
assert.Equal(t, key, hitKeys[0])
assert.Len(t, podsPerKey, 1)
assert.Contains(t, podsPerKey, key)
assert.Equal(t, podsPerKey[key], []string{"10.0.0.1", "10.0.0.2"})
}
13 changes: 6 additions & 7 deletions pkg/kvcache/kvblock/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,13 @@ type PodCache struct {
// If the podIdentifierSet is empty, all pods are returned.
//
// It returns:
// 1. A slice of the hit keys.
// 2. A map where the keys are those in (1) and the values are pod-identifiers.
// 3. An error if any occurred during the operation.
// 1. A map where the keys are those in (1) and the values are pod-identifiers.
// 2. An error if any occurred during the operation.
func (m *InMemoryIndex) Lookup(ctx context.Context, keys []Key,
podIdentifierSet sets.Set[string],
) ([]Key, map[Key][]string, error) {
) (map[Key][]string, error) {
if len(keys) == 0 {
return nil, nil, fmt.Errorf("no keys provided for lookup")
return nil, fmt.Errorf("no keys provided for lookup")
}

traceLogger := klog.FromContext(ctx).V(logging.TRACE).WithName("kvblock.InMemoryIndex.Lookup")
Expand All @@ -108,7 +107,7 @@ func (m *InMemoryIndex) Lookup(ctx context.Context, keys []Key,
if pods, found := m.data.Get(key); found { //nolint:nestif // TODO: can this be optimized?
if pods == nil || pods.cache.Len() == 0 {
traceLogger.Info("no pods found for key, cutting search", "key", key)
return keys[:idx], podsPerKey, nil // early stop since prefix-chain breaks here
return podsPerKey, nil // early stop since prefix-chain breaks here
}

highestHitIdx = idx
Expand All @@ -135,7 +134,7 @@ func (m *InMemoryIndex) Lookup(ctx context.Context, keys []Key,
traceLogger.Info("lookup completed", "highest-hit-index", highestHitIdx,
"pods-per-key", podsPerKeyPrintHelper(podsPerKey))

return keys[:highestHitIdx+1], podsPerKey, nil
return podsPerKey, nil
}

// Add adds a set of keys and their associated pod entries to the index backend.
Expand Down
70 changes: 70 additions & 0 deletions pkg/kvcache/kvblock/in_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,73 @@ func TestInMemoryAddBasic(t *testing.T) {
assert.NoError(t, err)
testAddBasic(t, index)
}

func TestInMemoryIndexSize(t *testing.T) {
// Test with small size to verify eviction
cfg := &kvblock.InMemoryIndexConfig{
Size: 2, // Only 2 keys max
PodCacheSize: 1, // Pod cache size doesn't matter for this test
}

index, err := kvblock.NewInMemoryIndex(cfg)
assert.NoError(t, err)

ctx := t.Context()

// Add first key
key1 := kvblock.Key{ModelName: "test-model", ChunkHash: 111}
err = index.Add(ctx, []kvblock.Key{key1}, []kvblock.PodEntry{{PodIdentifier: "pod1", DeviceTier: "gpu"}})
assert.NoError(t, err)

// Add second key
key2 := kvblock.Key{ModelName: "test-model", ChunkHash: 222}
err = index.Add(ctx, []kvblock.Key{key2}, []kvblock.PodEntry{{PodIdentifier: "pod2", DeviceTier: "gpu"}})
assert.NoError(t, err)

// Add third key - should evict the first one due to LRU
key3 := kvblock.Key{ModelName: "test-model", ChunkHash: 333}
err = index.Add(ctx, []kvblock.Key{key3}, []kvblock.PodEntry{{PodIdentifier: "pod3", DeviceTier: "cpu"}})
assert.NoError(t, err)

// Lookup should only return the last two keys
podsPerKey, err := index.Lookup(ctx, []kvblock.Key{key1, key2, key3}, nil)
assert.NoError(t, err)

assert.Len(t, podsPerKey, 2) // Only key2 and key3 should be present
assert.Len(t, podsPerKey[key2], 1)
assert.Len(t, podsPerKey[key3], 1)
assert.Contains(t, podsPerKey[key2], "pod2")
assert.Contains(t, podsPerKey[key3], "pod3")
}

func TestInMemoryIndexPodCacheSize(t *testing.T) {
// Test with small limits to verify enforcement
cfg := &kvblock.InMemoryIndexConfig{
Size: 1, // Only 1 key max
PodCacheSize: 2, // Only 2 pods per key
}

index, err := kvblock.NewInMemoryIndex(cfg)
assert.NoError(t, err)

// Test PodCacheSize limit: add more pods than the limit for one key
key := kvblock.Key{ModelName: "test-model", ChunkHash: 111}
pods := []kvblock.PodEntry{
{PodIdentifier: "pod1", DeviceTier: "gpu"},
{PodIdentifier: "pod2", DeviceTier: "gpu"},
{PodIdentifier: "pod3", DeviceTier: "cpu"}, // This should evict pod1 due to LRU
}

ctx := t.Context()

err = index.Add(ctx, []kvblock.Key{key}, pods)
assert.NoError(t, err)

// Lookup should only return 2 pods (pod2 and pod3), pod1 should be evicted
podsPerKey, err := index.Lookup(ctx, []kvblock.Key{key}, nil)
assert.NoError(t, err)
assert.Len(t, podsPerKey, 1)
assert.Len(t, podsPerKey[key], 2, "Should only have 2 pods due to PodCacheSize limit")
assert.Contains(t, podsPerKey[key], "pod2")
assert.Contains(t, podsPerKey[key], "pod3")
}
7 changes: 3 additions & 4 deletions pkg/kvcache/kvblock/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,9 @@ type Index interface {
// If the podIdentifierSet is empty, all pods are returned.
//
// It returns:
// 1. A slice of the hit keys.
// 2. A map where the keys are those in (1) and the values are pod-identifiers.
// 3. An error if any occurred during the operation.
Lookup(ctx context.Context, keys []Key, podIdentifierSet sets.Set[string]) ([]Key, map[Key][]string, error)
// 1. A map where the keys are those in (1) and the values are pod-identifiers.
// 2. An error if any occurred during the operation.
Lookup(ctx context.Context, keys []Key, podIdentifierSet sets.Set[string]) (map[Key][]string, error)
// Add adds a set of keys and their associated pod entries to the index backend.
Add(ctx context.Context, keys []Key, entries []PodEntry) error
// Evict removes a key and its associated pod entries from the index backend.
Expand Down
7 changes: 3 additions & 4 deletions pkg/kvcache/kvblock/instrumented_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,13 @@ func (m *instrumentedIndex) Lookup(
ctx context.Context,
keys []Key,
podIdentifierSet sets.Set[string],
) ([]Key, map[Key][]string, error) {
) (map[Key][]string, error) {
timer := prometheus.NewTimer(metrics.LookupLatency)
defer timer.ObserveDuration()

metrics.LookupRequests.Inc()

hitKeys, pods, err := m.next.Lookup(ctx, keys, podIdentifierSet)
metrics.LookupHits.Add(float64(len(hitKeys)))
pods, err := m.next.Lookup(ctx, keys, podIdentifierSet)

return hitKeys, pods, err
return pods, err
}
19 changes: 8 additions & 11 deletions pkg/kvcache/kvblock/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,13 @@ var _ Index = &RedisIndex{}
// If the podIdentifierSet is empty, all pods are returned.
//
// It returns:
// 1. A slice of the hit keys.
// 2. A map where the keys are those in (1) and the values are pod-identifiers.
// 3. An error if any occurred during the operation.
// 1. A map where the keys are those in (1) and the values are pod-identifiers.
// 2. An error if any occurred during the operation.
func (r *RedisIndex) Lookup(ctx context.Context, keys []Key,
podIdentifierSet sets.Set[string],
) ([]Key, map[Key][]string, error) {
) (map[Key][]string, error) {
if len(keys) == 0 {
return nil, nil, nil
return make(map[Key][]string), nil
}

logger := klog.FromContext(ctx).WithName("kvblock.RedisIndex.Lookup")
Expand All @@ -105,11 +104,10 @@ func (r *RedisIndex) Lookup(ctx context.Context, keys []Key,

_, execErr := pipe.Exec(ctx)
if execErr != nil {
return nil, nil, fmt.Errorf("redis pipeline execution failed: %w", execErr)
return nil, fmt.Errorf("redis pipeline execution failed: %w", execErr)
}

filterPods := len(podIdentifierSet) > 0 // predicate for filtering
highestHitIdx := 0

for idx, cmd := range results {
key := keys[idx]
Expand All @@ -121,7 +119,7 @@ func (r *RedisIndex) Lookup(ctx context.Context, keys []Key,
logger.Error(cmdErr, "failed to get pods for key", "key", key)
}

return keys[:idx], podsPerKey, nil // early stop since prefix-chain breaks here
return podsPerKey, nil // early stop since prefix-chain breaks here
}

var filteredPods []string
Expand All @@ -134,14 +132,13 @@ func (r *RedisIndex) Lookup(ctx context.Context, keys []Key,

if len(filteredPods) == 0 {
logger.Info("no pods found for key, cutting search", "key", key)
return keys[:idx], podsPerKey, nil // early stop since prefix-chain breaks here
return podsPerKey, nil // early stop since prefix-chain breaks here
}

highestHitIdx = idx
podsPerKey[key] = filteredPods
}

return keys[:highestHitIdx+1], podsPerKey, nil
return podsPerKey, nil
}

// Add adds a set of keys and their associated pod entries to the index backend.
Expand Down