@@ -2,6 +2,9 @@ package registry
2
2
3
3
import (
4
4
"os"
5
+ "strings"
6
+ "sync"
7
+ "sync/atomic"
5
8
"testing"
6
9
"time"
7
10
@@ -161,3 +164,121 @@ registries:
161
164
})
162
165
163
166
}
167
+
168
+ func Test_ConcurrentCredentialFetching (t * testing.T ) {
169
+ t .Run ("Multiple goroutines fetching credentials should only call once" , func (t * testing.T ) {
170
+ // Create a mock script that counts how many times it's called
171
+ scriptContent := `#!/bin/sh
172
+ echo "counter" >> /tmp/test_ecr_calls.log
173
+ echo "AWS:mock-token-12345"
174
+ `
175
+ scriptPath := "/tmp/test_ecr_auth.sh"
176
+ err := os .WriteFile (scriptPath , []byte (scriptContent ), 0755 )
177
+ require .NoError (t , err )
178
+ defer os .Remove (scriptPath )
179
+
180
+ // Clean up any existing log file
181
+ os .Remove ("/tmp/test_ecr_calls.log" )
182
+ defer os .Remove ("/tmp/test_ecr_calls.log" )
183
+
184
+ epYAML := `
185
+ registries:
186
+ - name: ECR Registry
187
+ api_url: https://123456789.dkr.ecr.us-east-1.amazonaws.com
188
+ prefix: 123456789.dkr.ecr.us-east-1.amazonaws.com
189
+ credentials: ext:` + scriptPath + `
190
+ credsexpire: 1s
191
+ `
192
+ epl , err := ParseRegistryConfiguration (epYAML )
193
+ require .NoError (t , err )
194
+ require .Len (t , epl .Items , 1 )
195
+
196
+ // Add registry configuration
197
+ err = AddRegistryEndpointFromConfig (epl .Items [0 ])
198
+ require .NoError (t , err )
199
+ ep , err := GetRegistryEndpoint ("123456789.dkr.ecr.us-east-1.amazonaws.com" )
200
+ require .NoError (t , err )
201
+
202
+ // Force credentials to be expired
203
+ ep .CredsUpdated = time .Now ().Add (- 2 * time .Second )
204
+
205
+ // Launch multiple goroutines to fetch credentials concurrently
206
+ var wg sync.WaitGroup
207
+ numGoroutines := 10
208
+ errors := make ([]error , numGoroutines )
209
+
210
+ for i := 0 ; i < numGoroutines ; i ++ {
211
+ wg .Add (1 )
212
+ go func (idx int ) {
213
+ defer wg .Done ()
214
+ errors [idx ] = ep .SetEndpointCredentials (nil )
215
+ }(i )
216
+ }
217
+
218
+ wg .Wait ()
219
+
220
+ // Check that no errors occurred
221
+ for i , err := range errors {
222
+ assert .NoError (t , err , "goroutine %d returned error" , i )
223
+ }
224
+
225
+ // Verify credentials were set
226
+ assert .Equal (t , "AWS" , ep .Username )
227
+ assert .Equal (t , "mock-token-12345" , ep .Password )
228
+
229
+ // Check that the script was called only once
230
+ data , err := os .ReadFile ("/tmp/test_ecr_calls.log" )
231
+ if err != nil {
232
+ // File might not exist if script wasn't called at all
233
+ assert .Equal (t , 0 , 0 )
234
+ } else {
235
+ lines := strings .Count (string (data ), "counter" )
236
+ assert .Equal (t , 1 , lines , "Expected script to be called exactly once, but was called %d times" , lines )
237
+ }
238
+ })
239
+
240
+ t .Run ("Concurrent calls with unexpired credentials should not refetch" , func (t * testing.T ) {
241
+ var callCount int32
242
+
243
+ epYAML := `
244
+ registries:
245
+ - name: Test Registry
246
+ api_url: https://test.registry.io
247
+ prefix: test.registry.io
248
+ credentials: env:TEST_CONCURRENT_CREDS
249
+ credsexpire: 10m
250
+ `
251
+ epl , err := ParseRegistryConfiguration (epYAML )
252
+ require .NoError (t , err )
253
+
254
+ err = AddRegistryEndpointFromConfig (epl .Items [0 ])
255
+ require .NoError (t , err )
256
+ ep , err := GetRegistryEndpoint ("test.registry.io" )
257
+ require .NoError (t , err )
258
+
259
+ // Set environment variable
260
+ os .Setenv ("TEST_CONCURRENT_CREDS" , "user:pass" )
261
+
262
+ // First call to set credentials
263
+ err = ep .SetEndpointCredentials (nil )
264
+ require .NoError (t , err )
265
+ atomic .AddInt32 (& callCount , 1 )
266
+
267
+ // Launch concurrent calls - these should not refetch
268
+ var wg sync.WaitGroup
269
+ for i := 0 ; i < 10 ; i ++ {
270
+ wg .Add (1 )
271
+ go func () {
272
+ defer wg .Done ()
273
+ err := ep .SetEndpointCredentials (nil )
274
+ assert .NoError (t , err )
275
+ }()
276
+ }
277
+ wg .Wait ()
278
+
279
+ // Credentials should still be cached, so total calls should be 1
280
+ assert .Equal (t , int32 (1 ), atomic .LoadInt32 (& callCount ))
281
+ assert .Equal (t , "user" , ep .Username )
282
+ assert .Equal (t , "pass" , ep .Password )
283
+ })
284
+ }
0 commit comments