Skip to content

Commit f706f2b

Browse files
authored
fix: prevent concurrent registry credential refreshing (#1175)
Signed-off-by: Jongwon Youn <[email protected]>
1 parent 2cd8c7d commit f706f2b

File tree

3 files changed

+135
-1
lines changed

3 files changed

+135
-1
lines changed

registry-scanner/pkg/registry/endpoints.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/argoproj-labs/argocd-image-updater/registry-scanner/pkg/log"
1414

1515
"go.uber.org/ratelimit"
16+
"golang.org/x/sync/singleflight"
1617
)
1718

1819
// TagListSort defines how the registry returns the list of tags
@@ -117,6 +118,9 @@ var defaultRegistry *RegistryEndpoint
117118
// Simple RW mutex for concurrent access to registries map
118119
var registryLock sync.RWMutex
119120

121+
// credentialGroup ensures only one credential refresh happens per registry
122+
var credentialGroup singleflight.Group
123+
120124
func AddRegistryEndpointFromConfig(epc RegistryConfiguration) error {
121125
ep := NewRegistryEndpoint(epc.Prefix, epc.Name, epc.ApiURL, epc.Credentials, epc.DefaultNS, epc.Insecure, TagListSortFromString(epc.TagSortMode), epc.Limit, epc.CredsExpire)
122126
return AddRegistryEndpoint(ep)

registry-scanner/pkg/registry/registry.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,17 @@ func (ep *RegistryEndpoint) expireCredentials() bool {
188188
return false
189189
}
190190

191-
// Sets endpoint credentials for this registry from a reference to a K8s secret
191+
// SetEndpointCredentials Sets endpoint credentials for this registry from a reference to a K8s secret
192192
func (ep *RegistryEndpoint) SetEndpointCredentials(kubeClient *kube.KubernetesClient) error {
193+
// Use singleflight to prevent concurrent credential fetching for the same registry
194+
_, err, _ := credentialGroup.Do(ep.RegistryAPI, func() (interface{}, error) {
195+
return nil, ep.setEndpointCredentialsInternal(kubeClient)
196+
})
197+
return err
198+
}
199+
200+
// setEndpointCredentialsInternal performs the actual credential fetching
201+
func (ep *RegistryEndpoint) setEndpointCredentialsInternal(kubeClient *kube.KubernetesClient) error {
193202
if ep.expireCredentials() {
194203
log.Debugf("expired credentials for registry %s (updated:%s, expiry:%0fs)", ep.RegistryAPI, ep.CredsUpdated, ep.CredsExpire.Seconds())
195204
}

registry-scanner/pkg/registry/registry_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package registry
22

33
import (
44
"os"
5+
"strings"
6+
"sync"
7+
"sync/atomic"
58
"testing"
69
"time"
710

@@ -161,3 +164,121 @@ registries:
161164
})
162165

163166
}
167+
168+
func Test_ConcurrentCredentialFetching(t *testing.T) {
169+
t.Run("Multiple goroutines fetching credentials should only call once", func(t *testing.T) {
170+
// Create a mock script that counts how many times it's called
171+
scriptContent := `#!/bin/sh
172+
echo "counter" >> /tmp/test_ecr_calls.log
173+
echo "AWS:mock-token-12345"
174+
`
175+
scriptPath := "/tmp/test_ecr_auth.sh"
176+
err := os.WriteFile(scriptPath, []byte(scriptContent), 0755)
177+
require.NoError(t, err)
178+
defer os.Remove(scriptPath)
179+
180+
// Clean up any existing log file
181+
os.Remove("/tmp/test_ecr_calls.log")
182+
defer os.Remove("/tmp/test_ecr_calls.log")
183+
184+
epYAML := `
185+
registries:
186+
- name: ECR Registry
187+
api_url: https://123456789.dkr.ecr.us-east-1.amazonaws.com
188+
prefix: 123456789.dkr.ecr.us-east-1.amazonaws.com
189+
credentials: ext:` + scriptPath + `
190+
credsexpire: 1s
191+
`
192+
epl, err := ParseRegistryConfiguration(epYAML)
193+
require.NoError(t, err)
194+
require.Len(t, epl.Items, 1)
195+
196+
// Add registry configuration
197+
err = AddRegistryEndpointFromConfig(epl.Items[0])
198+
require.NoError(t, err)
199+
ep, err := GetRegistryEndpoint("123456789.dkr.ecr.us-east-1.amazonaws.com")
200+
require.NoError(t, err)
201+
202+
// Force credentials to be expired
203+
ep.CredsUpdated = time.Now().Add(-2 * time.Second)
204+
205+
// Launch multiple goroutines to fetch credentials concurrently
206+
var wg sync.WaitGroup
207+
numGoroutines := 10
208+
errors := make([]error, numGoroutines)
209+
210+
for i := 0; i < numGoroutines; i++ {
211+
wg.Add(1)
212+
go func(idx int) {
213+
defer wg.Done()
214+
errors[idx] = ep.SetEndpointCredentials(nil)
215+
}(i)
216+
}
217+
218+
wg.Wait()
219+
220+
// Check that no errors occurred
221+
for i, err := range errors {
222+
assert.NoError(t, err, "goroutine %d returned error", i)
223+
}
224+
225+
// Verify credentials were set
226+
assert.Equal(t, "AWS", ep.Username)
227+
assert.Equal(t, "mock-token-12345", ep.Password)
228+
229+
// Check that the script was called only once
230+
data, err := os.ReadFile("/tmp/test_ecr_calls.log")
231+
if err != nil {
232+
// File might not exist if script wasn't called at all
233+
assert.Equal(t, 0, 0)
234+
} else {
235+
lines := strings.Count(string(data), "counter")
236+
assert.Equal(t, 1, lines, "Expected script to be called exactly once, but was called %d times", lines)
237+
}
238+
})
239+
240+
t.Run("Concurrent calls with unexpired credentials should not refetch", func(t *testing.T) {
241+
var callCount int32
242+
243+
epYAML := `
244+
registries:
245+
- name: Test Registry
246+
api_url: https://test.registry.io
247+
prefix: test.registry.io
248+
credentials: env:TEST_CONCURRENT_CREDS
249+
credsexpire: 10m
250+
`
251+
epl, err := ParseRegistryConfiguration(epYAML)
252+
require.NoError(t, err)
253+
254+
err = AddRegistryEndpointFromConfig(epl.Items[0])
255+
require.NoError(t, err)
256+
ep, err := GetRegistryEndpoint("test.registry.io")
257+
require.NoError(t, err)
258+
259+
// Set environment variable
260+
os.Setenv("TEST_CONCURRENT_CREDS", "user:pass")
261+
262+
// First call to set credentials
263+
err = ep.SetEndpointCredentials(nil)
264+
require.NoError(t, err)
265+
atomic.AddInt32(&callCount, 1)
266+
267+
// Launch concurrent calls - these should not refetch
268+
var wg sync.WaitGroup
269+
for i := 0; i < 10; i++ {
270+
wg.Add(1)
271+
go func() {
272+
defer wg.Done()
273+
err := ep.SetEndpointCredentials(nil)
274+
assert.NoError(t, err)
275+
}()
276+
}
277+
wg.Wait()
278+
279+
// Credentials should still be cached, so total calls should be 1
280+
assert.Equal(t, int32(1), atomic.LoadInt32(&callCount))
281+
assert.Equal(t, "user", ep.Username)
282+
assert.Equal(t, "pass", ep.Password)
283+
})
284+
}

0 commit comments

Comments
 (0)