@@ -16,6 +16,7 @@ import (
1616 "fmt"
1717 "sync"
1818 "testing"
19+ "time"
1920
2021 "github.com/aws/aws-sdk-go/aws"
2122 "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
@@ -25,6 +26,7 @@ import (
2526 "github.com/aws/aws-sdk-go/service/sts/stsiface"
2627 "github.com/prometheus/common/promslog"
2728 "github.com/stretchr/testify/require"
29+ "go.uber.org/atomic"
2830
2931 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/clients/cloudwatch"
3032 "github.com/prometheus-community/yet-another-cloudwatch-exporter/pkg/model"
@@ -82,15 +84,21 @@ func TestNewClientCache(t *testing.T) {
8284 "an empty config gives an empty cache" ,
8385 model.JobsConfig {},
8486 false ,
85- & CachingFactory {logger : promslog .NewNopLogger ()},
87+ & CachingFactory {
88+ logger : promslog .NewNopLogger (),
89+ refreshed : atomic .NewBool (false ),
90+ cleared : atomic .NewBool (false ),
91+ },
8692 },
8793 {
8894 "if fips is set then the clients has fips" ,
8995 model.JobsConfig {},
9096 true ,
9197 & CachingFactory {
92- fips : true ,
93- logger : promslog .NewNopLogger (),
98+ fips : true ,
99+ logger : promslog .NewNopLogger (),
100+ refreshed : atomic .NewBool (false ),
101+ cleared : atomic .NewBool (false ),
94102 },
95103 },
96104 {
@@ -153,7 +161,9 @@ func TestNewClientCache(t *testing.T) {
153161 "ap-northeast-3" : & cachedClients {},
154162 },
155163 },
156- logger : promslog .NewNopLogger (),
164+ logger : promslog .NewNopLogger (),
165+ refreshed : atomic .NewBool (false ),
166+ cleared : atomic .NewBool (false ),
157167 },
158168 },
159169 {
@@ -239,7 +249,9 @@ func TestNewClientCache(t *testing.T) {
239249 "ap-northeast-1" : & cachedClients {onlyStatic : true },
240250 },
241251 },
242- logger : promslog .NewNopLogger (),
252+ logger : promslog .NewNopLogger (),
253+ refreshed : atomic .NewBool (false ),
254+ cleared : atomic .NewBool (false ),
243255 },
244256 },
245257 {
@@ -362,7 +374,9 @@ func TestNewClientCache(t *testing.T) {
362374 "ap-northeast-3" : & cachedClients {},
363375 },
364376 },
365- logger : promslog .NewNopLogger (),
377+ logger : promslog .NewNopLogger (),
378+ refreshed : atomic .NewBool (false ),
379+ cleared : atomic .NewBool (false ),
366380 },
367381 },
368382 {
@@ -451,7 +465,9 @@ func TestNewClientCache(t *testing.T) {
451465 "ap-northeast-1" : & cachedClients {onlyStatic : true },
452466 },
453467 },
454- logger : promslog .NewNopLogger (),
468+ logger : promslog .NewNopLogger (),
469+ refreshed : atomic .NewBool (false ),
470+ cleared : atomic .NewBool (false ),
455471 },
456472 },
457473 }
@@ -463,12 +479,12 @@ func TestNewClientCache(t *testing.T) {
463479 cache := NewFactory (promslog .NewNopLogger (), test .jobsCfg , test .fips )
464480 t .Logf ("the cache is: %v" , cache )
465481
466- if test .cache .cleared != cache .cleared {
482+ if test .cache .cleared . Load () != cache .cleared . Load () {
467483 t .Logf ("`cleared` not equal got %v, expected %v" , cache .cleared , test .cache .cleared )
468484 t .Fail ()
469485 }
470486
471- if test .cache .refreshed != cache .refreshed {
487+ if test .cache .refreshed . Load () != cache .refreshed . Load () {
472488 t .Logf ("`refreshed` not equal got %v, expected %v" , cache .refreshed , test .cache .refreshed )
473489 t .Fail ()
474490 }
@@ -497,7 +513,7 @@ func TestClear(t *testing.T) {
497513 "a new clear clears all clients" ,
498514 & CachingFactory {
499515 session : mock .Session ,
500- cleared : false ,
516+ cleared : atomic . NewBool ( false ) ,
501517 mu : sync.Mutex {},
502518 stscache : map [model.Role ]stsiface.STSAPI {
503519 {}: nil ,
@@ -512,13 +528,14 @@ func TestClear(t *testing.T) {
512528 },
513529 },
514530 },
515- logger : promslog .NewNopLogger (),
531+ logger : promslog .NewNopLogger (),
532+ refreshed : atomic .NewBool (false ),
516533 },
517534 },
518535 {
519536 "A second call to clear does nothing" ,
520537 & CachingFactory {
521- cleared : true ,
538+ cleared : atomic . NewBool ( true ) ,
522539 mu : sync.Mutex {},
523540 session : mock .Session ,
524541 stscache : map [model.Role ]stsiface.STSAPI {
@@ -533,7 +550,8 @@ func TestClear(t *testing.T) {
533550 },
534551 },
535552 },
536- logger : promslog .NewNopLogger (),
553+ logger : promslog .NewNopLogger (),
554+ refreshed : atomic .NewBool (false ),
537555 },
538556 },
539557 }
@@ -542,11 +560,11 @@ func TestClear(t *testing.T) {
542560 test := l
543561 t .Run (test .description , func (t * testing.T ) {
544562 test .cache .Clear ()
545- if ! test .cache .cleared {
563+ if ! test .cache .cleared . Load () {
546564 t .Log ("Cache cleared flag not set" )
547565 t .Fail ()
548566 }
549- if test .cache .refreshed {
567+ if test .cache .refreshed . Load () {
550568 t .Log ("Cache cleared flag set" )
551569 t .Fail ()
552570 }
@@ -591,7 +609,8 @@ func TestRefresh(t *testing.T) {
591609 "a new refresh creates clients" ,
592610 & CachingFactory {
593611 session : mock .Session ,
594- refreshed : false ,
612+ refreshed : atomic .NewBool (false ),
613+ cleared : atomic .NewBool (false ),
595614 mu : sync.Mutex {},
596615 stscache : map [model.Role ]stsiface.STSAPI {
597616 {}: nil ,
@@ -613,7 +632,8 @@ func TestRefresh(t *testing.T) {
613632 "a new refresh with static only creates only cloudwatch" ,
614633 & CachingFactory {
615634 session : mock .Session ,
616- refreshed : false ,
635+ refreshed : atomic .NewBool (false ),
636+ cleared : atomic .NewBool (false ),
617637 mu : sync.Mutex {},
618638 stscache : map [model.Role ]stsiface.STSAPI {
619639 {}: nil ,
@@ -635,7 +655,8 @@ func TestRefresh(t *testing.T) {
635655 {
636656 "A second call to refreshed does nothing" ,
637657 & CachingFactory {
638- refreshed : true ,
658+ refreshed : atomic .NewBool (true ),
659+ cleared : atomic .NewBool (false ),
639660 mu : sync.Mutex {},
640661 session : mock .Session ,
641662 stscache : map [model.Role ]stsiface.STSAPI {
@@ -662,12 +683,12 @@ func TestRefresh(t *testing.T) {
662683 t .Parallel ()
663684 test .cache .Refresh ()
664685
665- if ! test .cache .refreshed {
686+ if ! test .cache .refreshed . Load () {
666687 t .Log ("Cache refreshed flag not set" )
667688 t .Fail ()
668689 }
669690
670- if test .cache .cleared {
691+ if test .cache .cleared . Load () {
671692 t .Log ("Cache cleared flag set" )
672693 t .Fail ()
673694 }
@@ -753,7 +774,7 @@ func testGetAWSClient(
753774 {
754775 "locks during unrefreshed parallel call" ,
755776 & CachingFactory {
756- refreshed : false ,
777+ refreshed : atomic . NewBool ( false ) ,
757778 mu : sync.Mutex {},
758779 session : mock .Session ,
759780 stscache : map [model.Role ]stsiface.STSAPI {
@@ -775,7 +796,7 @@ func testGetAWSClient(
775796 {
776797 "returns clients if available" ,
777798 & CachingFactory {
778- refreshed : true ,
799+ refreshed : atomic . NewBool ( true ) ,
779800 session : mock .Session ,
780801 mu : sync.Mutex {},
781802 stscache : map [model.Role ]stsiface.STSAPI {
@@ -797,7 +818,7 @@ func testGetAWSClient(
797818 {
798819 "creates a new clients if not available" ,
799820 & CachingFactory {
800- refreshed : true ,
821+ refreshed : atomic . NewBool ( true ) ,
801822 session : mock .Session ,
802823 mu : sync.Mutex {},
803824 stscache : map [model.Role ]stsiface.STSAPI {
@@ -1150,6 +1171,45 @@ func TestSTSResolvesFIPSEnabledEndpoints(t *testing.T) {
11501171 }
11511172}
11521173
1174+ func TestRaceConditionRefreshClear (t * testing.T ) {
1175+ t .Parallel ()
1176+
1177+ // Create a factory with the test config
1178+ factory := NewFactory (promslog .NewNopLogger (), model.JobsConfig {}, false )
1179+
1180+ // Number of concurrent operations to perform
1181+ iterations := 100
1182+
1183+ // Use WaitGroup to synchronize goroutines
1184+ var wg sync.WaitGroup
1185+ wg .Add (iterations ) // For both Refresh and Clear calls
1186+
1187+ // Start function to run concurrent operations
1188+ for i := 0 ; i < iterations ; i ++ {
1189+ // Launch goroutine to call Refresh
1190+ go func () {
1191+ defer wg .Done ()
1192+ factory .Refresh ()
1193+ factory .Clear ()
1194+ }()
1195+ }
1196+
1197+ // Create a channel to signal completion
1198+ done := make (chan struct {})
1199+ go func () {
1200+ wg .Wait ()
1201+ close (done )
1202+ }()
1203+
1204+ // Wait for either completion or timeout
1205+ select {
1206+ case <- done :
1207+ // Test completed successfully
1208+ case <- time .After (60 * time .Second ):
1209+ require .Fail (t , "Test timed out after 60 seconds" )
1210+ }
1211+ }
1212+
11531213func testAWSClient (
11541214 t * testing.T ,
11551215 name string ,
0 commit comments