@@ -17,10 +17,12 @@ import (
1717
1818 "github.com/caddyserver/caddy/v2"
1919 "github.com/caddyserver/caddy/v2/modules/caddyhttp/caddyauth"
20- "github.com/lestrrat-go/jwx/v2/jwa"
21- "github.com/lestrrat-go/jwx/v2/jwk"
22- "github.com/lestrrat-go/jwx/v2/jws"
23- "github.com/lestrrat-go/jwx/v2/jwt"
20+ "github.com/lestrrat-go/httprc/v3"
21+ "github.com/lestrrat-go/httprc/v3/errsink"
22+ "github.com/lestrrat-go/jwx/v3/jwa"
23+ "github.com/lestrrat-go/jwx/v3/jwk"
24+ "github.com/lestrrat-go/jwx/v3/jws"
25+ "github.com/lestrrat-go/jwx/v3/jwt"
2426 "go.uber.org/zap"
2527)
2628
@@ -232,17 +234,28 @@ func (ja *JWTAuth) getOrCreateJWKCache(resolvedURL string) (*jwkCacheEntry, erro
232234 }
233235
234236 // Create a new cache for this URL
235- cache := jwk .NewCache (context .Background (), jwk .WithErrSink (ja ))
236- err := cache .Register (resolvedURL )
237+ client := httprc .NewClient (httprc .WithErrorSink (errsink .NewFunc (func (_ context.Context , err error ) {
238+ ja .Error (err )
239+ })))
240+ cache , err := jwk .NewCache (context .Background (), client )
241+ if err != nil {
242+ return nil , fmt .Errorf ("failed to create JWK cache: %w" , err )
243+ }
244+ err = cache .Register (context .Background (), resolvedURL )
237245 if err != nil {
238246 return nil , fmt .Errorf ("failed to register JWK URL: %w" , err )
239247 }
240248
249+ cachedSet , err := cache .CachedSet (resolvedURL )
250+ if err != nil {
251+ return nil , fmt .Errorf ("failed to create cached JWK set: %w" , err )
252+ }
253+
241254 // Create cache entry before attempting refresh
242255 entry = & jwkCacheEntry {
243256 URL : resolvedURL ,
244257 Cache : cache ,
245- CachedSet : jwk . NewCachedSet ( cache , resolvedURL ) ,
258+ CachedSet : cachedSet ,
246259 }
247260
248261 // Try to refresh the cache immediately
@@ -298,9 +311,8 @@ func (ja *JWTAuth) validateSignatureKeys() error {
298311 }
299312
300313 if ja .SignAlgorithm != "" {
301- var alg jwa.SignatureAlgorithm
302- if err := alg .Accept (ja .SignAlgorithm ); err != nil {
303- return fmt .Errorf ("%w: %v" , ErrInvalidSignAlgorithm , err )
314+ if _ , ok := jwa .LookupSignatureAlgorithm (ja .SignAlgorithm ); ! ok {
315+ return fmt .Errorf ("%w: %s" , ErrInvalidSignAlgorithm , ja .SignAlgorithm )
304316 }
305317 }
306318 }
@@ -329,7 +341,7 @@ func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
329341 }
330342
331343 // Use the key set associated with this URL
332- kid := sig .ProtectedHeaders ().KeyID ()
344+ kid , _ := sig .ProtectedHeaders ().KeyID ()
333345 key , found := cacheEntry .CachedSet .LookupKeyID (kid )
334346 if ! found {
335347 // Trigger an asynchronous refresh if the key is not found
@@ -340,29 +352,36 @@ func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
340352 }
341353 return fmt .Errorf ("key specified by kid %q not found in JWKs from %s" , kid , resolvedURL )
342354 }
343- sink .Key (ja .determineSigningAlgorithm (key .Algorithm (), sig .ProtectedHeaders ().Algorithm ()), key )
344- } else if ja .SignAlgorithm == string (jwa .EdDSA ) {
355+ keyAlg , _ := key .Algorithm ()
356+ headerAlg , _ := sig .ProtectedHeaders ().Algorithm ()
357+ sink .Key (ja .determineSigningAlgorithm (safeString (keyAlg ), safeString (headerAlg )), key )
358+ } else if ja .SignAlgorithm == jwa .EdDSA ().String () {
345359 if signKey , ok := ja .parsedSignKey .([]byte ); ! ok {
346360 return fmt .Errorf ("EdDSA key must be base64 encoded bytes" )
347361 } else if len (signKey ) != ed25519 .PublicKeySize {
348362 return fmt .Errorf ("key is not a proper ed25519 length" )
349363 } else {
350- sink .Key (jwa .EdDSA , ed25519 .PublicKey (signKey ))
364+ sink .Key (jwa .EdDSA () , ed25519 .PublicKey (signKey ))
351365 }
352366 } else {
353- sink .Key (ja .determineSigningAlgorithm (sig .ProtectedHeaders ().Algorithm ()), ja .parsedSignKey )
367+ headerAlg , _ := sig .ProtectedHeaders ().Algorithm ()
368+ sink .Key (ja .determineSigningAlgorithm (safeString (headerAlg )), ja .parsedSignKey )
354369 }
355370 return nil
356371 }
357372}
358373
359- func (ja * JWTAuth ) determineSigningAlgorithm (alg ... jwa.KeyAlgorithm ) jwa.SignatureAlgorithm {
360- for _ , a := range alg {
361- if a .String () != "" {
362- return jwa .SignatureAlgorithm (a .String ())
374+ func (ja * JWTAuth ) determineSigningAlgorithm (algoNames ... string ) jwa.SignatureAlgorithm {
375+ algoNames = append (algoNames , ja .SignAlgorithm ) // fallback to ja.SignAlgorithm
376+ for _ , name := range algoNames {
377+ if name == "" {
378+ continue
379+ }
380+ if alg , ok := jwa .LookupSignatureAlgorithm (name ); ok {
381+ return alg
363382 }
364383 }
365- return jwa .SignatureAlgorithm ( ja . SignAlgorithm ) // can be ""
384+ return jwa .EmptySignatureAlgorithm ()
366385}
367386
368387// Authenticate validates the JWT in the request and returns the user, if valid.
@@ -499,13 +518,17 @@ func getTokensFromCookies(r *http.Request, names []string) []string {
499518
500519func getUserID (token Token , names []string ) (string , string ) {
501520 for _ , name := range names {
502- if userClaim , ok := token .Get (name ); ok {
503- switch val := userClaim .(type ) {
504- case string :
505- return name , val
506- case float64 :
507- return name , strconv .FormatFloat (val , 'f' , - 1 , 64 )
508- }
521+ userClaim , ok := getTokenClaim (token , name )
522+ if ! ok {
523+ continue
524+ }
525+ switch val := userClaim .(type ) {
526+ case string :
527+ return name , val
528+ case float64 :
529+ return name , strconv .FormatFloat (val , 'f' , - 1 , 64 )
530+ case json.Number :
531+ return name , val .String ()
509532 }
510533 }
511534 return "" , ""
@@ -531,10 +554,10 @@ func getUserMetadata(token Token, placeholdersMap map[string]string) map[string]
531554 return nil
532555 }
533556
534- claims , _ := token . AsMap ( context . Background ()) // error ignored
557+ claims := tokenAsMap ( token )
535558 metadata := make (map [string ]string )
536559 for claim , placeholder := range placeholdersMap {
537- claimValue , ok := token . Get ( claim )
560+ claimValue , ok := getTokenClaim ( token , claim )
538561
539562 // Query nested claims.
540563 if ! ok && strings .Contains (claim , "." ) {
@@ -550,6 +573,26 @@ func getUserMetadata(token Token, placeholdersMap map[string]string) map[string]
550573 return metadata
551574}
552575
576+ func getTokenClaim (token Token , name string ) (interface {}, bool ) {
577+ var value interface {}
578+ if err := token .Get (name , & value ); err != nil {
579+ return nil , false
580+ }
581+ return value , true
582+ }
583+
584+ func tokenAsMap (token Token ) map [string ]interface {} {
585+ claims := make (map [string ]interface {})
586+ for _ , key := range token .Keys () {
587+ value , ok := getTokenClaim (token , key )
588+ if ! ok {
589+ continue
590+ }
591+ claims [key ] = value
592+ }
593+ return claims
594+ }
595+
553596func stringify (val interface {}) string {
554597 if val == nil {
555598 return ""
@@ -621,6 +664,13 @@ func parsePEMFormattedPublicKey(pubKey string) ([]byte, error) {
621664 return nil , ErrInvalidPublicKey
622665}
623666
667+ func safeString (s fmt.Stringer ) string {
668+ if s == nil {
669+ return ""
670+ }
671+ return s .String ()
672+ }
673+
624674// Interface guards
625675var (
626676 _ caddy.Provisioner = (* JWTAuth )(nil )
0 commit comments