Skip to content

Commit d4951f0

Browse files
committed
flightcontrol: add cached group support
Signed-off-by: Tonis Tiigi <[email protected]>
1 parent bd5d50e commit d4951f0

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

util/flightcontrol/cached.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package flightcontrol
2+
3+
import (
4+
"context"
5+
"sync"
6+
7+
"github.com/pkg/errors"
8+
)
9+
10+
// Group is a flightcontrol synchronization group that memoizes the results of a function
11+
// and returns the cached result if the function is called with the same key.
12+
// Don't use with long-running groups as the results are cached indefinitely.
13+
type CachedGroup[T any] struct {
14+
// CacheError defines if error results should also be cached.
15+
// It is not safe to change this value after the first call to Do.
16+
// Context cancellation errors are never cached.
17+
CacheError bool
18+
g Group[T]
19+
mu sync.Mutex
20+
cache map[string]result[T]
21+
}
22+
23+
type result[T any] struct {
24+
v T
25+
err error
26+
}
27+
28+
// Do executes a context function syncronized by the key or returns the cached result for the key.
29+
func (g *CachedGroup[T]) Do(ctx context.Context, key string, fn func(ctx context.Context) (T, error)) (T, error) {
30+
return g.g.Do(ctx, key, func(ctx context.Context) (T, error) {
31+
g.mu.Lock()
32+
if v, ok := g.cache[key]; ok {
33+
g.mu.Unlock()
34+
if v.err != nil {
35+
if g.CacheError {
36+
return v.v, v.err
37+
}
38+
} else {
39+
return v.v, nil
40+
}
41+
}
42+
g.mu.Unlock()
43+
v, err := fn(ctx)
44+
if err != nil {
45+
select {
46+
case <-ctx.Done():
47+
if errors.Is(err, context.Cause(ctx)) {
48+
return v, err
49+
}
50+
default:
51+
}
52+
}
53+
if err == nil || g.CacheError {
54+
g.mu.Lock()
55+
if g.cache == nil {
56+
g.cache = make(map[string]result[T])
57+
}
58+
g.cache[key] = result[T]{v: v, err: err}
59+
g.mu.Unlock()
60+
}
61+
return v, err
62+
})
63+
}

util/flightcontrol/cached_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package flightcontrol
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/pkg/errors"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestCached(t *testing.T) {
13+
var g CachedGroup[int]
14+
15+
ctx := context.TODO()
16+
17+
v, err := g.Do(ctx, "11", func(ctx context.Context) (int, error) {
18+
return 1, nil
19+
})
20+
require.NoError(t, err)
21+
require.Equal(t, 1, v)
22+
23+
v, err = g.Do(ctx, "22", func(ctx context.Context) (int, error) {
24+
return 2, nil
25+
})
26+
require.NoError(t, err)
27+
require.Equal(t, 2, v)
28+
29+
didCall := false
30+
v, err = g.Do(ctx, "11", func(ctx context.Context) (int, error) {
31+
didCall = true
32+
return 3, nil
33+
})
34+
require.NoError(t, err)
35+
require.Equal(t, 1, v)
36+
require.Equal(t, false, didCall)
37+
38+
// by default, errors are not cached
39+
_, err = g.Do(ctx, "33", func(ctx context.Context) (int, error) {
40+
return 0, errors.Errorf("some error")
41+
})
42+
43+
require.Error(t, err)
44+
require.ErrorContains(t, err, "some error")
45+
46+
v, err = g.Do(ctx, "33", func(ctx context.Context) (int, error) {
47+
return 3, nil
48+
})
49+
50+
require.NoError(t, err)
51+
require.Equal(t, 3, v)
52+
}
53+
54+
func TestCachedError(t *testing.T) {
55+
var g CachedGroup[string]
56+
g.CacheError = true
57+
58+
ctx := context.TODO()
59+
60+
_, err := g.Do(ctx, "11", func(ctx context.Context) (string, error) {
61+
return "", errors.Errorf("first error")
62+
})
63+
require.Error(t, err)
64+
require.ErrorContains(t, err, "first error")
65+
66+
_, err = g.Do(ctx, "11", func(ctx context.Context) (string, error) {
67+
return "never-ran", nil
68+
})
69+
require.Error(t, err)
70+
require.ErrorContains(t, err, "first error")
71+
72+
// context errors are never cached
73+
ctx, cancel := context.WithTimeoutCause(context.TODO(), 10*time.Millisecond, nil)
74+
defer cancel()
75+
_, err = g.Do(ctx, "22", func(ctx context.Context) (string, error) {
76+
select {
77+
case <-ctx.Done():
78+
return "", context.Cause(ctx)
79+
case <-time.After(10 * time.Second):
80+
return "", errors.Errorf("unexpected error")
81+
}
82+
})
83+
require.Error(t, err)
84+
require.ErrorContains(t, err, "context deadline exceeded")
85+
86+
select {
87+
case <-ctx.Done():
88+
default:
89+
require.Fail(t, "expected context to be done")
90+
}
91+
92+
v, err := g.Do(ctx, "22", func(ctx context.Context) (string, error) {
93+
return "did-run", nil
94+
})
95+
require.NoError(t, err)
96+
require.Equal(t, "did-run", v)
97+
}

0 commit comments

Comments
 (0)