@@ -2,6 +2,9 @@ package registry
22
33import (
44 "os"
5+ "strings"
6+ "sync"
7+ "sync/atomic"
58 "testing"
69 "time"
710
@@ -161,3 +164,121 @@ registries:
161164 })
162165
163166}
167+
168+ func Test_ConcurrentCredentialFetching (t * testing.T ) {
169+ t .Run ("Multiple goroutines fetching credentials should only call once" , func (t * testing.T ) {
170+ // Create a mock script that counts how many times it's called
171+ scriptContent := `#!/bin/sh
172+ echo "counter" >> /tmp/test_ecr_calls.log
173+ echo "AWS:mock-token-12345"
174+ `
175+ scriptPath := "/tmp/test_ecr_auth.sh"
176+ err := os .WriteFile (scriptPath , []byte (scriptContent ), 0755 )
177+ require .NoError (t , err )
178+ defer os .Remove (scriptPath )
179+
180+ // Clean up any existing log file
181+ os .Remove ("/tmp/test_ecr_calls.log" )
182+ defer os .Remove ("/tmp/test_ecr_calls.log" )
183+
184+ epYAML := `
185+ registries:
186+ - name: ECR Registry
187+ api_url: https://123456789.dkr.ecr.us-east-1.amazonaws.com
188+ prefix: 123456789.dkr.ecr.us-east-1.amazonaws.com
189+ credentials: ext:` + scriptPath + `
190+ credsexpire: 1s
191+ `
192+ epl , err := ParseRegistryConfiguration (epYAML )
193+ require .NoError (t , err )
194+ require .Len (t , epl .Items , 1 )
195+
196+ // Add registry configuration
197+ err = AddRegistryEndpointFromConfig (epl .Items [0 ])
198+ require .NoError (t , err )
199+ ep , err := GetRegistryEndpoint ("123456789.dkr.ecr.us-east-1.amazonaws.com" )
200+ require .NoError (t , err )
201+
202+ // Force credentials to be expired
203+ ep .CredsUpdated = time .Now ().Add (- 2 * time .Second )
204+
205+ // Launch multiple goroutines to fetch credentials concurrently
206+ var wg sync.WaitGroup
207+ numGoroutines := 10
208+ errors := make ([]error , numGoroutines )
209+
210+ for i := 0 ; i < numGoroutines ; i ++ {
211+ wg .Add (1 )
212+ go func (idx int ) {
213+ defer wg .Done ()
214+ errors [idx ] = ep .SetEndpointCredentials (nil )
215+ }(i )
216+ }
217+
218+ wg .Wait ()
219+
220+ // Check that no errors occurred
221+ for i , err := range errors {
222+ assert .NoError (t , err , "goroutine %d returned error" , i )
223+ }
224+
225+ // Verify credentials were set
226+ assert .Equal (t , "AWS" , ep .Username )
227+ assert .Equal (t , "mock-token-12345" , ep .Password )
228+
229+ // Check that the script was called only once
230+ data , err := os .ReadFile ("/tmp/test_ecr_calls.log" )
231+ if err != nil {
232+ // File might not exist if script wasn't called at all
233+ assert .Equal (t , 0 , 0 )
234+ } else {
235+ lines := strings .Count (string (data ), "counter" )
236+ assert .Equal (t , 1 , lines , "Expected script to be called exactly once, but was called %d times" , lines )
237+ }
238+ })
239+
240+ t .Run ("Concurrent calls with unexpired credentials should not refetch" , func (t * testing.T ) {
241+ var callCount int32
242+
243+ epYAML := `
244+ registries:
245+ - name: Test Registry
246+ api_url: https://test.registry.io
247+ prefix: test.registry.io
248+ credentials: env:TEST_CONCURRENT_CREDS
249+ credsexpire: 10m
250+ `
251+ epl , err := ParseRegistryConfiguration (epYAML )
252+ require .NoError (t , err )
253+
254+ err = AddRegistryEndpointFromConfig (epl .Items [0 ])
255+ require .NoError (t , err )
256+ ep , err := GetRegistryEndpoint ("test.registry.io" )
257+ require .NoError (t , err )
258+
259+ // Set environment variable
260+ os .Setenv ("TEST_CONCURRENT_CREDS" , "user:pass" )
261+
262+ // First call to set credentials
263+ err = ep .SetEndpointCredentials (nil )
264+ require .NoError (t , err )
265+ atomic .AddInt32 (& callCount , 1 )
266+
267+ // Launch concurrent calls - these should not refetch
268+ var wg sync.WaitGroup
269+ for i := 0 ; i < 10 ; i ++ {
270+ wg .Add (1 )
271+ go func () {
272+ defer wg .Done ()
273+ err := ep .SetEndpointCredentials (nil )
274+ assert .NoError (t , err )
275+ }()
276+ }
277+ wg .Wait ()
278+
279+ // Credentials should still be cached, so total calls should be 1
280+ assert .Equal (t , int32 (1 ), atomic .LoadInt32 (& callCount ))
281+ assert .Equal (t , "user" , ep .Username )
282+ assert .Equal (t , "pass" , ep .Password )
283+ })
284+ }
0 commit comments