Skip to content

Commit 9badd7f

Browse files
authored
Merge pull request #268 from yyforyongyu/lru-methods
cache: add deletion and iteration methods
2 parents da7f35a + 465459a commit 9badd7f

File tree

4 files changed

+224
-15
lines changed

4 files changed

+224
-15
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ breakpoints.txt
2828

2929
# coverage output
3030
coverage.txt
31+
32+
# go workspace
33+
go.work

cache/lru/lru.go

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@ import (
77
"github.com/lightninglabs/neutrino/cache"
88
)
99

10-
// elementMap is an alias for a map from a generic interface to a list.Element.
11-
type elementMap[K comparable, V any] map[K]V
12-
1310
// entry represents a (key,value) pair entry in the Cache. The Cache's list
1411
// stores entries which let us get the cache key when an entry is evicted.
1512
type entry[K comparable, V cache.Value] struct {
@@ -33,7 +30,7 @@ type Cache[K comparable, V cache.Value] struct {
3330

3431
// cache is a generic cache which allows us to find an elements position
3532
// in the ll list from a given key.
36-
cache elementMap[K, *Element[entry[K, V]]]
33+
cache syncMap[K, *Element[entry[K, V]]]
3734

3835
// mtx is used to make sure the Cache is thread-safe.
3936
mtx sync.RWMutex
@@ -45,7 +42,7 @@ func NewCache[K comparable, V cache.Value](capacity uint64) *Cache[K, V] {
4542
return &Cache[K, V]{
4643
capacity: capacity,
4744
ll: NewList[entry[K, V]](),
48-
cache: make(map[K]*Element[entry[K, V]]),
45+
cache: syncMap[K, *Element[entry[K, V]]]{},
4946
}
5047
}
5148

@@ -84,7 +81,7 @@ func (c *Cache[K, V]) evict(needed uint64) (bool, error) {
8481

8582
// Remove the element from the cache.
8683
c.ll.Remove(elr)
87-
delete(c.cache, ce.key)
84+
c.cache.Delete(ce.key)
8885
evicted = true
8986
}
9087
}
@@ -108,17 +105,22 @@ func (c *Cache[K, V]) Put(key K, value V) (bool, error) {
108105
"cache with capacity %v", vs, c.capacity)
109106
}
110107

108+
// Load the element.
109+
el, ok := c.cache.Load(key)
110+
111+
// Update the internal list inside a lock.
111112
c.mtx.Lock()
112-
defer c.mtx.Unlock()
113113

114114
// If the element already exists, remove it and decrease cache's size.
115-
el, ok := c.cache[key]
116115
if ok {
117116
es, err := el.Value.value.Size()
118117
if err != nil {
118+
c.mtx.Unlock()
119+
119120
return false, fmt.Errorf("couldn't determine size of "+
120121
"existing cache value %v", err)
121122
}
123+
122124
c.ll.Remove(el)
123125
c.size -= es
124126
}
@@ -132,26 +134,31 @@ func (c *Cache[K, V]) Put(key K, value V) (bool, error) {
132134

133135
// We have made enough space in the cache, so just insert it.
134136
el = c.ll.PushFront(entry[K, V]{key, value})
135-
c.cache[key] = el
136137
c.size += vs
137138

139+
// Release the lock.
140+
c.mtx.Unlock()
141+
142+
// Update the cache.
143+
c.cache.Store(key, el)
144+
138145
return evicted, nil
139146
}
140147

