@@ -17,10 +17,12 @@ import (
1717
1818 "github.com/caddyserver/caddy/v2"
1919 "github.com/caddyserver/caddy/v2/modules/caddyhttp/caddyauth"
20- "github.com/lestrrat-go/jwx/v2/jwa"
21- "github.com/lestrrat-go/jwx/v2/jwk"
22- "github.com/lestrrat-go/jwx/v2/jws"
23- "github.com/lestrrat-go/jwx/v2/jwt"
20+ "github.com/lestrrat-go/httprc/v3"
21+ "github.com/lestrrat-go/httprc/v3/errsink"
22+ "github.com/lestrrat-go/jwx/v3/jwa"
23+ "github.com/lestrrat-go/jwx/v3/jwk"
24+ "github.com/lestrrat-go/jwx/v3/jws"
25+ "github.com/lestrrat-go/jwx/v3/jwt"
2426 "go.uber.org/zap"
2527)
2628
@@ -232,17 +234,28 @@ func (ja *JWTAuth) getOrCreateJWKCache(resolvedURL string) (*jwkCacheEntry, erro
232234 }
233235
234236 // Create a new cache for this URL
235- cache := jwk .NewCache (context .Background (), jwk .WithErrSink (ja ))
236- err := cache .Register (resolvedURL )
237+ client := httprc .NewClient (httprc .WithErrorSink (errsink .NewFunc (func (_ context.Context , err error ) {
238+ ja .Error (err )
239+ })))
240+ cache , err := jwk .NewCache (context .Background (), client )
241+ if err != nil {
242+ return nil , fmt .Errorf ("failed to create JWK cache: %w" , err )
243+ }
244+ err = cache .Register (context .Background (), resolvedURL )
237245 if err != nil {
238246 return nil , fmt .Errorf ("failed to register JWK URL: %w" , err )
239247 }
240248
249+ cachedSet , err := cache .CachedSet (resolvedURL )
250+ if err != nil {
251+ return nil , fmt .Errorf ("failed to create cached JWK set: %w" , err )
252+ }
253+
241254 // Create cache entry before attempting refresh
242255 entry = & jwkCacheEntry {
243256 URL : resolvedURL ,
244257 Cache : cache ,
245- CachedSet : jwk . NewCachedSet ( cache , resolvedURL ) ,
258+ CachedSet : cachedSet ,
246259 }
247260
248261 // Try to refresh the cache immediately
@@ -298,9 +311,8 @@ func (ja *JWTAuth) validateSignatureKeys() error {
298311 }
299312
300313 if ja .SignAlgorithm != "" {
301- var alg jwa.SignatureAlgorithm
302- if err := alg .Accept (ja .SignAlgorithm ); err != nil {
303- return fmt .Errorf ("%w: %v" , ErrInvalidSignAlgorithm , err )
314+ if _ , ok := jwa .LookupSignatureAlgorithm (ja .SignAlgorithm ); ! ok {
315+ return fmt .Errorf ("%w: %s" , ErrInvalidSignAlgorithm , ja .SignAlgorithm )
304316 }
305317 }
306318 }
@@ -329,7 +341,7 @@ func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
329341 }
330342
331343 // Use the key set associated with this URL
332- kid := sig .ProtectedHeaders ().KeyID ()
344+ kid , _ := sig .ProtectedHeaders ().KeyID ()
333345 key , found := cacheEntry .CachedSet .LookupKeyID (kid )
334346 if ! found {
335347 // Trigger an asynchronous refresh if the key is not found
@@ -340,29 +352,41 @@ func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
340352 }
341353 return fmt .Errorf ("key specified by kid %q not found in JWKs from %s" , kid , resolvedURL )
342354 }
343- sink .Key (ja .determineSigningAlgorithm (key .Algorithm (), sig .ProtectedHeaders ().Algorithm ()), key )
344- } else if ja .SignAlgorithm == string (jwa .EdDSA ) {
355+ keyAlg , keyAlgOk := key .Algorithm ()
356+ headerAlg , headerAlgOk := sig .ProtectedHeaders ().Algorithm ()
357+ sink .Key (ja .determineSigningAlgorithm (keyAlg , keyAlgOk , headerAlg , headerAlgOk ), key )
358+ } else if ja .SignAlgorithm == jwa .EdDSA ().String () {
345359 if signKey , ok := ja .parsedSignKey .([]byte ); ! ok {
346360 return fmt .Errorf ("EdDSA key must be base64 encoded bytes" )
347361 } else if len (signKey ) != ed25519 .PublicKeySize {
348362 return fmt .Errorf ("key is not a proper ed25519 length" )
349363 } else {
350- sink .Key (jwa .EdDSA , ed25519 .PublicKey (signKey ))
364+ sink .Key (jwa .EdDSA () , ed25519 .PublicKey (signKey ))
351365 }
352366 } else {
353- sink .Key (ja .determineSigningAlgorithm (sig .ProtectedHeaders ().Algorithm ()), ja .parsedSignKey )
367+ headerAlg , headerAlgOk := sig .ProtectedHeaders ().Algorithm ()
368+ sink .Key (ja .determineSigningAlgorithm (nil , false , headerAlg , headerAlgOk ), ja .parsedSignKey )
354369 }
355370 return nil
356371 }
357372}
358373
359- func (ja * JWTAuth ) determineSigningAlgorithm (alg ... jwa.KeyAlgorithm ) jwa.SignatureAlgorithm {
360- for _ , a := range alg {
361- if a . String () != "" {
362- return jwa . SignatureAlgorithm ( a . String ())
374+ func (ja * JWTAuth ) determineSigningAlgorithm (keyAlg jwa.KeyAlgorithm , keyAlgOk bool , headerAlg jwa. SignatureAlgorithm , headerAlgOk bool ) jwa.SignatureAlgorithm {
375+ if keyAlgOk {
376+ if alg , ok := jwa . LookupSignatureAlgorithm ( keyAlg . String ()); ok {
377+ return alg
363378 }
364379 }
365- return jwa .SignatureAlgorithm (ja .SignAlgorithm ) // can be ""
380+ if headerAlgOk && headerAlg .String () != "" {
381+ return headerAlg
382+ }
383+ if ja .SignAlgorithm == "" {
384+ return jwa .EmptySignatureAlgorithm ()
385+ }
386+ if alg , ok := jwa .LookupSignatureAlgorithm (ja .SignAlgorithm ); ok {
387+ return alg
388+ }
389+ return jwa .EmptySignatureAlgorithm ()
366390}
367391
368392// Authenticate validates the JWT in the request and returns the user, if valid.
@@ -499,13 +523,17 @@ func getTokensFromCookies(r *http.Request, names []string) []string {
499523
500524func getUserID (token Token , names []string ) (string , string ) {
501525 for _ , name := range names {
502- if userClaim , ok := token .Get (name ); ok {
503- switch val := userClaim .(type ) {
504- case string :
505- return name , val
506- case float64 :
507- return name , strconv .FormatFloat (val , 'f' , - 1 , 64 )
508- }
526+ userClaim , ok := getTokenClaim (token , name )
527+ if ! ok {
528+ continue
529+ }
530+ switch val := userClaim .(type ) {
531+ case string :
532+ return name , val
533+ case float64 :
534+ return name , strconv .FormatFloat (val , 'f' , - 1 , 64 )
535+ case json.Number :
536+ return name , val .String ()
509537 }
510538 }
511539 return "" , ""
@@ -531,10 +559,10 @@ func getUserMetadata(token Token, placeholdersMap map[string]string) map[string]
531559 return nil
532560 }
533561
534- claims , _ := token . AsMap ( context . Background ()) // error ignored
562+ claims := tokenAsMap ( token )
535563 metadata := make (map [string ]string )
536564 for claim , placeholder := range placeholdersMap {
537- claimValue , ok := token . Get ( claim )
565+ claimValue , ok := getTokenClaim ( token , claim )
538566
539567 // Query nested claims.
540568 if ! ok && strings .Contains (claim , "." ) {
@@ -550,6 +578,26 @@ func getUserMetadata(token Token, placeholdersMap map[string]string) map[string]
550578 return metadata
551579}
552580
581+ func getTokenClaim (token Token , name string ) (interface {}, bool ) {
582+ var value interface {}
583+ if err := token .Get (name , & value ); err != nil {
584+ return nil , false
585+ }
586+ return value , true
587+ }
588+
589+ func tokenAsMap (token Token ) map [string ]interface {} {
590+ claims := make (map [string ]interface {})
591+ for _ , key := range token .Keys () {
592+ value , ok := getTokenClaim (token , key )
593+ if ! ok {
594+ continue
595+ }
596+ claims [key ] = value
597+ }
598+ return claims
599+ }
600+
553601func stringify (val interface {}) string {
554602 if val == nil {
555603 return ""
0 commit comments