Skip to content

Commit 191e710

Browse files
kfirtoledovMaroon
andauthored
refactor: Replace prefix cache structure with golang-lru (#928)
* refactor: Replace prefix cache structure with golang-lru Signed-off-by: Kfir Toledo <[email protected]> Co-authored-by: Maroon Ayoub <[email protected]> * fix: rename prefix scorer parameters and convert test to benchmark test Signed-off-by: Kfir Toledo <[email protected]> * feat: Add per server LRU capacity Signed-off-by: Kfir Toledo <[email protected]> * fix: Fix typos and error handle Signed-off-by: Kfir Toledo <[email protected]> * fix: add safety check for LRUCapacityPerServer Signed-off-by: Kfir Toledo <[email protected]> --------- Signed-off-by: Kfir Toledo <[email protected]> Co-authored-by: Maroon Ayoub <[email protected]>
1 parent 17824ba commit 191e710

File tree

8 files changed

+209
-165
lines changed

8 files changed

+209
-165
lines changed

cmd/epp/runner/runner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ func loadPrefixCacheConfig() prefix.Config {
321321
return prefix.Config{
322322
HashBlockSize: envutil.GetEnvInt("PREFIX_CACHE_HASH_BLOCK_SIZE", prefix.DefaultHashBlockSize, baseLogger),
323323
MaxPrefixBlocksToMatch: envutil.GetEnvInt("PREFIX_CACHE_MAX_PREFIX_BLOCKS", prefix.DefaultMaxPrefixBlocks, baseLogger),
324-
LRUIndexerCapacity: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY", prefix.DefaultLRUIndexerCapacity, baseLogger),
324+
LRUCapacityPerServer: envutil.GetEnvInt("PREFIX_CACHE_LRU_CAPACITY_PER_SERVER", prefix.DefaultLRUCapacityPerServer, baseLogger),
325325
}
326326
}
327327

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/go-logr/logr v1.4.3
1010
github.com/google/go-cmp v0.7.0
1111
github.com/google/uuid v1.6.0
12+
github.com/hashicorp/golang-lru/v2 v2.0.7
1213
github.com/onsi/ginkgo/v2 v2.23.4
1314
github.com/onsi/gomega v1.37.0
1415
github.com/prometheus/client_golang v1.22.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5T
9595
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
9696
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0 h1:TmHmbvxPmaegwhDubVz0lICL0J5Ka2vwTzhoePEXsGE=
9797
github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKAPIS5qsmQDqZna/PgVt4rWtI=
98+
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
99+
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
98100
github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4=
99101
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
100102
github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4=

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer.go

Lines changed: 92 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -20,154 +20,130 @@ import (
2020
"context"
2121
"sync"
2222
"time"
23-
"unsafe"
24-
25-
"container/list"
2623

24+
lru "github.com/hashicorp/golang-lru/v2"
2725
"sigs.k8s.io/controller-runtime/pkg/log"
2826
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2927
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3028
)
3129

32-
func newIndexer(maxCacheSize int) *indexer {
33-
t := &indexer{
34-
maxCacheSize: maxCacheSize,
35-
table: make(map[BlockHash]map[ServerID]*list.Element),
36-
ll: list.New(),
37-
}
38-
go t.ReportCacheSize(time.Second)
39-
return t
40-
}
41-
4230
// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
43-
// prefix cached .
31+
// prefix cached.
4432
type indexer struct {
45-
mu sync.RWMutex
46-
maxCacheSize int
47-
table map[BlockHash]map[ServerID]*list.Element // from any prefix cache to the cache entry to find the server
48-
ll *list.List // LinkedList to keep track of the order of entries
33+
mu sync.RWMutex
34+
hashToPods map[BlockHash]podSet // the lookup data structure to find pods that have the BlockHash cached
35+
podToLRU map[ServerID]*lru.Cache[BlockHash, struct{}] // key is pod namespacedName, value is an LRU cache
36+
maxLRUSize int
4937
}
5038

51-
// value is the value stored in the linked list.
52-
type value struct {
53-
server ServerID
54-
hash BlockHash
55-
}
56-
57-
// Get returns the set of servers that have the given prefix hash cached.
58-
func (i *indexer) Get(hash BlockHash) map[ServerID]bool {
59-
i.mu.RLock()
60-
defer i.mu.RUnlock()
61-
res := map[ServerID]bool{}
62-
for server := range i.table[hash] {
63-
res[server] = true
39+
// newIndexer initializes an indexer with size limits and starts cache size reporting.
40+
func newIndexer(maxLRUSize int) *indexer {
41+
ix := &indexer{
42+
hashToPods: make(map[BlockHash]podSet),
43+
podToLRU: make(map[ServerID]*lru.Cache[BlockHash, struct{}]),
44+
maxLRUSize: maxLRUSize,
6445
}
65-
return res
46+
47+
go ix.ReportLRUSize(time.Second)
48+
return ix
6649
}
6750

68-
// Add adds a list of prefix hashes of a single request to the server the request was sent to.
69-
// The intuition is that this server is likely to have the prefix cached, so next time a request
70-
// sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
71-
func (i *indexer) Add(hashes []BlockHash, server ServerID) {
51+
// Add adds a list of prefix hashes to the cache, tied to the server.
52+
func (i *indexer) Add(hashes []BlockHash, pod ServerID) {
7253
i.mu.Lock()
73-
defer i.mu.Unlock()
74-
for _, hash := range hashes {
75-
i.add(hash, server)
54+
// Check if the LRU pod exist
55+
lruForPod, exists := i.podToLRU[pod]
56+
if !exists {
57+
newLRU, _ := lru.NewWithEvict[BlockHash, struct{}](i.maxLRUSize, i.makeEvictionFn(pod))
58+
i.podToLRU[pod] = newLRU
59+
lruForPod = newLRU
7660
}
77-
}
7861

79-
func (i *indexer) check(hash BlockHash, server ServerID) (*list.Element, bool) {
80-
servers, ok := i.table[hash]
81-
if !ok {
82-
return nil, false
62+
i.mu.Unlock()
63+
64+
// Add to LRU (may evict)
65+
for _, hash := range hashes {
66+
lruForPod.Add(hash, struct{}{})
8367
}
84-
e, ok := servers[server]
85-
return e, ok
86-
}
8768

88-
func (i *indexer) add(hash BlockHash, server ServerID) {
89-
e, exists := i.check(hash, server)
90-
if exists {
91-
i.ll.MoveToBack(e)
92-
} else {
93-
i.create(hash, server)
69+
// Update hashToPods once under lock
70+
i.mu.Lock()
71+
for _, hash := range hashes {
72+
pods := i.hashToPods[hash]
73+
if pods == nil {
74+
pods = make(podSet)
75+
}
76+
pods[pod] = struct{}{}
77+
i.hashToPods[hash] = pods
9478
}
79+
80+
i.mu.Unlock()
9581
}
9682

97-
func (i *indexer) create(hash BlockHash, server ServerID) {
98-
for i.ll.Len() >= i.maxCacheSize {
99-
// Evict the least recently used entry if we've exceeded the max cache size
100-
i.evict()
101-
}
83+
// Get returns a set of servers that have the given prefix hash cached.
84+
func (i *indexer) Get(hash BlockHash) podSet {
85+
i.mu.RLock()
86+
defer i.mu.RUnlock()
10287

103-
if _, ok := i.table[hash]; !ok {
104-
i.table[hash] = make(map[ServerID]*list.Element)
105-
}
106-
v := &value{
107-
server: server,
108-
hash: hash,
88+
res := podSet{}
89+
pods, ok := i.hashToPods[hash]
90+
if !ok {
91+
return res
10992
}
110-
e := i.ll.PushBack(v)
111-
i.table[hash][server] = e
93+
94+
return pods
11295
}
11396

114-
// evict removes the least recently used entry from the cache
115-
func (i *indexer) evict() {
116-
oldestNode := i.ll.Front()
117-
if oldestNode == nil {
118-
return
97+
// makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction.
98+
func (i *indexer) makeEvictionFn(pod ServerID) func(BlockHash, struct{}) {
99+
return func(hash BlockHash, _ struct{}) {
100+
i.mu.Lock()
101+
defer i.mu.Unlock()
102+
// Remove the pod from the hash→pods map
103+
if podSet, ok := i.hashToPods[hash]; ok {
104+
delete(podSet, pod)
105+
if len(podSet) == 0 {
106+
delete(i.hashToPods, hash)
107+
}
108+
}
119109
}
120-
i.ll.Remove(oldestNode)
121-
122-
v := oldestNode.Value.(*value)
123-
hash := v.hash
124-
server := v.server
125-
// Remove from the hash map
126-
serverMap := i.table[hash]
127-
delete(serverMap, server)
128-
129-
// If this was the last server for this hash, remove the hash entry entirely
130-
if len(serverMap) == 0 {
131-
delete(i.table, hash)
132-
}
133-
134-
log.FromContext(context.TODO()).V(logutil.TRACE).Info("Evicted LRU entry", "hash", hash, "server", server)
135110
}
136111

137-
// ReportCacheSize starts a goroutine that periodically reports the cache size metric
138-
func (i *indexer) ReportCacheSize(interval time.Duration) {
112+
// ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric.
113+
func (i *indexer) ReportLRUSize(interval time.Duration) {
139114
ticker := time.NewTicker(interval)
140115
defer ticker.Stop()
141116
for range ticker.C {
142117
i.mu.RLock()
143-
metrics.RecordPrefixCacheSize(int64(i.ll.Len()))
144-
log.FromContext(context.TODO()).V(logutil.TRACE).Info("LRU", "# entries", i.ll.Len(), "estimated size MB", i.ll.Len()*i.estimateEntrySize()/1000000)
118+
totalEntries := 0
119+
maxPodEntries := 0
120+
maxPodName := ServerID{}
121+
122+
for pod, lruCache := range i.podToLRU {
123+
size := lruCache.Len()
124+
totalEntries += size
125+
if size > maxPodEntries {
126+
maxPodEntries = size
127+
maxPodName = pod
128+
}
129+
}
130+
131+
numPods := len(i.podToLRU)
132+
avg := 0.0
133+
if numPods > 0 {
134+
avg = float64(totalEntries) / float64(numPods)
135+
}
136+
137+
metrics.RecordPrefixCacheSize(int64(totalEntries))
138+
log.FromContext(context.TODO()).V(logutil.TRACE).Info("Prefix cache state",
139+
"total entries", totalEntries,
140+
"# pods", numPods,
141+
"avg entries per pod", avg,
142+
"pod with max cache", maxPodName,
143+
"max pod size", maxPodEntries,
144+
"global max LRU cache capacity per pod", i.maxLRUSize,
145+
)
146+
145147
i.mu.RUnlock()
146148
}
147149
}
148-
149-
// estimateEntrySize estimates the memory size of a cache entry in bytes.
150-
func (i *indexer) estimateEntrySize() int {
151-
size := 0
152-
153-
// Estimate the size of a node in the linked list.
154-
// First get the size of the node struct via unsafe.Sizeof.
155-
// The prev and next pointers are 8 bytes each on a 64-bit system.
156-
// The BlockHash is a uint64, which is 8 bytes.
157-
// The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
158-
// The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
159-
// So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
160-
size += int(unsafe.Sizeof(value{}))
161-
// Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
162-
size += 2 * 63
163-
164-
// Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
165-
size += 8 // Size of the BlockHash (uint64).
166-
size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
167-
size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
168-
size += 8 // Size of the pointer to the node in the hash map.
169-
170-
// Based on the above estimates, the estimated size of an entry is:
171-
// (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
172-
return size
173-
}

pkg/epp/scheduling/framework/plugins/multi/prefix/indexer_test.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,23 @@ import (
2222
)
2323

2424
func TestIndexer_AddAndGet(t *testing.T) {
25-
cache := newIndexer(2)
25+
i := newIndexer(2)
2626

2727
hash1 := BlockHash(1)
2828
server := ServerID{Namespace: "default", Name: "server1"}
29-
3029
// Add an entry to the cache
31-
cache.Add([]BlockHash{hash1}, server)
30+
i.Add([]BlockHash{hash1}, server)
3231

3332
// Retrieve the entry
34-
assert.Equal(t, 1, cache.ll.Len(), "Cache size should be 1 after adding an entry")
35-
servers := cache.Get(hash1)
33+
assert.Equal(t, 1, i.podToLRU[server].Len(), "Cache size should be 1 after adding an entry")
34+
servers := i.Get(hash1)
3635
assert.Contains(t, servers, server, "Cache should contain the added server")
3736

3837
// Add another entry to the cache, the cache size should be incremented to 2.
39-
cache.Add([]BlockHash{BlockHash(2)}, server)
40-
assert.Equal(t, 2, cache.ll.Len(), "Cache size should be 2 after adding an entry")
38+
i.Add([]BlockHash{BlockHash(2)}, server)
39+
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should be 2 after adding an entry")
4140

4241
// Add another entry to the cache, which should evict the first one due to max size.
43-
cache.Add([]BlockHash{BlockHash(3)}, server)
44-
assert.Equal(t, 2, cache.ll.Len(), "Cache size should still be 2 after adding an entry")
42+
i.Add([]BlockHash{BlockHash(3)}, server)
43+
assert.Equal(t, 2, i.podToLRU[server].Len(), "Cache size should still be 2 after adding an entry")
4544
}

0 commit comments

Comments
 (0)