@@ -20,154 +20,130 @@ import (
20
20
"context"
21
21
"sync"
22
22
"time"
23
- "unsafe"
24
-
25
- "container/list"
26
23
24
+ lru "github.com/hashicorp/golang-lru/v2"
27
25
"sigs.k8s.io/controller-runtime/pkg/log"
28
26
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
29
27
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
30
28
)
31
29
32
- func newIndexer (maxCacheSize int ) * indexer {
33
- t := & indexer {
34
- maxCacheSize : maxCacheSize ,
35
- table : make (map [BlockHash ]map [ServerID ]* list.Element ),
36
- ll : list .New (),
37
- }
38
- go t .ReportCacheSize (time .Second )
39
- return t
40
- }
41
-
42
30
// An indexer maintains an LRU cache of prompt prefix hashes and the server(s) that might have that
43
- // prefix cached .
31
+ // prefix cached.
44
32
type indexer struct {
45
- mu sync.RWMutex
46
- maxCacheSize int
47
- table map [BlockHash ] map [ ServerID ]* list. Element // from any prefix cache to the cache entry to find the server
48
- ll * list. List // LinkedList to keep track of the order of entries
33
+ mu sync.RWMutex
34
+ hashToPods map [ BlockHash ] podSet // the lookup data structure to find pods that have the BlockHash cached
35
+ podToLRU map [ServerID ]* lru. Cache [ BlockHash , struct {}] // key is pod namespacedName, value is an LRU cache
36
+ maxLRUSize int
49
37
}
50
38
51
- // value is the value stored in the linked list.
52
- type value struct {
53
- server ServerID
54
- hash BlockHash
55
- }
56
-
57
- // Get returns the set of servers that have the given prefix hash cached.
58
- func (i * indexer ) Get (hash BlockHash ) map [ServerID ]bool {
59
- i .mu .RLock ()
60
- defer i .mu .RUnlock ()
61
- res := map [ServerID ]bool {}
62
- for server := range i .table [hash ] {
63
- res [server ] = true
39
+ // newIndexer initializes an indexer with size limits and starts cache size reporting.
40
+ func newIndexer (maxLRUSize int ) * indexer {
41
+ ix := & indexer {
42
+ hashToPods : make (map [BlockHash ]podSet ),
43
+ podToLRU : make (map [ServerID ]* lru.Cache [BlockHash , struct {}]),
44
+ maxLRUSize : maxLRUSize ,
64
45
}
65
- return res
46
+
47
+ go ix .ReportLRUSize (time .Second )
48
+ return ix
66
49
}
67
50
68
- // Add adds a list of prefix hashes of a single request to the server the request was sent to.
69
- // The intuition is that this server is likely to have the prefix cached, so next time a request
70
- // sharing the longest prefix should be sent to the same server to take advantage of the cache hit.
71
- func (i * indexer ) Add (hashes []BlockHash , server ServerID ) {
51
+ // Add adds a list of prefix hashes to the cache, tied to the server.
52
+ func (i * indexer ) Add (hashes []BlockHash , pod ServerID ) {
72
53
i .mu .Lock ()
73
- defer i .mu .Unlock ()
74
- for _ , hash := range hashes {
75
- i .add (hash , server )
54
+ // Check if the LRU pod exist
55
+ lruForPod , exists := i .podToLRU [pod ]
56
+ if ! exists {
57
+ newLRU , _ := lru .NewWithEvict [BlockHash , struct {}](i .maxLRUSize , i .makeEvictionFn (pod ))
58
+ i .podToLRU [pod ] = newLRU
59
+ lruForPod = newLRU
76
60
}
77
- }
78
61
79
- func (i * indexer ) check (hash BlockHash , server ServerID ) (* list.Element , bool ) {
80
- servers , ok := i .table [hash ]
81
- if ! ok {
82
- return nil , false
62
+ i .mu .Unlock ()
63
+
64
+ // Add to LRU (may evict)
65
+ for _ , hash := range hashes {
66
+ lruForPod .Add (hash , struct {}{})
83
67
}
84
- e , ok := servers [server ]
85
- return e , ok
86
- }
87
68
88
- func (i * indexer ) add (hash BlockHash , server ServerID ) {
89
- e , exists := i .check (hash , server )
90
- if exists {
91
- i .ll .MoveToBack (e )
92
- } else {
93
- i .create (hash , server )
69
+ // Update hashToPods once under lock
70
+ i .mu .Lock ()
71
+ for _ , hash := range hashes {
72
+ pods := i .hashToPods [hash ]
73
+ if pods == nil {
74
+ pods = make (podSet )
75
+ }
76
+ pods [pod ] = struct {}{}
77
+ i .hashToPods [hash ] = pods
94
78
}
79
+
80
+ i .mu .Unlock ()
95
81
}
96
82
97
- func (i * indexer ) create (hash BlockHash , server ServerID ) {
98
- for i .ll .Len () >= i .maxCacheSize {
99
- // Evict the least recently used entry if we've exceeded the max cache size
100
- i .evict ()
101
- }
83
+ // Get returns a set of servers that have the given prefix hash cached.
84
+ func (i * indexer ) Get (hash BlockHash ) podSet {
85
+ i .mu .RLock ()
86
+ defer i .mu .RUnlock ()
102
87
103
- if _ , ok := i .table [hash ]; ! ok {
104
- i .table [hash ] = make (map [ServerID ]* list.Element )
105
- }
106
- v := & value {
107
- server : server ,
108
- hash : hash ,
88
+ res := podSet {}
89
+ pods , ok := i .hashToPods [hash ]
90
+ if ! ok {
91
+ return res
109
92
}
110
- e := i . ll . PushBack ( v )
111
- i. table [ hash ][ server ] = e
93
+
94
+ return pods
112
95
}
113
96
114
- // evict removes the least recently used entry from the cache
115
- func (i * indexer ) evict () {
116
- oldestNode := i .ll .Front ()
117
- if oldestNode == nil {
118
- return
97
+ // makeEvictionFn returns a per-pod LRU eviction callback that removes the pod from hashToPods on eviction.
98
+ func (i * indexer ) makeEvictionFn (pod ServerID ) func (BlockHash , struct {}) {
99
+ return func (hash BlockHash , _ struct {}) {
100
+ i .mu .Lock ()
101
+ defer i .mu .Unlock ()
102
+ // Remove the pod from the hash→pods map
103
+ if podSet , ok := i .hashToPods [hash ]; ok {
104
+ delete (podSet , pod )
105
+ if len (podSet ) == 0 {
106
+ delete (i .hashToPods , hash )
107
+ }
108
+ }
119
109
}
120
- i .ll .Remove (oldestNode )
121
-
122
- v := oldestNode .Value .(* value )
123
- hash := v .hash
124
- server := v .server
125
- // Remove from the hash map
126
- serverMap := i .table [hash ]
127
- delete (serverMap , server )
128
-
129
- // If this was the last server for this hash, remove the hash entry entirely
130
- if len (serverMap ) == 0 {
131
- delete (i .table , hash )
132
- }
133
-
134
- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("Evicted LRU entry" , "hash" , hash , "server" , server )
135
110
}
136
111
137
- // ReportCacheSize starts a goroutine that periodically reports the cache size metric
138
- func (i * indexer ) ReportCacheSize (interval time.Duration ) {
112
+ // ReportLRUSize starts a goroutine that periodically reports the LRU cache size metric.
113
+ func (i * indexer ) ReportLRUSize (interval time.Duration ) {
139
114
ticker := time .NewTicker (interval )
140
115
defer ticker .Stop ()
141
116
for range ticker .C {
142
117
i .mu .RLock ()
143
- metrics .RecordPrefixCacheSize (int64 (i .ll .Len ()))
144
- log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("LRU" , "# entries" , i .ll .Len (), "estimated size MB" , i .ll .Len ()* i .estimateEntrySize ()/ 1000000 )
118
+ totalEntries := 0
119
+ maxPodEntries := 0
120
+ maxPodName := ServerID {}
121
+
122
+ for pod , lruCache := range i .podToLRU {
123
+ size := lruCache .Len ()
124
+ totalEntries += size
125
+ if size > maxPodEntries {
126
+ maxPodEntries = size
127
+ maxPodName = pod
128
+ }
129
+ }
130
+
131
+ numPods := len (i .podToLRU )
132
+ avg := 0.0
133
+ if numPods > 0 {
134
+ avg = float64 (totalEntries ) / float64 (numPods )
135
+ }
136
+
137
+ metrics .RecordPrefixCacheSize (int64 (totalEntries ))
138
+ log .FromContext (context .TODO ()).V (logutil .TRACE ).Info ("Prefix cache state" ,
139
+ "total entries" , totalEntries ,
140
+ "# pods" , numPods ,
141
+ "avg entries per pod" , avg ,
142
+ "pod with max cache" , maxPodName ,
143
+ "max pod size" , maxPodEntries ,
144
+ "global max LRU cache capacity per pod" , i .maxLRUSize ,
145
+ )
146
+
145
147
i .mu .RUnlock ()
146
148
}
147
149
}
148
-
149
- // estimateEntrySize estimates the memory size of a cache entry in bytes.
150
- func (i * indexer ) estimateEntrySize () int {
151
- size := 0
152
-
153
- // Estimate the size of a node in the linked list.
154
- // First get the size of the node struct via unsafe.Sizeof.
155
- // The prev and next pointers are 8 bytes each on a 64-bit system.
156
- // The BlockHash is a uint64, which is 8 bytes.
157
- // The ServerID is a NamespacedName, which contains two strings (Name and Namespace).
158
- // The headers for the strings are 16 bytes each (8 bytes for the pointer and 8 bytes for the length).
159
- // So unsafe.Sizeof(node{}) should return 2*8 + 8 + 2*16 = 48 bytes.
160
- size += int (unsafe .Sizeof (value {}))
161
- // Size of the Name and Namespace strings in ServerID, assuming 63 bytes each (max length for Kubernetes NamespacedName).
162
- size += 2 * 63
163
-
164
- // Estimate the size of an entry in the hash map. Note the overhead of the map headers and buckets are ignored.
165
- size += 8 // Size of the BlockHash (uint64).
166
- size += 2 * 16 // Size of the ServerID string headers (NamespacedName).
167
- size += 2 * 63 // Size of the Name and Namespace strings in ServerID.
168
- size += 8 // Size of the pointer to the node in the hash map.
169
-
170
- // Based on the above estimates, the estimated size of an entry is:
171
- // (48 + 2*63) + (8 + 2*16 + 2*63 + 8) = 348 bytes.
172
- return size
173
- }
0 commit comments