Skip to content

Commit d753f26

Browse files
authored
chore: add metrics and tests to all cache implementations (#2874)
1 parent c2c3bab commit d753f26

File tree

7 files changed

+244
-28
lines changed

7 files changed

+244
-28
lines changed

pkg/cache/cache.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ type Cache[K KeyString, V any] interface {
6868
// Get returns the value for the given key in the cache, if it exists.
6969
Get(key K) (V, bool)
7070

71-
// Set sets a value for the key in the cache, with the given cost.
71+
// GetTTL returns the TTL of entries in the cache.
72+
// If zero is used, entries are not deleted.
73+
GetTTL() time.Duration
74+
75+
// Set is a best-effort attempt to set a value for the key in the cache, with the given cost.
76+
// If GetTTL returns zero, the entry never expires.
77+
// Returns true if the value could be set, false if the cost was too high.
7278
Set(key K, entry V, cost int64) bool
7379

7480
// Wait waits for the cache to process and apply updates.
@@ -78,6 +84,7 @@ type Cache[K KeyString, V any] interface {
7884
Close()
7985

8086
// GetMetrics returns the metrics block for the cache.
87+
// Some implementations may chose to not return some of these metrics.
8188
GetMetrics() Metrics
8289

8390
zerolog.LogObjectMarshaler
@@ -106,6 +113,7 @@ type noopCache[K KeyString, V any] struct{}
106113
var _ Cache[StringKey, any] = (*noopCache[StringKey, any])(nil)
107114

108115
func (no *noopCache[K, V]) Get(_ K) (V, bool) { return *new(V), false }
116+
func (no *noopCache[K, V]) GetTTL() time.Duration { return time.Duration(0) }
109117
func (no *noopCache[K, V]) Set(_ K, _ V, _ int64) bool { return false }
110118
func (no *noopCache[K, V]) Wait() {}
111119
func (no *noopCache[K, V]) Close() {}

pkg/cache/cache_otter.go

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package cache
55
import (
66
"math"
77
"sync/atomic"
8+
"time"
89

910
"github.com/ccoveille/go-safecast/v2"
1011
"github.com/maypok86/otter/v2"
@@ -13,15 +14,20 @@ import (
1314
)
1415

1516
func NewOtterCacheWithMetrics[K KeyString, V any](name string, config *Config) (Cache[K, V], error) {
16-
return NewOtterCache[K, V](config)
17+
cache, err := NewOtterCache[K, V](name, config)
18+
if err != nil {
19+
return nil, err
20+
}
21+
mustRegisterCache(name, cache)
22+
return cache, nil
1723
}
1824

1925
type valueAndCost[V any] struct {
2026
value V
2127
cost uint32
2228
}
2329

24-
func NewOtterCache[K KeyString, V any](config *Config) (Cache[K, V], error) {
30+
func NewOtterCache[K KeyString, V any](name string, config *Config) (Cache[K, V], error) {
2531
uintCost, err := safecast.Convert[uint64](config.MaxCost)
2632
if err != nil {
2733
return nil, err
@@ -41,14 +47,22 @@ func NewOtterCache[K KeyString, V any](config *Config) (Cache[K, V], error) {
4147

4248
cache, err := otter.New(opts)
4349
return &otterCache[K, V]{
50+
name,
4451
cache,
4552
otterMetrics{atomic.Uint64{}, counter},
53+
config.DefaultTTL,
4654
}, err
4755
}
4856

4957
type otterCache[K KeyString, V any] struct {
58+
name string
5059
cache *otter.Cache[string, valueAndCost[V]]
5160
metrics otterMetrics
61+
ttl time.Duration
62+
}
63+
64+
func (wtc *otterCache[K, V]) GetTTL() time.Duration {
65+
return wtc.ttl
5266
}
5367

5468
func (wtc *otterCache[K, V]) Get(key K) (V, bool) {
@@ -67,13 +81,23 @@ func (wtc *otterCache[K, V]) Set(key K, value V, cost int64) bool {
6781
// was too big, so we set to maxint in that case.
6882
uintCost = math.MaxUint32
6983
}
84+
7085
wtc.metrics.costAdded.Add(uint64(uintCost))
86+
_, ok := wtc.Get(key)
87+
if ok {
88+
wtc.cache.Invalidate(key.KeyString())
89+
}
7190
wtc.cache.Set(key.KeyString(), valueAndCost[V]{value, uintCost})
72-
return true // Otter doesn't drop insertions for performance
91+
if wtc.ttl > 0 {
92+
wtc.cache.SetExpiresAfter(key.KeyString(), wtc.ttl)
93+
}
94+
return true
7395
}
7496

75-
func (wtc *otterCache[K, V]) Wait() {}
76-
func (wtc *otterCache[K, V]) Close() {}
97+
func (wtc *otterCache[K, V]) Wait() {}
98+
func (wtc *otterCache[K, V]) Close() {
99+
unregisterCache(wtc.name)
100+
}
77101

78102
type otterMetrics struct {
79103
costAdded atomic.Uint64

pkg/cache/cache_ristretto.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,32 @@ func NewRistrettoCacheWithMetrics[K KeyString, V any](name string, config *Confi
4343
return nil, err
4444
}
4545

46-
cache := wrapped[K, V]{name, config, config.DefaultTTL, rcache}
46+
cache := ristretoCache[K, V]{name, config, config.DefaultTTL, rcache}
4747
mustRegisterCache(name, cache)
4848
return &cache, nil
4949
}
5050

5151
// NewRistrettoCache creates a new ristretto cache from the given config.
5252
func NewRistrettoCache[K KeyString, V any](config *Config) (Cache[K, V], error) {
5353
rcache, err := ristretto.NewCache(ristrettoConfig(config))
54-
return &wrapped[K, V]{"", config, config.DefaultTTL, rcache}, err
54+
return &ristretoCache[K, V]{"", config, config.DefaultTTL, rcache}, err
5555
}
5656

57-
type wrapped[K any, V any] struct {
57+
type ristretoCache[K any, V any] struct {
5858
name string
5959
config *Config
6060
defaultTTL time.Duration
6161
ristretto *ristretto.Cache
6262
}
6363

64-
func (w wrapped[K, V]) Set(key K, entry V, cost int64) bool {
64+
func (w ristretoCache[K, V]) Set(key K, entry V, cost int64) bool {
6565
if w.defaultTTL <= 0 {
6666
return w.ristretto.Set(key, entry, cost)
6767
}
6868
return w.ristretto.SetWithTTL(key, entry, cost, w.defaultTTL)
6969
}
7070

71-
func (w wrapped[K, V]) Get(key K) (V, bool) {
71+
func (w ristretoCache[K, V]) Get(key K) (V, bool) {
7272
found, ok := w.ristretto.Get(key)
7373
if !ok {
7474
return *new(V), false
@@ -77,16 +77,20 @@ func (w wrapped[K, V]) Get(key K) (V, bool) {
7777
return found.(V), true
7878
}
7979

80-
func (w wrapped[K, V]) Wait() {
80+
func (w ristretoCache[K, V]) Wait() {
8181
w.ristretto.Wait()
8282
}
8383

84-
var _ Cache[StringKey, any] = (*wrapped[StringKey, any])(nil)
84+
func (w ristretoCache[K, V]) GetTTL() time.Duration {
85+
return w.defaultTTL
86+
}
87+
88+
var _ Cache[StringKey, any] = (*ristretoCache[StringKey, any])(nil)
8589

86-
func (w wrapped[K, V]) GetMetrics() Metrics { return w.ristretto.Metrics }
87-
func (w wrapped[K, V]) MarshalZerologObject(e *zerolog.Event) { e.EmbedObject(w.config) }
90+
func (w ristretoCache[K, V]) GetMetrics() Metrics { return w.ristretto.Metrics }
91+
func (w ristretoCache[K, V]) MarshalZerologObject(e *zerolog.Event) { e.EmbedObject(w.config) }
8892

89-
func (w wrapped[K, V]) Close() {
93+
func (w ristretoCache[K, V]) Close() {
9094
w.ristretto.Close()
9195
unregisterCache(w.name)
9296
}

pkg/cache/cache_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
//go:build !wasm
2+
3+
package cache
4+
5+
import (
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestCacheWithMetrics(t *testing.T) {
13+
t.Parallel()
14+
15+
config := &Config{
16+
NumCounters: 10000,
17+
MaxCost: 1000,
18+
DefaultTTL: 10 * time.Hour,
19+
}
20+
21+
t.Run("otter", func(t *testing.T) {
22+
t.Parallel()
23+
testCacheImplementation(t, func() (Cache[StringKey, string], error) {
24+
return NewOtterCacheWithMetrics[StringKey, string]("test-otter", config)
25+
})
26+
})
27+
28+
t.Run("ristretto", func(t *testing.T) {
29+
t.Parallel()
30+
testCacheImplementation(t, func() (Cache[StringKey, string], error) {
31+
// Use the metrics version for proper metrics tracking
32+
return NewRistrettoCacheWithMetrics[StringKey, string]("test-ristretto", config)
33+
})
34+
})
35+
36+
t.Run("theine", func(t *testing.T) {
37+
t.Parallel()
38+
testCacheImplementation(t, func() (Cache[StringKey, string], error) {
39+
return NewTheineCacheWithMetrics[StringKey, string]("test-theine", config)
40+
})
41+
})
42+
}
43+
44+
func testCacheImplementation(t *testing.T, factory func() (Cache[StringKey, string], error)) {
45+
t.Run("Set and Get", func(t *testing.T) {
46+
cache, err := factory()
47+
require.NoError(t, err)
48+
defer cache.Close()
49+
50+
// Set multiple entries
51+
entries := []struct {
52+
key StringKey
53+
value string
54+
}{
55+
{"key1", "value1"},
56+
{"key2", "value2"},
57+
{"key3", "value3"},
58+
}
59+
60+
for _, entry := range entries {
61+
ok := cache.Set(entry.key, entry.value, 10)
62+
require.True(t, ok)
63+
}
64+
65+
// Wait for all sets to be processed
66+
cache.Wait()
67+
68+
// Verify all entries
69+
for _, entry := range entries {
70+
retrieved, found := cache.Get(entry.key)
71+
require.True(t, found, "expected key %s to be found", entry.key)
72+
require.Equal(t, entry.value, retrieved, "expected value for key %s to match", entry.key)
73+
}
74+
})
75+
76+
t.Run("Set same key with diff values", func(t *testing.T) {
77+
cache, err := factory()
78+
require.NoError(t, err)
79+
defer cache.Close()
80+
81+
ok := cache.Set(StringKey("metric-key-1"), "value1", 10)
82+
require.True(t, ok)
83+
cache.Wait()
84+
val, found := cache.Get("metric-key-1")
85+
require.True(t, found)
86+
require.Equal(t, "value1", val)
87+
88+
// same key set, diff value
89+
ok = cache.Set(StringKey("metric-key-1"), "value2", 10)
90+
require.True(t, ok)
91+
cache.Wait()
92+
val, found = cache.Get("metric-key-1")
93+
require.True(t, found)
94+
require.Equal(t, "value2", val)
95+
})
96+
97+
t.Run("Close multiple times", func(t *testing.T) {
98+
cache, err := factory()
99+
require.NoError(t, err)
100+
101+
for i := 0; i < 10; i++ {
102+
cache.Close()
103+
}
104+
})
105+
106+
t.Run("GetTTL", func(t *testing.T) {
107+
cache, err := factory()
108+
require.NoError(t, err)
109+
defer cache.Close()
110+
111+
require.Equal(t, 10*time.Hour, cache.GetTTL())
112+
})
113+
114+
t.Run("GetMetrics", func(t *testing.T) {
115+
cache, err := factory()
116+
require.NoError(t, err)
117+
defer cache.Close()
118+
119+
// Set some values
120+
ok := cache.Set(StringKey("metric-key-1"), "value1", 10)
121+
require.True(t, ok)
122+
ok = cache.Set(StringKey("metric-key-2"), "value2", 20)
123+
require.True(t, ok)
124+
125+
cache.Wait()
126+
127+
// Perform some gets (hits and misses)
128+
_, ok = cache.Get(StringKey("metric-key-1")) // hit
129+
require.True(t, ok)
130+
_, ok = cache.Get(StringKey("metric-key-2")) // hit
131+
require.True(t, ok)
132+
_, ok = cache.Get(StringKey("non-existent")) // miss
133+
require.False(t, ok)
134+
135+
metrics := cache.GetMetrics()
136+
require.NotNil(t, metrics, "expected metrics to be available")
137+
138+
// Verify hits and misses are tracked
139+
hits := metrics.Hits()
140+
misses := metrics.Misses()
141+
require.GreaterOrEqual(t, hits, uint64(1), "expected at least one hit")
142+
require.GreaterOrEqual(t, misses, uint64(1), "expected at least one miss")
143+
144+
// Verify cost tracking
145+
costAdded := metrics.CostAdded()
146+
require.GreaterOrEqual(t, costAdded, uint64(10), "expected cost to be tracked")
147+
})
148+
}

0 commit comments

Comments
 (0)