Skip to content

Commit 2d6977e

Browse files
committed
feat: add TTL-based expiration for JWKS cache entries
- Add CacheTTL option to CachingClientOpts with 1-hour default - Introduce expiresAt field to track cache entry expiration time - Implement cache expiration checks during key retrieval - Rename InvalidateCacheIfNeeded to InvalidateCacheIfPossible with bool return - Update cache invalidation to reset TTL on successful refresh - Add comprehensive test coverage for TTL expiration scenarios
1 parent e086c1e commit 2d6977e

File tree

3 files changed

+447
-18
lines changed

3 files changed

+447
-18
lines changed

jwks/caching_client.go

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,23 @@ import (
1717

1818
const DefaultCacheUpdateMinInterval = time.Minute * 1
1919

20+
// DefaultCacheTTL is the default time-to-live for cached JWKS entries.
21+
// After this duration, cached entries are considered expired and will be refreshed.
22+
// This prevents revoked keys from remaining in cache indefinitely.
23+
const DefaultCacheTTL = time.Hour * 1
24+
2025
// CachingClientOpts contains options for CachingClient.
2126
type CachingClientOpts struct {
2227
ClientOpts
2328

2429
// CacheUpdateMinInterval is a minimal interval between cache updates for the same issuer.
2530
CacheUpdateMinInterval time.Duration
31+
32+
// CacheTTL is the time-to-live for cached JWKS entries.
33+
// After this duration, cached entries expire and will be refreshed on next access.
34+
// This prevents revoked keys from remaining in cache indefinitely.
35+
// Default: DefaultCacheTTL (1 hour).
36+
CacheTTL time.Duration
2637
}
2738

2839
// CachingClient is a Client for getting keys from remote JWKS with a caching mechanism.
@@ -31,30 +42,40 @@ type CachingClient struct {
3142
rawClient *Client
3243
issuerCache map[string]issuerCacheEntry
3344
cacheUpdateMinInterval time.Duration
45+
cacheTTL time.Duration
3446
}
3547

3648
const missingKeysCacheSize = 100
3749

3850
type issuerCacheEntry struct {
3951
updatedAt time.Time
52+
expiresAt time.Time
4053
keys map[string]interface{}
4154
missingKeys *lrucache.LRUCache[string, time.Time]
4255
}
4356

57+
func (ice *issuerCacheEntry) isExpired() bool {
58+
return time.Now().After(ice.expiresAt)
59+
}
60+
4461
// NewCachingClient returns a new Client that can cache fetched data.
4562
func NewCachingClient() *CachingClient {
4663
return NewCachingClientWithOpts(CachingClientOpts{})
4764
}
4865

4966
// NewCachingClientWithOpts returns a new Client that can cache fetched data with options.
5067
func NewCachingClientWithOpts(opts CachingClientOpts) *CachingClient {
51-
if opts.CacheUpdateMinInterval == 0 {
68+
if opts.CacheUpdateMinInterval <= 0 {
5269
opts.CacheUpdateMinInterval = DefaultCacheUpdateMinInterval
5370
}
71+
if opts.CacheTTL <= 0 {
72+
opts.CacheTTL = DefaultCacheTTL
73+
}
5474
return &CachingClient{
5575
rawClient: NewClientWithOpts(opts.ClientOpts),
5676
issuerCache: make(map[string]issuerCacheEntry),
5777
cacheUpdateMinInterval: opts.CacheUpdateMinInterval,
78+
cacheTTL: opts.CacheTTL,
5879
}
5980
}
6081

@@ -76,35 +97,38 @@ func (cc *CachingClient) GetRSAPublicKey(ctx context.Context, issuerURL, keyID s
7697
return nil, &JWKNotFoundError{IssuerURL: issuerURL, KeyID: keyID}
7798
}
7899

79-
// InvalidateCacheIfNeeded does cache invalidation for specific issuer URL if it's necessary.
80-
func (cc *CachingClient) InvalidateCacheIfNeeded(ctx context.Context, issuerURL string) error {
100+
// InvalidateCacheIfPossible does cache invalidation for specific issuer URL if possible.
101+
// It returns true if the cache was invalidated, false if invalidation was skipped due to rate limiting.
102+
func (cc *CachingClient) InvalidateCacheIfPossible(ctx context.Context, issuerURL string) (invalidated bool, err error) {
81103
cc.mu.Lock()
82104
defer cc.mu.Unlock()
83105

84106
var missingKeys *lrucache.LRUCache[string, time.Time]
85107
issCache, found := cc.issuerCache[issuerURL]
86108
if found {
87109
if time.Since(issCache.updatedAt) < cc.cacheUpdateMinInterval {
88-
return nil
110+
return false, nil
89111
}
90112
missingKeys = issCache.missingKeys
91113
} else {
92114
var err error
93115
if missingKeys, err = lrucache.New[string, time.Time](missingKeysCacheSize, nil); err != nil {
94-
return fmt.Errorf("new lru cache for missing keys: %w", err)
116+
return false, fmt.Errorf("new lru cache for missing keys: %w", err)
95117
}
96118
}
97119

98120
pubKeys, err := cc.rawClient.getRSAPubKeysForIssuer(ctx, issuerURL)
99121
if err != nil {
100-
return fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err)
122+
return false, fmt.Errorf("get rsa public keys for issuer %q: %w", issuerURL, err)
101123
}
124+
now := time.Now()
102125
cc.issuerCache[issuerURL] = issuerCacheEntry{
103-
updatedAt: time.Now(),
126+
updatedAt: now,
127+
expiresAt: now.Add(cc.cacheTTL),
104128
keys: pubKeys,
105129
missingKeys: missingKeys,
106130
}
107-
return nil
131+
return true, nil
108132
}
109133

110134
func (cc *CachingClient) getPubKeyFromCache(
@@ -117,6 +141,12 @@ func (cc *CachingClient) getPubKeyFromCache(
117141
if !issFound {
118142
return nil, false, true
119143
}
144+
145+
// Check if cache entry has expired based on TTL (if TTL is configured)
146+
if issCache.isExpired() {
147+
return nil, false, true
148+
}
149+
120150
if pubKey, found = issCache.keys[keyID]; found {
121151
return
122152
}
@@ -135,12 +165,14 @@ func (cc *CachingClient) getPubKeyFromCacheAndInvalidate(
135165

136166
var missingKeys *lrucache.LRUCache[string, time.Time]
137167
if issCache, issFound := cc.issuerCache[issuerURL]; issFound {
138-
if pubKey, found = issCache.keys[keyID]; found {
139-
return pubKey, true, nil
140-
}
141-
missedAt, miss := issCache.missingKeys.Get(keyID)
142-
if miss && time.Since(missedAt) < cc.cacheUpdateMinInterval {
143-
return nil, false, nil
168+
if !issCache.isExpired() {
169+
if pubKey, found = issCache.keys[keyID]; found {
170+
return pubKey, true, nil
171+
}
172+
missedAt, miss := issCache.missingKeys.Get(keyID)
173+
if miss && time.Since(missedAt) < cc.cacheUpdateMinInterval {
174+
return nil, false, nil
175+
}
144176
}
145177
missingKeys = issCache.missingKeys
146178
} else {
@@ -158,8 +190,10 @@ func (cc *CachingClient) getPubKeyFromCacheAndInvalidate(
158190
if !found {
159191
missingKeys.Add(keyID, time.Now())
160192
}
193+
now := time.Now()
161194
cc.issuerCache[issuerURL] = issuerCacheEntry{
162-
updatedAt: time.Now(),
195+
updatedAt: now,
196+
expiresAt: now.Add(cc.cacheTTL),
163197
keys: pubKeys,
164198
missingKeys: missingKeys,
165199
}

0 commit comments

Comments
 (0)