Skip to content

Commit bfd368b

Browse files
authored
Fix CachingFactory concurrent usage issues (#1707)
Signed-off-by: andriikushch <[email protected]>
1 parent fdde350 commit bfd368b

File tree

6 files changed

+174
-62
lines changed

6 files changed

+174
-62
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ require (
2727
github.com/r3labs/diff/v3 v3.0.1
2828
github.com/stretchr/testify v1.10.0
2929
github.com/urfave/cli/v2 v2.27.7
30+
go.uber.org/atomic v1.11.0
3031
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
3132
golang.org/x/sync v0.15.0
3233
gopkg.in/yaml.v2 v2.4.0

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8
111111
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
112112
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
113113
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
114+
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
115+
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
114116
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA=
115117
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
116118
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=

pkg/clients/v1/factory.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import (
4646
"github.com/aws/aws-sdk-go/service/storagegateway/storagegatewayiface"
4747
"github.com/aws/aws-sdk-go/service/sts"
4848
"github.com/aws/aws-sdk-go/service/sts/stsiface"
49+
"go.uber.org/atomic"
4950

5051
"github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients"
5152
"github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/account"
@@ -64,8 +65,8 @@ type CachingFactory struct {
6465
stscache map[model.Role]stsiface.STSAPI
6566
iamcache map[model.Role]iamiface.IAMAPI
6667
clients map[model.Role]map[string]*cachedClients
67-
cleared bool
68-
refreshed bool
68+
cleared *atomic.Bool
69+
refreshed *atomic.Bool
6970
mu sync.Mutex
7071
fips bool
7172
logger *slog.Logger
@@ -175,16 +176,21 @@ func NewFactory(logger *slog.Logger, jobsCfg model.JobsConfig, fips bool) *Cachi
175176
iamcache: iamcache,
176177
clients: cache,
177178
fips: fips,
178-
cleared: false,
179-
refreshed: false,
179+
cleared: atomic.NewBool(false),
180+
refreshed: atomic.NewBool(false),
180181
logger: logger,
181182
}
182183
}
183184

184-
// Refresh and Clear help to avoid using lock primitives by asserting that
185-
// there are no ongoing writes to the map.
186185
func (c *CachingFactory) Clear() {
187-
if c.cleared {
186+
if c.cleared.Load() {
187+
return
188+
}
189+
190+
c.mu.Lock()
191+
defer c.mu.Unlock()
192+
193+
if c.cleared.Load() {
188194
return
189195
}
190196

@@ -204,19 +210,19 @@ func (c *CachingFactory) Clear() {
204210
cachedClient.tagging = nil
205211
}
206212
}
207-
c.cleared = true
208-
c.refreshed = false
213+
c.cleared.Store(true)
214+
c.refreshed.Store(false)
209215
}
210216

211217
func (c *CachingFactory) Refresh() {
212-
if c.refreshed {
218+
if c.refreshed.Load() {
213219
return
214220
}
215221

216222
c.mu.Lock()
217223
defer c.mu.Unlock()
218224
// Double check Refresh wasn't called concurrently
219-
if c.refreshed {
225+
if c.refreshed.Load() {
220226
return
221227
}
222228

@@ -248,8 +254,8 @@ func (c *CachingFactory) Refresh() {
248254
}
249255
}
250256

251-
c.cleared = false
252-
c.refreshed = true
257+
c.cleared.Store(false)
258+
c.refreshed.Store(true)
253259
}
254260

255261
func createCloudWatchClient(logger *slog.Logger, s *session.Session, region *string, role model.Role, fips bool) cloudwatch_client.Client {
@@ -282,7 +288,7 @@ func createAccountClient(logger *slog.Logger, sts stsiface.STSAPI, iam iamiface.
282288
}
283289

284290
func (c *CachingFactory) GetCloudwatchClient(region string, role model.Role, concurrency cloudwatch_client.ConcurrencyConfig) cloudwatch_client.Client {
285-
if !c.refreshed {
291+
if !c.refreshed.Load() {
286292
// if we have not refreshed then we need to lock in case we are accessing concurrently
287293
c.mu.Lock()
288294
defer c.mu.Unlock()
@@ -295,7 +301,7 @@ func (c *CachingFactory) GetCloudwatchClient(region string, role model.Role, con
295301
}
296302

297303
func (c *CachingFactory) GetTaggingClient(region string, role model.Role, concurrencyLimit int) tagging.Client {
298-
if !c.refreshed {
304+
if !c.refreshed.Load() {
299305
// if we have not refreshed then we need to lock in case we are accessing concurrently
300306
c.mu.Lock()
301307
defer c.mu.Unlock()
@@ -308,7 +314,7 @@ func (c *CachingFactory) GetTaggingClient(region string, role model.Role, concur
308314
}
309315

310316
func (c *CachingFactory) GetAccountClient(region string, role model.Role) account.Client {
311-
if !c.refreshed {
317+
if !c.refreshed.Load() {
312318
// if we have not refreshed then we need to lock in case we are accessing concurrently
313319
c.mu.Lock()
314320
defer c.mu.Unlock()

pkg/clients/v1/factory_test.go

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"fmt"
1717
"sync"
1818
"testing"
19+
"time"
1920

2021
"github.com/aws/aws-sdk-go/aws"
2122
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
@@ -25,6 +26,7 @@ import (
2526
"github.com/aws/aws-sdk-go/service/sts/stsiface"
2627
"github.com/prometheus/common/promslog"
2728
"github.com/stretchr/testify/require"
29+
"go.uber.org/atomic"
2830

2931
"github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch"
3032
"github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model"
@@ -82,15 +84,21 @@ func TestNewClientCache(t *testing.T) {
8284
"an empty config gives an empty cache",
8385
model.JobsConfig{},
8486
false,
85-
&CachingFactory{logger: promslog.NewNopLogger()},
87+
&CachingFactory{
88+
logger: promslog.NewNopLogger(),
89+
refreshed: atomic.NewBool(false),
90+
cleared: atomic.NewBool(false),
91+
},
8692
},
8793
{
8894
"if fips is set then the clients has fips",
8995
model.JobsConfig{},
9096
true,
9197
&CachingFactory{
92-
fips: true,
93-
logger: promslog.NewNopLogger(),
98+
fips: true,
99+
logger: promslog.NewNopLogger(),
100+
refreshed: atomic.NewBool(false),
101+
cleared: atomic.NewBool(false),
94102
},
95103
},
96104
{
@@ -153,7 +161,9 @@ func TestNewClientCache(t *testing.T) {
153161
"ap-northeast-3": &cachedClients{},
154162
},
155163
},
156-
logger: promslog.NewNopLogger(),
164+
logger: promslog.NewNopLogger(),
165+
refreshed: atomic.NewBool(false),
166+
cleared: atomic.NewBool(false),
157167
},
158168
},
159169
{
@@ -239,7 +249,9 @@ func TestNewClientCache(t *testing.T) {
239249
"ap-northeast-1": &cachedClients{onlyStatic: true},
240250
},
241251
},
242-
logger: promslog.NewNopLogger(),
252+
logger: promslog.NewNopLogger(),
253+
refreshed: atomic.NewBool(false),
254+
cleared: atomic.NewBool(false),
243255
},
244256
},
245257
{
@@ -362,7 +374,9 @@ func TestNewClientCache(t *testing.T) {
362374
"ap-northeast-3": &cachedClients{},
363375
},
364376
},
365-
logger: promslog.NewNopLogger(),
377+
logger: promslog.NewNopLogger(),
378+
refreshed: atomic.NewBool(false),
379+
cleared: atomic.NewBool(false),
366380
},
367381
},
368382
{
@@ -451,7 +465,9 @@ func TestNewClientCache(t *testing.T) {
451465
"ap-northeast-1": &cachedClients{onlyStatic: true},
452466
},
453467
},
454-
logger: promslog.NewNopLogger(),
468+
logger: promslog.NewNopLogger(),
469+
refreshed: atomic.NewBool(false),
470+
cleared: atomic.NewBool(false),
455471
},
456472
},
457473
}
@@ -463,12 +479,12 @@ func TestNewClientCache(t *testing.T) {
463479
cache := NewFactory(promslog.NewNopLogger(), test.jobsCfg, test.fips)
464480
t.Logf("the cache is: %v", cache)
465481

466-
if test.cache.cleared != cache.cleared {
482+
if test.cache.cleared.Load() != cache.cleared.Load() {
467483
t.Logf("`cleared` not equal got %v, expected %v", cache.cleared, test.cache.cleared)
468484
t.Fail()
469485
}
470486

471-
if test.cache.refreshed != cache.refreshed {
487+
if test.cache.refreshed.Load() != cache.refreshed.Load() {
472488
t.Logf("`refreshed` not equal got %v, expected %v", cache.refreshed, test.cache.refreshed)
473489
t.Fail()
474490
}
@@ -497,7 +513,7 @@ func TestClear(t *testing.T) {
497513
"a new clear clears all clients",
498514
&CachingFactory{
499515
session: mock.Session,
500-
cleared: false,
516+
cleared: atomic.NewBool(false),
501517
mu: sync.Mutex{},
502518
stscache: map[model.Role]stsiface.STSAPI{
503519
{}: nil,
@@ -512,13 +528,14 @@ func TestClear(t *testing.T) {
512528
},
513529
},
514530
},
515-
logger: promslog.NewNopLogger(),
531+
logger: promslog.NewNopLogger(),
532+
refreshed: atomic.NewBool(false),
516533
},
517534
},
518535
{
519536
"A second call to clear does nothing",
520537
&CachingFactory{
521-
cleared: true,
538+
cleared: atomic.NewBool(true),
522539
mu: sync.Mutex{},
523540
session: mock.Session,
524541
stscache: map[model.Role]stsiface.STSAPI{
@@ -533,7 +550,8 @@ func TestClear(t *testing.T) {
533550
},
534551
},
535552
},
536-
logger: promslog.NewNopLogger(),
553+
logger: promslog.NewNopLogger(),
554+
refreshed: atomic.NewBool(false),
537555
},
538556
},
539557
}
@@ -542,11 +560,11 @@ func TestClear(t *testing.T) {
542560
test := l
543561
t.Run(test.description, func(t *testing.T) {
544562
test.cache.Clear()
545-
if !test.cache.cleared {
563+
if !test.cache.cleared.Load() {
546564
t.Log("Cache cleared flag not set")
547565
t.Fail()
548566
}
549-
if test.cache.refreshed {
567+
if test.cache.refreshed.Load() {
550568
t.Log("Cache cleared flag set")
551569
t.Fail()
552570
}
@@ -591,7 +609,8 @@ func TestRefresh(t *testing.T) {
591609
"a new refresh creates clients",
592610
&CachingFactory{
593611
session: mock.Session,
594-
refreshed: false,
612+
refreshed: atomic.NewBool(false),
613+
cleared: atomic.NewBool(false),
595614
mu: sync.Mutex{},
596615
stscache: map[model.Role]stsiface.STSAPI{
597616
{}: nil,
@@ -613,7 +632,8 @@ func TestRefresh(t *testing.T) {
613632
"a new refresh with static only creates only cloudwatch",
614633
&CachingFactory{
615634
session: mock.Session,
616-
refreshed: false,
635+
refreshed: atomic.NewBool(false),
636+
cleared: atomic.NewBool(false),
617637
mu: sync.Mutex{},
618638
stscache: map[model.Role]stsiface.STSAPI{
619639
{}: nil,
@@ -635,7 +655,8 @@ func TestRefresh(t *testing.T) {
635655
{
636656
"A second call to refreshed does nothing",
637657
&CachingFactory{
638-
refreshed: true,
658+
refreshed: atomic.NewBool(true),
659+
cleared: atomic.NewBool(false),
639660
mu: sync.Mutex{},
640661
session: mock.Session,
641662
stscache: map[model.Role]stsiface.STSAPI{
@@ -662,12 +683,12 @@ func TestRefresh(t *testing.T) {
662683
t.Parallel()
663684
test.cache.Refresh()
664685

665-
if !test.cache.refreshed {
686+
if !test.cache.refreshed.Load() {
666687
t.Log("Cache refreshed flag not set")
667688
t.Fail()
668689
}
669690

670-
if test.cache.cleared {
691+
if test.cache.cleared.Load() {
671692
t.Log("Cache cleared flag set")
672693
t.Fail()
673694
}
@@ -753,7 +774,7 @@ func testGetAWSClient(
753774
{
754775
"locks during unrefreshed parallel call",
755776
&CachingFactory{
756-
refreshed: false,
777+
refreshed: atomic.NewBool(false),
757778
mu: sync.Mutex{},
758779
session: mock.Session,
759780
stscache: map[model.Role]stsiface.STSAPI{
@@ -775,7 +796,7 @@ func testGetAWSClient(
775796
{
776797
"returns clients if available",
777798
&CachingFactory{
778-
refreshed: true,
799+
refreshed: atomic.NewBool(true),
779800
session: mock.Session,
780801
mu: sync.Mutex{},
781802
stscache: map[model.Role]stsiface.STSAPI{
@@ -797,7 +818,7 @@ func testGetAWSClient(
797818
{
798819
"creates a new clients if not available",
799820
&CachingFactory{
800-
refreshed: true,
821+
refreshed: atomic.NewBool(true),
801822
session: mock.Session,
802823
mu: sync.Mutex{},
803824
stscache: map[model.Role]stsiface.STSAPI{
@@ -1150,6 +1171,45 @@ func TestSTSResolvesFIPSEnabledEndpoints(t *testing.T) {
11501171
}
11511172
}
11521173

1174+
func TestRaceConditionRefreshClear(t *testing.T) {
1175+
t.Parallel()
1176+
1177+
// Create a factory with the test config
1178+
factory := NewFactory(promslog.NewNopLogger(), model.JobsConfig{}, false)
1179+
1180+
// Number of concurrent operations to perform
1181+
iterations := 100
1182+
1183+
// Use WaitGroup to synchronize goroutines
1184+
var wg sync.WaitGroup
1185+
wg.Add(iterations) // For both Refresh and Clear calls
1186+
1187+
// Start function to run concurrent operations
1188+
for i := 0; i < iterations; i++ {
1189+
// Launch goroutine to call Refresh
1190+
go func() {
1191+
defer wg.Done()
1192+
factory.Refresh()
1193+
factory.Clear()
1194+
}()
1195+
}
1196+
1197+
// Create a channel to signal completion
1198+
done := make(chan struct{})
1199+
go func() {
1200+
wg.Wait()
1201+
close(done)
1202+
}()
1203+
1204+
// Wait for either completion or timeout
1205+
select {
1206+
case <-done:
1207+
// Test completed successfully
1208+
case <-time.After(60 * time.Second):
1209+
require.Fail(t, "Test timed out after 60 seconds")
1210+
}
1211+
}
1212+
11531213
func testAWSClient(
11541214
t *testing.T,
11551215
name string,

0 commit comments

Comments
 (0)