Skip to content

Commit 6e59794

Browse files
committed
💸 cache: new package
This commit adds a simple bounded LRU cache implementation.
1 parent 492bc8d commit 6e59794

File tree

2 files changed

+350
-0
lines changed

2 files changed

+350
-0
lines changed

cache/cache.go

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
package cache
2+
3+
import (
4+
"iter"
5+
"math"
6+
)
7+
8+
// Entry represents a key-value pair in the cache.
9+
type Entry[K comparable, V any] struct {
10+
Key K
11+
Value V
12+
}
13+
14+
type boundedNode[K comparable, V any] struct {
15+
prev *boundedNode[K, V]
16+
next *boundedNode[K, V]
17+
Entry[K, V]
18+
}
19+
20+
// BoundedCache is a simple in-memory cache with a fixed upper bound on the number of entries.
21+
// It uses a doubly linked list to maintain the order of access, where the tail is the most recently used node.
22+
// When the cache reaches its capacity, insertions will evict the least recently used node (the head of the list).
23+
type BoundedCache[K comparable, V any] struct {
24+
nodeByKey map[K]*boundedNode[K, V]
25+
capacity int
26+
27+
// head is the least recently used node.
28+
head *boundedNode[K, V]
29+
// tail is the most recently used node.
30+
tail *boundedNode[K, V]
31+
}
32+
33+
// NewBoundedCache returns a new bounded cache with the given capacity.
34+
// If capacity is not positive, the cache will be effectively unbounded.
35+
func NewBoundedCache[K comparable, V any](capacity int) *BoundedCache[K, V] {
36+
if capacity <= 0 {
37+
capacity = math.MaxInt
38+
}
39+
return &BoundedCache[K, V]{
40+
nodeByKey: make(map[K]*boundedNode[K, V]),
41+
capacity: capacity,
42+
}
43+
}
44+
45+
// Len returns the number of entries in the cache.
46+
func (c *BoundedCache[K, V]) Len() int {
47+
return len(c.nodeByKey)
48+
}
49+
50+
// Capacity returns the maximum number of entries the cache can hold.
51+
func (c *BoundedCache[K, V]) Capacity() int {
52+
return c.capacity
53+
}
54+
55+
// Contains returns whether the cache contains the given key.
56+
//
57+
// Unlike Get, this method does not update the access order of the cache.
58+
func (c *BoundedCache[K, V]) Contains(key K) bool {
59+
_, ok := c.nodeByKey[key]
60+
return ok
61+
}
62+
63+
// Get returns the value associated with key.
64+
func (c *BoundedCache[K, V]) Get(key K) (value V, ok bool) {
65+
node, ok := c.nodeByKey[key]
66+
if !ok {
67+
return value, false
68+
}
69+
c.moveToTail(node)
70+
return node.Value, true
71+
}
72+
73+
// GetEntry returns the entry associated with key.
74+
func (c *BoundedCache[K, V]) GetEntry(key K) (entry *Entry[K, V], ok bool) {
75+
node, ok := c.nodeByKey[key]
76+
if !ok {
77+
return nil, false
78+
}
79+
c.moveToTail(node)
80+
return &node.Entry, true
81+
}
82+
83+
// Set inserts or updates the value associated with key.
84+
func (c *BoundedCache[K, V]) Set(key K, value V) {
85+
node, ok := c.nodeByKey[key]
86+
if !ok {
87+
c.insert(key, value)
88+
return
89+
}
90+
node.Value = value
91+
c.moveToTail(node)
92+
}
93+
94+
// Insert adds a new key-value pair to the cache if the key does not already exist.
95+
// It returns true if the insertion was successful, false if the key already exists.
96+
func (c *BoundedCache[K, V]) Insert(key K, value V) bool {
97+
if _, ok := c.nodeByKey[key]; ok {
98+
return false
99+
}
100+
c.insert(key, value)
101+
return true
102+
}
103+
104+
// InsertUnchecked adds a new key-value pair to the cache without checking if the key already exists.
105+
//
106+
// WARNING: This is undefined behavior if the key already exists in the cache.
107+
func (c *BoundedCache[K, V]) InsertUnchecked(key K, value V) {
108+
c.insert(key, value)
109+
}
110+
111+
// Remove deletes the value associated with key and returns whether the key was found.
112+
func (c *BoundedCache[K, V]) Remove(key K) bool {
113+
node, ok := c.nodeByKey[key]
114+
if !ok {
115+
return false
116+
}
117+
c.remove(node)
118+
return true
119+
}
120+
121+
// Clear removes all entries from the cache.
122+
func (c *BoundedCache[K, V]) Clear() {
123+
clear(c.nodeByKey)
124+
c.head = nil
125+
c.tail = nil
126+
}
127+
128+
// All returns an iterator over all entries in the cache,
129+
// starting from the least recently used (head) to the most recently used (tail).
130+
func (c *BoundedCache[K, V]) All() iter.Seq2[K, V] {
131+
return func(yield func(K, V) bool) {
132+
for node := c.head; node != nil; node = node.next {
133+
if !yield(node.Key, node.Value) {
134+
break
135+
}
136+
}
137+
}
138+
}
139+
140+
// Backward returns an iterator over all entries in the cache,
141+
// starting from the most recently used (tail) to the least recently used (head).
142+
func (c *BoundedCache[K, V]) Backward() iter.Seq2[K, V] {
143+
return func(yield func(K, V) bool) {
144+
for node := c.tail; node != nil; node = node.prev {
145+
if !yield(node.Key, node.Value) {
146+
break
147+
}
148+
}
149+
}
150+
}
151+
152+
// insert inserts the key-value pair as a new node at the tail of the list.
153+
func (c *BoundedCache[K, V]) insert(key K, value V) {
154+
// If the cache is at capacity, remove the least recently used item.
155+
if len(c.nodeByKey) == c.capacity {
156+
c.remove(c.head)
157+
}
158+
159+
node := &boundedNode[K, V]{
160+
prev: c.tail,
161+
Entry: Entry[K, V]{Key: key, Value: value},
162+
}
163+
164+
c.nodeByKey[key] = node
165+
166+
if c.tail != nil {
167+
c.tail.next = node
168+
} else {
169+
c.head = node
170+
}
171+
c.tail = node
172+
}
173+
174+
// moveToTail promotes the given node to the tail of the list,
175+
// indicating that it was recently accessed.
176+
func (c *BoundedCache[K, V]) moveToTail(node *boundedNode[K, V]) {
177+
// Check if the node is already at the tail.
178+
if node.next == nil {
179+
return
180+
}
181+
182+
// Detach from the list.
183+
node.next.prev = node.prev
184+
if node.prev != nil {
185+
node.prev.next = node.next
186+
} else {
187+
c.head = node.next
188+
}
189+
190+
// Attach to the tail.
191+
node.prev = c.tail
192+
node.next = nil
193+
c.tail.next = node
194+
c.tail = node
195+
}
196+
197+
// remove deletes the given node from the cache.
198+
func (c *BoundedCache[K, V]) remove(node *boundedNode[K, V]) {
199+
delete(c.nodeByKey, node.Key)
200+
201+
if node.prev != nil {
202+
node.prev.next = node.next
203+
} else {
204+
c.head = node.next
205+
}
206+
207+
if node.next != nil {
208+
node.next.prev = node.prev
209+
} else {
210+
c.tail = node.prev
211+
}
212+
}

