Skip to content

Commit fbb7048

Browse files
Rework flight cache into a strongly typed version (#1989)
This reworks the flight cache we already have into a strongly typed version based on a map and `sync.OnceValues`. That allows us to get rid of type assertions and unnecessary error handling in its implementation. Also add some tests for the behavior of flight cache.
1 parent ae2a6ba commit fbb7048

File tree

2 files changed

+136
-38
lines changed

2 files changed

+136
-38
lines changed

pkg/apk/apk/cache.go

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,66 +32,62 @@ import (
3232
"chainguard.dev/apko/pkg/paths"
3333
)
3434

35-
type flightCache[T any] struct {
36-
flight *singleflight.Group
37-
cache *sync.Map
35+
type flightCache[K comparable, V any] struct {
36+
mux sync.RWMutex
37+
cache map[K]func() (V, error)
3838
}
3939

40-
// TODO: Consider [K, V] if we need a non-string key type.
41-
func newFlightCache[T any]() *flightCache[T] {
42-
return &flightCache[T]{
43-
flight: &singleflight.Group{},
44-
cache: &sync.Map{},
40+
func newFlightCache[K comparable, V any]() *flightCache[K, V] {
41+
return &flightCache[K, V]{
42+
cache: make(map[K]func() (V, error)),
4543
}
4644
}
4745

4846
// Do returns coalesces multiple calls, like singleflight, but also caches
4947
// the result if the call is successful. Failures are not cached to avoid
5048
// permanently failing for transient errors.
51-
func (f *flightCache[T]) Do(key string, fn func() (T, error)) (T, error) {
52-
v, ok := f.cache.Load(key)
53-
if ok {
54-
if t, ok := v.(T); ok {
55-
return t, nil
56-
} else {
57-
// This can't happen but just in case things change.
58-
return t, fmt.Errorf("unexpected type %T", v)
59-
}
49+
func (f *flightCache[K, V]) Do(key K, fn func() (V, error)) (V, error) {
50+
f.mux.RLock()
51+
if v, ok := f.cache[key]; ok {
52+
f.mux.RUnlock()
53+
return v()
6054
}
55+
f.mux.RUnlock()
6156

62-
v, err, _ := f.flight.Do(key, func() (any, error) {
63-
if v, ok := f.cache.Load(key); ok {
64-
return v, nil
65-
}
57+
f.mux.Lock()
6658

67-
// Don't cache errors, but maybe we should.
68-
v, err := fn()
69-
if err != nil {
70-
return nil, err
71-
}
59+
// Doubly-checked-locking in case of race conditions.
60+
if v, ok := f.cache[key]; ok {
61+
f.mux.Unlock()
62+
return v()
63+
}
7264

73-
f.cache.Store(key, v)
65+
v := sync.OnceValues(fn)
66+
f.cache[key] = v
7467

75-
return v, nil
76-
})
68+
// Unlock before calling the function to avoid holding the lock for a potentially long time.
69+
f.mux.Unlock()
7770

78-
t, ok := v.(T)
71+
val, err := v()
7972
if err != nil {
80-
return t, err
81-
}
82-
if !ok {
83-
// This can't happen but just in case things change.
84-
return t, fmt.Errorf("unexpected type %T", v)
73+
f.Forget(key)
8574
}
86-
return t, nil
75+
return val, err
76+
}
77+
78+
// Forget removes the given key from the cache.
79+
func (f *flightCache[K, V]) Forget(key K) {
80+
f.mux.Lock()
81+
defer f.mux.Unlock()
82+
delete(f.cache, key)
8783
}
8884

8985
type Cache struct {
9086
etagCache *sync.Map
9187
headFlight *singleflight.Group
9288
getFlight *singleflight.Group
9389

94-
discoverKeys *flightCache[[]Key]
90+
discoverKeys *flightCache[string, []Key]
9591
}
9692

9793
// NewCache returns a new Cache, which allows us to persist the results of HEAD requests
@@ -109,7 +105,7 @@ func NewCache(etag bool) *Cache {
109105
c := &Cache{
110106
headFlight: &singleflight.Group{},
111107
getFlight: &singleflight.Group{},
112-
discoverKeys: newFlightCache[[]Key](),
108+
discoverKeys: newFlightCache[string, []Key](),
113109
}
114110

115111
if etag {

pkg/apk/apk/cache_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright 2025 Chainguard, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package apk
16+
17+
import (
18+
"sync"
19+
"sync/atomic"
20+
"testing"
21+
22+
"github.com/stretchr/testify/assert"
23+
"github.com/stretchr/testify/require"
24+
"golang.org/x/sync/errgroup"
25+
)
26+
27+
func TestFlightCache(t *testing.T) {
28+
s := newFlightCache[string, int]()
29+
var called int
30+
r1, err := s.Do("test", func() (int, error) {
31+
called++
32+
return 42, nil
33+
})
34+
require.NoError(t, err)
35+
require.Equal(t, 42, r1)
36+
37+
r2, err := s.Do("test", func() (int, error) {
38+
called++
39+
return 1337, nil
40+
})
41+
require.NoError(t, err)
42+
require.Equal(t, r1, r2)
43+
require.Equal(t, 1, called, "Function should only be called once")
44+
45+
s.Forget("test")
46+
47+
r3, err := s.Do("test", func() (int, error) {
48+
called++
49+
return 1337, nil
50+
})
51+
require.NoError(t, err)
52+
require.Equal(t, 1337, r3)
53+
require.Equal(t, 2, called, "Function should be called twice, once before and once after Forget")
54+
55+
differentKey, err := s.Do("test2", func() (int, error) {
56+
return 7, nil
57+
})
58+
require.NoError(t, err)
59+
require.Equal(t, 7, differentKey)
60+
}
61+
62+
func TestFlightCacheCachesNoErrors(t *testing.T) {
63+
s := newFlightCache[string, int]()
64+
var called int
65+
_, err := s.Do("test", func() (int, error) {
66+
called++
67+
return 42, assert.AnError
68+
})
69+
require.ErrorIs(t, assert.AnError, err)
70+
71+
r2, err := s.Do("test", func() (int, error) {
72+
called++
73+
return 1337, nil
74+
})
75+
require.NoError(t, err)
76+
require.Equal(t, 1337, r2)
77+
require.Equal(t, 2, called, "Function should be called twice, once for the error and once for the success")
78+
}
79+
80+
func TestFlightCacheCoalescesCalls(t *testing.T) {
81+
s := newFlightCache[string, int]()
82+
83+
var called atomic.Int32
84+
var mux sync.Mutex
85+
mux.Lock() // Lock to ensure the call below hangs until we unlock.
86+
87+
var eg errgroup.Group
88+
for range 10 {
89+
eg.Go(func() error {
90+
_, err := s.Do("test", func() (int, error) {
91+
mux.Lock() // Hangs until the unlock below.
92+
called.Add(1)
93+
return 42, nil
94+
})
95+
return err
96+
})
97+
}
98+
mux.Unlock() // Allow the calls to proceed.
99+
require.NoError(t, eg.Wait())
100+
101+
require.EqualValues(t, 1, called.Load(), "Function should only be called once")
102+
}

0 commit comments

Comments
 (0)