|
4 | 4 | "context" |
5 | 5 | "errors" |
6 | 6 | "net/http" |
7 | | - "sync" |
8 | | - "sync/atomic" |
9 | 7 | "testing" |
10 | 8 | "time" |
11 | 9 |
|
@@ -228,183 +226,6 @@ func TestExternalTokenProvider(t *testing.T) { |
228 | 226 | }) |
229 | 227 | } |
230 | 228 |
|
231 | | -func TestCachedTokenProvider(t *testing.T) { |
232 | | - t.Run("caches_valid_token", func(t *testing.T) { |
233 | | - callCount := 0 |
234 | | - baseProvider := &mockProvider{ |
235 | | - tokenFunc: func() (*Token, error) { |
236 | | - callCount++ |
237 | | - return &Token{ |
238 | | - AccessToken: "cached-token", |
239 | | - TokenType: "Bearer", |
240 | | - ExpiresAt: time.Now().Add(1 * time.Hour), |
241 | | - }, nil |
242 | | - }, |
243 | | - name: "mock", |
244 | | - } |
245 | | - |
246 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
247 | | - |
248 | | - // First call - should fetch from base provider |
249 | | - token1, err1 := cachedProvider.GetToken(context.Background()) |
250 | | - require.NoError(t, err1) |
251 | | - assert.Equal(t, "cached-token", token1.AccessToken) |
252 | | - assert.Equal(t, 1, callCount) |
253 | | - |
254 | | - // Second call - should use cache |
255 | | - token2, err2 := cachedProvider.GetToken(context.Background()) |
256 | | - require.NoError(t, err2) |
257 | | - assert.Equal(t, "cached-token", token2.AccessToken) |
258 | | - assert.Equal(t, 1, callCount) // Should still be 1 |
259 | | - }) |
260 | | - |
261 | | - t.Run("refreshes_expired_token", func(t *testing.T) { |
262 | | - callCount := 0 |
263 | | - baseProvider := &mockProvider{ |
264 | | - tokenFunc: func() (*Token, error) { |
265 | | - callCount++ |
266 | | - // Return token that expires soon |
267 | | - return &Token{ |
268 | | - AccessToken: "token-" + string(rune(callCount)), |
269 | | - TokenType: "Bearer", |
270 | | - ExpiresAt: time.Now().Add(2 * time.Minute), // Within refresh threshold |
271 | | - }, nil |
272 | | - }, |
273 | | - name: "mock", |
274 | | - } |
275 | | - |
276 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
277 | | - cachedProvider.RefreshThreshold = 5 * time.Minute |
278 | | - |
279 | | - // First call |
280 | | - token1, err1 := cachedProvider.GetToken(context.Background()) |
281 | | - require.NoError(t, err1) |
282 | | - assert.Equal(t, "token-\x01", token1.AccessToken) |
283 | | - assert.Equal(t, 1, callCount) |
284 | | - |
285 | | - // Second call - should refresh because token expires within threshold |
286 | | - token2, err2 := cachedProvider.GetToken(context.Background()) |
287 | | - require.NoError(t, err2) |
288 | | - assert.Equal(t, "token-\x02", token2.AccessToken) |
289 | | - assert.Equal(t, 2, callCount) |
290 | | - }) |
291 | | - |
292 | | - t.Run("handles_provider_error", func(t *testing.T) { |
293 | | - baseProvider := &mockProvider{ |
294 | | - tokenFunc: func() (*Token, error) { |
295 | | - return nil, errors.New("provider error") |
296 | | - }, |
297 | | - name: "mock", |
298 | | - } |
299 | | - |
300 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
301 | | - token, err := cachedProvider.GetToken(context.Background()) |
302 | | - |
303 | | - assert.Error(t, err) |
304 | | - assert.Nil(t, token) |
305 | | - assert.Contains(t, err.Error(), "provider error") |
306 | | - }) |
307 | | - |
308 | | - t.Run("no_expiry_token_not_refreshed", func(t *testing.T) { |
309 | | - callCount := 0 |
310 | | - baseProvider := &mockProvider{ |
311 | | - tokenFunc: func() (*Token, error) { |
312 | | - callCount++ |
313 | | - return &Token{ |
314 | | - AccessToken: "permanent-token", |
315 | | - TokenType: "Bearer", |
316 | | - ExpiresAt: time.Time{}, // No expiry |
317 | | - }, nil |
318 | | - }, |
319 | | - name: "mock", |
320 | | - } |
321 | | - |
322 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
323 | | - |
324 | | - // Multiple calls should all use cache |
325 | | - for i := 0; i < 5; i++ { |
326 | | - token, err := cachedProvider.GetToken(context.Background()) |
327 | | - require.NoError(t, err) |
328 | | - assert.Equal(t, "permanent-token", token.AccessToken) |
329 | | - } |
330 | | - |
331 | | - assert.Equal(t, 1, callCount) // Should only be called once |
332 | | - }) |
333 | | - |
334 | | - t.Run("clear_cache", func(t *testing.T) { |
335 | | - callCount := 0 |
336 | | - baseProvider := &mockProvider{ |
337 | | - tokenFunc: func() (*Token, error) { |
338 | | - callCount++ |
339 | | - return &Token{ |
340 | | - AccessToken: "token-" + string(rune(callCount)), |
341 | | - TokenType: "Bearer", |
342 | | - ExpiresAt: time.Now().Add(1 * time.Hour), |
343 | | - }, nil |
344 | | - }, |
345 | | - name: "mock", |
346 | | - } |
347 | | - |
348 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
349 | | - |
350 | | - // First call |
351 | | - token1, _ := cachedProvider.GetToken(context.Background()) |
352 | | - assert.Equal(t, "token-\x01", token1.AccessToken) |
353 | | - assert.Equal(t, 1, callCount) |
354 | | - |
355 | | - // Clear cache |
356 | | - cachedProvider.ClearCache() |
357 | | - |
358 | | - // Next call should fetch new token |
359 | | - token2, _ := cachedProvider.GetToken(context.Background()) |
360 | | - assert.Equal(t, "token-\x02", token2.AccessToken) |
361 | | - assert.Equal(t, 2, callCount) |
362 | | - }) |
363 | | - |
364 | | - t.Run("concurrent_access", func(t *testing.T) { |
365 | | - var callCount atomic.Int32 |
366 | | - baseProvider := &mockProvider{ |
367 | | - tokenFunc: func() (*Token, error) { |
368 | | - // Simulate slow token fetch |
369 | | - time.Sleep(100 * time.Millisecond) |
370 | | - callCount.Add(1) |
371 | | - return &Token{ |
372 | | - AccessToken: "concurrent-token", |
373 | | - TokenType: "Bearer", |
374 | | - ExpiresAt: time.Now().Add(1 * time.Hour), |
375 | | - }, nil |
376 | | - }, |
377 | | - name: "mock", |
378 | | - } |
379 | | - |
380 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
381 | | - |
382 | | - // Launch multiple goroutines |
383 | | - var wg sync.WaitGroup |
384 | | - for i := 0; i < 10; i++ { |
385 | | - wg.Add(1) |
386 | | - go func() { |
387 | | - defer wg.Done() |
388 | | - token, err := cachedProvider.GetToken(context.Background()) |
389 | | - assert.NoError(t, err) |
390 | | - assert.Equal(t, "concurrent-token", token.AccessToken) |
391 | | - }() |
392 | | - } |
393 | | - |
394 | | - wg.Wait() |
395 | | - |
396 | | - // Should only fetch token once despite concurrent access |
397 | | - assert.Equal(t, int32(1), callCount.Load()) |
398 | | - }) |
399 | | - |
400 | | - t.Run("provider_name", func(t *testing.T) { |
401 | | - baseProvider := &mockProvider{name: "test-provider"} |
402 | | - cachedProvider := NewCachedTokenProvider(baseProvider) |
403 | | - |
404 | | - assert.Equal(t, "cached[test-provider]", cachedProvider.Name()) |
405 | | - }) |
406 | | -} |
407 | | - |
408 | 229 | // Mock provider for testing |
409 | 230 | type mockProvider struct { |
410 | 231 | tokenFunc func() (*Token, error) |
|
0 commit comments