@@ -12,6 +12,7 @@ import (
1212 "net/http"
1313 "strconv"
1414 "strings"
15+ "sync"
1516 "time"
1617
1718 "github.com/caddyserver/caddy/v2"
@@ -30,6 +31,23 @@ func init() {
3031type User = caddyauth.User
3132type Token = jwt.Token
3233
34+ // jwkCacheEntry stores the JWK cache information for a specific URL
35+ type jwkCacheEntry struct {
36+ URL string
37+ Cache * jwk.Cache
38+ CachedSet jwk.Set
39+ }
40+
41+ // refresh refreshes the JWK cache for this entry
42+ func (entry * jwkCacheEntry ) refresh (ctx context.Context , logger * zap.Logger ) error {
43+ _ , err := entry .Cache .Refresh (ctx , entry .URL )
44+ if err != nil {
45+ logger .Warn ("failed to refresh JWK cache" , zap .Error (err ), zap .String ("url" , entry .URL ))
46+ return err
47+ }
48+ return nil
49+ }
50+
3351// JWTAuth facilitates JWT (JSON Web Token) authentication.
3452type JWTAuth struct {
3553 // SignKey is the key used by the signing algorithm to verify the signature.
@@ -150,8 +168,10 @@ type JWTAuth struct {
150168 logger * zap.Logger
151169 parsedSignKey interface {} // can be []byte, *rsa.PublicKey, *ecdsa.PublicKey, etc.
152170
153- jwkCache * jwk.Cache
154- jwkCachedSet jwk.Set
171+ // JWK cache by resolved URL to support placeholders in JWKURL
172+ jwkCaches map [string ]* jwkCacheEntry
173+
174+ mutex sync.RWMutex
155175}
156176
157177// CaddyModule implements caddy.Module interface.
@@ -179,18 +199,64 @@ func (ja *JWTAuth) usingJWK() bool {
179199}
180200
181201func (ja * JWTAuth ) setupJWKLoader () {
182- cache := jwk .NewCache (context .Background (), jwk .WithErrSink (ja ))
183- cache .Register (ja .JWKURL )
184- ja .jwkCache = cache
185- ja .refreshJWKCache ()
186- ja .jwkCachedSet = jwk .NewCachedSet (cache , ja .JWKURL )
187- ja .logger .Info ("using JWKs from URL" , zap .String ("url" , ja .JWKURL ), zap .Int ("loaded_keys" , ja .jwkCachedSet .Len ()))
202+ // Initialize cache for all possible URLs
203+ ja .mutex .Lock ()
204+ ja .jwkCaches = make (map [string ]* jwkCacheEntry )
205+ ja .mutex .Unlock ()
206+ ja .logger .Info ("JWK cache initialized for JWK URL" , zap .String ("jwk_url" , ja .JWKURL ))
188207}
189208
190- // refreshJWKCache refreshes the JWK cache. It validates the JWKs from the given URL.
191- func (ja * JWTAuth ) refreshJWKCache () {
192- _ , err := ja .jwkCache .Refresh (context .Background (), ja .JWKURL )
193- ja .logger .Warn ("failed to refresh JWK cache" , zap .Error (err ))
209+ // getOrCreateJWKCache retrieves or creates a cache for a specific JWK URL
210+ func (ja * JWTAuth ) getOrCreateJWKCache (resolvedURL string ) (* jwkCacheEntry , error ) {
211+ // If the URL is empty, return an error
212+ if resolvedURL == "" {
213+ return nil , fmt .Errorf ("resolved JWK URL is empty" )
214+ }
215+
216+ // First, check if cache already exists for this URL using a read lock
217+ ja .mutex .RLock ()
218+ entry , ok := ja .jwkCaches [resolvedURL ]
219+ ja .mutex .RUnlock ()
220+
221+ if ok {
222+ return entry , nil
223+ }
224+
225+ // If not found, acquire a write lock to create a new entry
226+ ja .mutex .Lock ()
227+ defer ja .mutex .Unlock ()
228+
229+ // Double-check if another goroutine created the cache while we were waiting for the lock
230+ if entry , ok := ja .jwkCaches [resolvedURL ]; ok {
231+ return entry , nil
232+ }
233+
234+ // Create a new cache for this URL
235+ cache := jwk .NewCache (context .Background (), jwk .WithErrSink (ja ))
236+ err := cache .Register (resolvedURL )
237+ if err != nil {
238+ return nil , fmt .Errorf ("failed to register JWK URL: %w" , err )
239+ }
240+
241+ // Create cache entry before attempting refresh
242+ entry = & jwkCacheEntry {
243+ URL : resolvedURL ,
244+ Cache : cache ,
245+ CachedSet : jwk .NewCachedSet (cache , resolvedURL ),
246+ }
247+
248+ // Try to refresh the cache immediately
249+ err = entry .refresh (context .Background (), ja .logger )
250+ if err != nil {
251+ ja .logger .Warn ("failed to refresh JWK cache during initialization" , zap .Error (err ), zap .String ("url" , resolvedURL ))
252+ // Continue even in case of error, the important thing is that the URL is registered
253+ }
254+
255+ // Register the entry in the cache
256+ ja .jwkCaches [resolvedURL ] = entry
257+ ja .logger .Info ("new JWK cache created" , zap .String ("url" , resolvedURL ), zap .Int ("loaded_keys" , entry .CachedSet .Len ()))
258+
259+ return entry , nil
194260}
195261
196262// Validate implements caddy.Validator interface.
@@ -216,6 +282,8 @@ func (ja *JWTAuth) Validate() error {
216282
217283func (ja * JWTAuth ) validateSignatureKeys () error {
218284 if ja .usingJWK () {
285+ // Initialize the cache structure without attempting to resolve placeholders
286+ // URLs will be resolved during usage
219287 ja .setupJWKLoader ()
220288 } else {
221289 if keyBytes , asymmetric , err := parseSignKey (ja .SignKey ); err != nil {
@@ -241,19 +309,36 @@ func (ja *JWTAuth) validateSignatureKeys() error {
241309 return nil
242310}
243311
244- func (ja * JWTAuth ) keyProvider () jws.KeyProviderFunc {
245- return func (_ context.Context , sink jws.KeySink , sig * jws.Signature , _ * jws.Message ) error {
312+ // resolveJWKURL prend une requête HTTP et résout l'URL JWK avec des espaces réservés
313+ func (ja * JWTAuth ) resolveJWKURL (request * http.Request ) string {
314+ replacer := request .Context ().Value (caddy .ReplacerCtxKey ).(* caddy.Replacer )
315+ return replacer .ReplaceAll (ja .JWKURL , "" )
316+ }
317+
318+ func (ja * JWTAuth ) keyProvider (request * http.Request ) jws.KeyProviderFunc {
319+ return func (curContext context.Context , sink jws.KeySink , sig * jws.Signature , _ * jws.Message ) error {
246320 if ja .usingJWK () {
321+ resolvedURL := ja .resolveJWKURL (request )
322+
323+ ja .logger .Info ("JWK URL" , zap .String ("unresolved" , ja .JWKURL ), zap .String ("resolved" , resolvedURL ))
324+
325+ // Get or create the cache for this URL
326+ cacheEntry , err := ja .getOrCreateJWKCache (resolvedURL )
327+ if err != nil {
328+ return fmt .Errorf ("failed to get JWK cache: %w" , err )
329+ }
330+
331+ // Use the key set associated with this URL
247332 kid := sig .ProtectedHeaders ().KeyID ()
248- key , found := ja . jwkCachedSet .LookupKeyID (kid )
333+ key , found := cacheEntry . CachedSet .LookupKeyID (kid )
249334 if ! found {
250- // trigger a refresh if the key is not found
251- go ja . refreshJWKCache ( )
335+ // Trigger an asynchronous refresh if the key is not found
336+ go cacheEntry . refresh ( context . Background (), ja . logger )
252337
253338 if kid == "" {
254339 return fmt .Errorf ("missing kid in JWT header" )
255340 }
256- return fmt .Errorf ("key specified by kid %q not found in JWKs" , kid )
341+ return fmt .Errorf ("key specified by kid %q not found in JWKs from %s " , kid , resolvedURL )
257342 }
258343 sink .Key (ja .determineSigningAlgorithm (key .Algorithm (), sig .ProtectedHeaders ().Algorithm ()), key )
259344 } else if ja .SignAlgorithm == string (jwa .EdDSA ) {
@@ -304,7 +389,7 @@ func (ja *JWTAuth) Authenticate(rw http.ResponseWriter, r *http.Request) (User,
304389 jwt .WithVerify (! ja .SkipVerification ),
305390 }
306391 if ! ja .SkipVerification {
307- jwtOptions = append (jwtOptions , jwt .WithKeyProvider (ja .keyProvider ()))
392+ jwtOptions = append (jwtOptions , jwt .WithKeyProvider (ja .keyProvider (r )))
308393 }
309394 gotToken , err = jwt .ParseString (tokenString , jwtOptions ... )
310395
@@ -515,15 +600,15 @@ func desensitizedTokenString(token string) string {
515600func parseSignKey (signKey string ) (keyBytes []byte , asymmetric bool , err error ) {
516601 repl := caddy .NewReplacer ()
517602 // Replace placeholders in the signKey such as {file./path/to/sign_key.txt}
518- signKey = repl .ReplaceAll (signKey , "" )
519- if len (signKey ) == 0 {
603+ resolvedSignKey : = repl .ReplaceAll (signKey , "" )
604+ if len (resolvedSignKey ) == 0 {
520605 return nil , false , ErrMissingKeys
521606 }
522- if strings .Contains (signKey , "-----BEGIN PUBLIC KEY-----" ) {
523- keyBytes , err = parsePEMFormattedPublicKey (signKey )
607+ if strings .Contains (resolvedSignKey , "-----BEGIN PUBLIC KEY-----" ) {
608+ keyBytes , err = parsePEMFormattedPublicKey (resolvedSignKey )
524609 return keyBytes , true , err
525610 }
526- keyBytes , err = base64 .StdEncoding .DecodeString (signKey )
611+ keyBytes , err = base64 .StdEncoding .DecodeString (resolvedSignKey )
527612 return keyBytes , false , err
528613}
529614
0 commit comments