@@ -150,8 +150,11 @@ type JWTAuth struct {
150150 logger * zap.Logger
151151 parsedSignKey interface {} // can be []byte, *rsa.PublicKey, *ecdsa.PublicKey, etc.
152152
153- jwkCache * jwk.Cache
154- jwkCachedSet jwk.Set
153+ // JWK cache by resolved URL to support placeholders in JWKURL
154+ jwkCaches map [string ]* struct {
155+ cache * jwk.Cache
156+ cachedSet jwk.Set
157+ }
155158}
156159
157160// CaddyModule implements caddy.Module interface.
@@ -179,18 +182,72 @@ func (ja *JWTAuth) usingJWK() bool {
179182}
180183
181184func (ja * JWTAuth ) setupJWKLoader () {
185+ // Initialiser le cache pour toutes les URL possibles
186+ ja .jwkCaches = make (map [string ]* struct {
187+ cache * jwk.Cache
188+ cachedSet jwk.Set
189+ })
190+ ja .logger .Info ("JWK cache initialized for placeholder URL" , zap .String ("placeholder_url" , ja .JWKURL ))
191+ }
192+
193+ // getOrCreateJWKCache retrieves or creates a cache for a specific JWK URL
194+ func (ja * JWTAuth ) getOrCreateJWKCache (resolvedURL string ) (* struct {
195+ cache * jwk.Cache
196+ cachedSet jwk.Set
197+ }, error ) {
198+ // If the URL is empty, return an error
199+ if resolvedURL == "" {
200+ return nil , fmt .Errorf ("resolved JWK URL is empty" )
201+ }
202+
203+ // Check if cache already exists for this URL
204+ if entry , ok := ja .jwkCaches [resolvedURL ]; ok {
205+ return entry , nil
206+ }
207+
208+ // Create a new cache for this URL
182209 cache := jwk .NewCache (context .Background (), jwk .WithErrSink (ja ))
183- cache .Register (ja .JWKURL )
184- ja .jwkCache = cache
185- ja .refreshJWKCache ()
186- ja .jwkCachedSet = jwk .NewCachedSet (cache , ja .JWKURL )
187- ja .logger .Info ("using JWKs from URL" , zap .String ("url" , ja .JWKURL ), zap .Int ("loaded_keys" , ja .jwkCachedSet .Len ()))
210+ err := cache .Register (resolvedURL )
211+ if err != nil {
212+ return nil , fmt .Errorf ("failed to register JWK URL: %w" , err )
213+ }
214+
215+ // Try to refresh the cache immediately
216+ _ , err = cache .Refresh (context .Background (), resolvedURL )
217+ if err != nil {
218+ ja .logger .Warn ("failed to refresh JWK cache during initialization" , zap .Error (err ), zap .String ("url" , resolvedURL ))
219+ // Continue even in case of error, the important thing is that the URL is registered
220+ }
221+
222+ cachedSet := jwk .NewCachedSet (cache , resolvedURL )
223+ entry := & struct {
224+ cache * jwk.Cache
225+ cachedSet jwk.Set
226+ }{
227+ cache : cache ,
228+ cachedSet : cachedSet ,
229+ }
230+
231+ // Register the entry in the cache
232+ ja .jwkCaches [resolvedURL ] = entry
233+ ja .logger .Info ("new JWK cache created" , zap .String ("url" , resolvedURL ), zap .Int ("loaded_keys" , cachedSet .Len ()))
234+
235+ return entry , nil
188236}
189237
190- // refreshJWKCache refreshes the JWK cache. It validates the JWKs from the given URL.
191- func (ja * JWTAuth ) refreshJWKCache () {
192- _ , err := ja .jwkCache .Refresh (context .Background (), ja .JWKURL )
193- ja .logger .Warn ("failed to refresh JWK cache" , zap .Error (err ))
238+ // refreshJWKCache refreshes the JWK cache for a specific URL. It validates the JWKs from the given URL.
239+ func (ja * JWTAuth ) refreshJWKCache (resolvedURL string ) error {
240+ entry , err := ja .getOrCreateJWKCache (resolvedURL )
241+ if err != nil {
242+ return err
243+ }
244+
245+ _ , err = entry .cache .Refresh (context .Background (), resolvedURL )
246+ if err != nil {
247+ ja .logger .Warn ("failed to refresh JWK cache" , zap .Error (err ), zap .String ("url" , resolvedURL ))
248+ return err
249+ }
250+ return nil
194251}
195252
196253// Validate implements caddy.Validator interface.
@@ -216,6 +273,8 @@ func (ja *JWTAuth) Validate() error {
216273
217274func (ja * JWTAuth ) validateSignatureKeys () error {
218275 if ja .usingJWK () {
276+ // Initialize the cache structure without attempting to resolve placeholders
277+ // URLs will be resolved during usage
219278 ja .setupJWKLoader ()
220279 } else {
221280 if keyBytes , asymmetric , err := parseSignKey (ja .SignKey ); err != nil {
@@ -241,19 +300,33 @@ func (ja *JWTAuth) validateSignatureKeys() error {
241300 return nil
242301}
243302
244- func (ja * JWTAuth ) keyProvider () jws.KeyProviderFunc {
245- return func (_ context.Context , sink jws.KeySink , sig * jws.Signature , _ * jws.Message ) error {
303+ func (ja * JWTAuth ) keyProvider (request * http. Request ) jws.KeyProviderFunc {
304+ return func (context context.Context , sink jws.KeySink , sig * jws.Signature , _ * jws.Message ) error {
246305 if ja .usingJWK () {
306+ // Resolve JWKURL with placeholders
307+ replacer := request .Context ().Value (caddy .ReplacerCtxKey ).(* caddy.Replacer )
308+ resolvedURL := replacer .ReplaceAll (ja .JWKURL , "" )
309+
310+ ja .logger .Info ("JWK unresolved" , zap .String ("placeholder_url" , ja .JWKURL ))
311+ ja .logger .Info ("JWK resolved" , zap .String ("placeholder_url" , resolvedURL ))
312+
313+ // Get or create the cache for this URL
314+ cacheEntry , err := ja .getOrCreateJWKCache (resolvedURL )
315+ if err != nil {
316+ return fmt .Errorf ("failed to get JWK cache: %w" , err )
317+ }
318+
319+ // Use the key set associated with this URL
247320 kid := sig .ProtectedHeaders ().KeyID ()
248- key , found := ja . jwkCachedSet .LookupKeyID (kid )
321+ key , found := cacheEntry . cachedSet .LookupKeyID (kid )
249322 if ! found {
250- // trigger a refresh if the key is not found
251- go ja .refreshJWKCache ()
323+ // Trigger an asynchronous refresh if the key is not found
324+ go ja .refreshJWKCache (resolvedURL )
252325
253326 if kid == "" {
254327 return fmt .Errorf ("missing kid in JWT header" )
255328 }
256- return fmt .Errorf ("key specified by kid %q not found in JWKs" , kid )
329+ return fmt .Errorf ("key specified by kid %q not found in JWKs from %s " , kid , resolvedURL )
257330 }
258331 sink .Key (ja .determineSigningAlgorithm (key .Algorithm (), sig .ProtectedHeaders ().Algorithm ()), key )
259332 } else if ja .SignAlgorithm == string (jwa .EdDSA ) {
@@ -304,7 +377,7 @@ func (ja *JWTAuth) Authenticate(rw http.ResponseWriter, r *http.Request) (User,
304377 jwt .WithVerify (! ja .SkipVerification ),
305378 }
306379 if ! ja .SkipVerification {
307- jwtOptions = append (jwtOptions , jwt .WithKeyProvider (ja .keyProvider ()))
380+ jwtOptions = append (jwtOptions , jwt .WithKeyProvider (ja .keyProvider (r )))
308381 }
309382 gotToken , err = jwt .ParseString (tokenString , jwtOptions ... )
310383
@@ -515,15 +588,15 @@ func desensitizedTokenString(token string) string {
515588func parseSignKey (signKey string ) (keyBytes []byte , asymmetric bool , err error ) {
516589 repl := caddy .NewReplacer ()
517590 // Replace placeholders in the signKey such as {file./path/to/sign_key.txt}
518- signKey = repl .ReplaceAll (signKey , "" )
519- if len (signKey ) == 0 {
591+ resolvedSignKey : = repl .ReplaceAll (signKey , "" )
592+ if len (resolvedSignKey ) == 0 {
520593 return nil , false , ErrMissingKeys
521594 }
522- if strings .Contains (signKey , "-----BEGIN PUBLIC KEY-----" ) {
523- keyBytes , err = parsePEMFormattedPublicKey (signKey )
595+ if strings .Contains (resolvedSignKey , "-----BEGIN PUBLIC KEY-----" ) {
596+ keyBytes , err = parsePEMFormattedPublicKey (resolvedSignKey )
524597 return keyBytes , true , err
525598 }
526- keyBytes , err = base64 .StdEncoding .DecodeString (signKey )
599+ keyBytes , err = base64 .StdEncoding .DecodeString (resolvedSignKey )
527600 return keyBytes , false , err
528601}
529602
0 commit comments