141148
// Get will return value for a given key, making the element the most recently
142149
// accessed item in the process. Will return nil if the key isn't found.
143150
func (c *Cache[K, V]) Get(key K) (V, error) {
144-
c.mtx.Lock()
145-
defer c.mtx.Unlock()
146-
147151
var defaultVal V
148152

149-
el, ok := c.cache[key]
153+
el, ok := c.cache.Load(key)
150154
if !ok {
151155
// Element not found in the cache.
152156
return defaultVal, cache.ErrElementNotFound
153157
}
154158

159+
c.mtx.Lock()
160+
defer c.mtx.Unlock()
161+
155162
// When the cache needs to evict a element to make space for another
156163
// one, it starts eviction from the back, so by moving this element to
157164
// the front, it's eviction is delayed because it's recently accessed.
@@ -166,3 +173,45 @@ func (c *Cache[K, V]) Len() int {
166173

167174
return c.ll.Len()
168175
}
176+
177+
// Delete removes an item from the cache.
178+
func (c *Cache[K, V]) Delete(key K) {
179+
c.LoadAndDelete(key)
180+
}
181+
182+
// LoadAndDelete queries an item and deletes it from the cache using the
183+
// specified key.
184+
func (c *Cache[K, V]) LoadAndDelete(key K) (V, bool) {
185+
var defaultVal V
186+
187+
// Noop if the element doesn't exist.
188+
el, ok := c.cache.LoadAndDelete(key)
189+
if !ok {
190+
return defaultVal, false
191+
}
192+
193+
c.mtx.Lock()
194+
defer c.mtx.Unlock()
195+
196+
// Get its size.
197+
vs, err := el.Value.value.Size()
198+
if err != nil {
199+
return defaultVal, false
200+
}
201+
202+
// Remove the element from the list and update the cache's size.
203+
c.ll.Remove(el)
204+
c.size -= vs
205+
206+
return el.Value.value, true
207+
}
208+
209+
// Range iterates the cache.
210+
func (c *Cache[K, V]) Range(visitor func(K, V) bool) {
211+
// valueVisitor is a closure to help unwrap the value from the cache.
212+
valueVisitor := func(key K, value *Element[entry[K, V]]) bool {
213+
return visitor(key, value.Value.value)
214+
}
215+
216+
c.cache.Range(valueVisitor)
217+
}

cache/lru/lru_test.go

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func TestElementSizeCapacityEvictsEverything(t *testing.T) {
9797
// Insert element with size=capacity of cache, should evict everything.
9898
c.Put(4, &sizeable{value: 4, size: 3})
9999
require.Equal(t, c.Len(), 1)
100-
require.Equal(t, len(c.cache), 1)
100+
require.Equal(t, c.cache.Len(), 1)
101101
four := getSizeableValue(c.Get(4))
102102
require.Equal(t, four, 4)
103103

@@ -110,7 +110,7 @@ func TestElementSizeCapacityEvictsEverything(t *testing.T) {
110110
// Insert element with size=capacity of cache.
111111
c.Put(4, &sizeable{value: 4, size: 6})
112112
require.Equal(t, c.Len(), 1)
113-
require.Equal(t, len(c.cache), 1)
113+
require.Equal(t, c.cache.Len(), 1)
114114
four = getSizeableValue(c.Get(4))
115115
require.Equal(t, four, 4)
116116
}
@@ -296,3 +296,94 @@ func TestConcurrencyBigCache(t *testing.T) {
296296

297297
wg.Wait()
298298
}
299+
300+
// TestLoadAndDelete checks the `LoadAndDelete` method.
301+
func TestLoadAndDelete(t *testing.T) {
302+
t.Parallel()
303+
304+
c := NewCache[int, *sizeable](3)
305+
306+
// Create a test item.
307+
item1 := &sizeable{value: 1, size: 1}
308+
309+
// Put the item.
310+
_, err := c.Put(0, item1)
311+
require.NoError(t, err)
312+
313+
// Load the item and check it's returned as expected.
314+
loadedItem, loaded := c.LoadAndDelete(0)
315+
require.True(t, loaded)
316+
require.Equal(t, item1, loadedItem)
317+
318+
// Now check that the item has been deleted.
319+
_, err = c.Get(0)
320+
require.ErrorIs(t, err, cache.ErrElementNotFound)
321+
322+
// Load the item again should give us a nil value and false.
323+
loadedItem, loaded = c.LoadAndDelete(0)
324+
require.False(t, loaded)
325+
require.Nil(t, loadedItem)
326+
327+
// The length should be 0.
328+
require.Zero(t, c.Len())
329+
require.Zero(t, c.size)
330+
}
331+
332+
// TestRangeIteration checks that the `Range` method works as expected.
333+
func TestRangeIteration(t *testing.T) {
334+
t.Parallel()
335+
336+
c := NewCache[int, *sizeable](100)
337+
338+
// Create test items.
339+
const numItems = 10
340+
for i := 0; i < numItems; i++ {
341+
_, err := c.Put(i, &sizeable{value: i, size: 1})
342+
require.NoError(t, err)
343+
}
344+
345+
// Create a dummy visitor that just counts the number of items visited.
346+
visited := 0
347+
testVisitor := func(key int, value *sizeable) bool {
348+
visited++
349+
return true
350+
}
351+
352+
// Call the method.
353+
c.Range(testVisitor)
354+
355+
// Check the number of items visited.
356+
require.Equal(t, numItems, visited)
357+
}
358+
359+
// TestRangeAbort checks that the `Range` will abort when the visitor returns
360+
// false.
361+
func TestRangeAbort(t *testing.T) {
362+
t.Parallel()
363+
364+
c := NewCache[int, *sizeable](100)
365+
366+
// Create test items.
367+
const numItems = 10
368+
for i := 0; i < numItems; i++ {
369+
_, err := c.Put(i, &sizeable{value: i, size: 1})
370+
require.NoError(t, err)
371+
}
372+
373+
// Create a visitor that counts the number of items visited and returns
374+
// false when visited 5 times.
375+
visited := 0
376+
testVisitor := func(key int, value *sizeable) bool {
377+
visited++
378+
if visited >= numItems/2 {
379+
return false
380+
}
381+
return true
382+
}
383+
384+
// Call the method.
385+
c.Range(testVisitor)
386+
387+
// Check the number of items visited.
388+
require.Equal(t, numItems/2, visited)
389+
}

cache/lru/sync_map.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package lru
2+
3+
import "sync"
4+
5+
// syncMap wraps a sync.Map with type parameters such that it's easier to
6+
// access the items stored in the map since no type assertion is needed. It
7+
// also requires explicit type definition when declaring and initiating the
8+
// variables, which helps us understanding what's stored in a given map.
9+
//
10+
// NOTE: this is unexported to avoid confusion with `lnd`'s `SyncMap`.
11+
type syncMap[K comparable, V any] struct {
12+
sync.Map
13+
}
14+
15+
// Store puts an item in the map.
16+
func (m *syncMap[K, V]) Store(key K, value V) {
17+
m.Map.Store(key, value)
18+
}
19+
20+
// Load queries an item from the map using the specified key. If the item
21+
// cannot be found, an empty value and false will be returned. If the stored
22+
// item fails the type assertion, a nil value and false will be returned.
23+
func (m *syncMap[K, V]) Load(key K) (V, bool) {
24+
result, ok := m.Map.Load(key)
25+
if !ok {
26+
return *new(V), false // nolint: gocritic
27+
}
28+
29+
item, ok := result.(V)
30+
return item, ok
31+
}
32+
33+
// Delete removes an item from the map specified by the key.
34+
func (m *syncMap[K, V]) Delete(key K) {
35+
m.Map.Delete(key)
36+
}
37+
38+
// LoadAndDelete queries an item and deletes it from the map using the
39+
// specified key.
40+
func (m *syncMap[K, V]) LoadAndDelete(key K) (V, bool) {
41+
result, loaded := m.Map.LoadAndDelete(key)
42+
if !loaded {
43+
return *new(V), loaded // nolint: gocritic
44+
}
45+
46+
item, ok := result.(V)
47+
return item, ok
48+
}
49+
50+
// Range iterates the map.
51+
func (m *syncMap[K, V]) Range(visitor func(K, V) bool) {
52+
m.Map.Range(func(k any, v any) bool {
53+
return visitor(k.(K), v.(V))
54+
})
55+
}
56+
57+
// Len returns the number of items in the map.
58+
func (m *syncMap[K, V]) Len() int {
59+
var count int
60+
m.Range(func(K, V) bool {
61+
count++
62+
return true
63+
})
64+
65+
return count
66+
}

0 commit comments

Comments
 (0)