cache/cache_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package cache_test
2+
3+
import (
4+
"math"
5+
"slices"
6+
"testing"
7+
8+
"github.com/database64128/shadowsocks-go/cache"
9+
)
10+
11+
func TestBoundedCache(t *testing.T) {
12+
c := cache.NewBoundedCache[int, int](3)
13+
assertBoundedCacheLenCapacityContent(t, c, nil, 3)
14+
c.Set(1, -1)
15+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{1, -1}}, 3)
16+
if !c.Insert(2, -2) {
17+
t.Error("c.Insert(2, -2) = false, want true")
18+
}
19+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{1, -1}, {2, -2}}, 3)
20+
c.InsertUnchecked(3, -3)
21+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{1, -1}, {2, -2}, {3, -3}}, 3)
22+
c.Set(4, -4)
23+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{2, -2}, {3, -3}, {4, -4}}, 3)
24+
if value, ok := c.Get(2); value != -2 || !ok {
25+
t.Errorf("c.Get(2) = %d, %v, want -2, true", value, ok)
26+
}
27+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{3, -3}, {4, -4}, {2, -2}}, 3)
28+
if entry, ok := c.GetEntry(4); entry == nil || entry.Key != 4 || entry.Value != -4 || !ok {
29+
t.Errorf("c.GetEntry(4) = %v, %v, want {Key: 4, Value: -4}, true", entry, ok)
30+
}
31+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{3, -3}, {2, -2}, {4, -4}}, 3)
32+
if !c.Remove(4) {
33+
t.Error("c.Remove(4) = false, want true")
34+
}
35+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{3, -3}, {2, -2}}, 3)
36+
c.Set(2, 2)
37+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{3, -3}, {2, 2}}, 3)
38+
if value, ok := c.Get(3); value != -3 || !ok {
39+
t.Errorf("c.Get(3) = %d, %v, want -3, true", value, ok)
40+
}
41+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{{2, 2}, {3, -3}}, 3)
42+
c.Clear()
43+
assertBoundedCacheLenCapacityContent(t, c, nil, 3)
44+
}
45+
46+
func TestBoundedCacheUnboundedCapacity(t *testing.T) {
47+
c := cache.NewBoundedCache[int, int](0)
48+
assertBoundedCacheLenCapacityContent(t, c, nil, math.MaxInt)
49+
c.Set(1, -1)
50+
c.Set(2, -2)
51+
c.Set(3, -3)
52+
for range c.All() {
53+
for range c.Backward() {
54+
break
55+
}
56+
break
57+
}
58+
c.Set(4, -4)
59+
c.Set(5, -5)
60+
c.Set(6, -6)
61+
assertBoundedCacheLenCapacityContent(t, c, []cache.Entry[int, int]{
62+
{1, -1}, {2, -2}, {3, -3}, {4, -4}, {5, -5}, {6, -6},
63+
}, math.MaxInt)
64+
}
65+
66+
func assertBoundedCacheLenCapacityContent(t *testing.T, c *cache.BoundedCache[int, int], want []cache.Entry[int, int], expectedCapacity int) {
67+
t.Helper()
68+
69+
if got := c.Len(); got != len(want) {
70+
t.Errorf("c.Len() = %d, want %d", got, len(want))
71+
}
72+
if got := c.Capacity(); got != expectedCapacity {
73+
t.Errorf("c.Capacity() = %d, want %d", got, expectedCapacity)
74+
}
75+
76+
got := make([]cache.Entry[int, int], 0, len(want))
77+
for key, value := range c.All() {
78+
got = append(got, cache.Entry[int, int]{Key: key, Value: value})
79+
}
80+
if !slices.Equal(got, want) {
81+
t.Errorf("c.All() = %v, want %v", got, want)
82+
}
83+
84+
got = got[:0]
85+
for key, value := range c.Backward() {
86+
got = append(got, cache.Entry[int, int]{Key: key, Value: value})
87+
}
88+
if !slicesReverseEqual(got, want) {
89+
t.Errorf("c.Backward() = %v, want %v", got, want)
90+
}
91+
92+
for key := range 10 {
93+
if index := slices.IndexFunc(want, func(e cache.Entry[int, int]) bool {
94+
return e.Key == key
95+
}); index != -1 {
96+
expectedEntry := want[index]
97+
expectedValue := expectedEntry.Value
98+
99+
if !c.Contains(key) {
100+
t.Errorf("c.Contains(%d) = false, want true", key)
101+
}
102+
103+
if c.Insert(key, expectedValue) {
104+
t.Errorf("c.Insert(%d, %d) = true, want false", key, expectedValue)
105+
}
106+
} else {
107+
if c.Contains(key) {
108+
t.Errorf("c.Contains(%d) = true, want false", key)
109+
}
110+
111+
value, ok := c.Get(key)
112+
if value != 0 || ok {
113+
t.Errorf("c.Get(%d) = %d, %v, want 0, false", key, value, ok)
114+
}
115+
116+
entry, ok := c.GetEntry(key)
117+
if entry != nil || ok {
118+
t.Errorf("c.GetEntry(%d) = %v, %v, want nil, false", key, entry, ok)
119+
}
120+
121+
if c.Remove(key) {
122+
t.Errorf("c.Remove(%d) = true, want false", key)
123+
}
124+
}
125+
}
126+
}
127+
128+
func slicesReverseEqual[S ~[]E, E comparable](s1, s2 S) bool {
129+
if len(s1) != len(s2) {
130+
return false
131+
}
132+
for i := range s1 {
133+
if s1[i] != s2[len(s2)-1-i] {
134+
return false
135+
}
136+
}
137+
return true
138+
}

0 commit comments

Comments
 (0)