Skip to content

Commit 35c303d

Browse files
committed
Keep signing algorithm fallback
1 parent e7e27fe commit 35c303d

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

jwt.go

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,9 @@ func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
352352
}
353353
return fmt.Errorf("key specified by kid %q not found in JWKs from %s", kid, resolvedURL)
354354
}
355-
keyAlg, keyAlgOk := key.Algorithm()
356-
headerAlg, headerAlgOk := sig.ProtectedHeaders().Algorithm()
357-
sink.Key(ja.determineSigningAlgorithm(keyAlg, keyAlgOk, headerAlg, headerAlgOk), key)
355+
keyAlg, _ := key.Algorithm()
356+
headerAlg, _ := sig.ProtectedHeaders().Algorithm()
357+
sink.Key(ja.determineSigningAlgorithm(safeString(keyAlg), safeString(headerAlg)), key)
358358
} else if ja.SignAlgorithm == jwa.EdDSA().String() {
359359
if signKey, ok := ja.parsedSignKey.([]byte); !ok {
360360
return fmt.Errorf("EdDSA key must be base64 encoded bytes")
@@ -364,28 +364,23 @@ func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
364364
sink.Key(jwa.EdDSA(), ed25519.PublicKey(signKey))
365365
}
366366
} else {
367-
headerAlg, headerAlgOk := sig.ProtectedHeaders().Algorithm()
368-
sink.Key(ja.determineSigningAlgorithm(nil, false, headerAlg, headerAlgOk), ja.parsedSignKey)
367+
headerAlg, _ := sig.ProtectedHeaders().Algorithm()
368+
sink.Key(ja.determineSigningAlgorithm(safeString(headerAlg)), ja.parsedSignKey)
369369
}
370370
return nil
371371
}
372372
}
373373

374-
func (ja *JWTAuth) determineSigningAlgorithm(keyAlg jwa.KeyAlgorithm, keyAlgOk bool, headerAlg jwa.SignatureAlgorithm, headerAlgOk bool) jwa.SignatureAlgorithm {
375-
if keyAlgOk {
376-
if alg, ok := jwa.LookupSignatureAlgorithm(keyAlg.String()); ok {
374+
func (ja *JWTAuth) determineSigningAlgorithm(algoNames ...string) jwa.SignatureAlgorithm {
375+
algoNames = append(algoNames, ja.SignAlgorithm) // fallback to ja.SignAlgorithm
376+
for _, name := range algoNames {
377+
if name == "" {
378+
continue
379+
}
380+
if alg, ok := jwa.LookupSignatureAlgorithm(name); ok {
377381
return alg
378382
}
379383
}
380-
if headerAlgOk && headerAlg.String() != "" {
381-
return headerAlg
382-
}
383-
if ja.SignAlgorithm == "" {
384-
return jwa.EmptySignatureAlgorithm()
385-
}
386-
if alg, ok := jwa.LookupSignatureAlgorithm(ja.SignAlgorithm); ok {
387-
return alg
388-
}
389384
return jwa.EmptySignatureAlgorithm()
390385
}
391386

@@ -669,6 +664,13 @@ func parsePEMFormattedPublicKey(pubKey string) ([]byte, error) {
669664
return nil, ErrInvalidPublicKey
670665
}
671666

667+
func safeString(s fmt.Stringer) string {
668+
if s == nil {
669+
return ""
670+
}
671+
return s.String()
672+
}
673+
672674
// Interface guards
673675
var (
674676
_ caddy.Provisioner = (*JWTAuth)(nil)

jwt_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,16 @@ func TestValidate_InvalidMetaClaims(t *testing.T) {
225225
assert.Contains(t, ja.Validate().Error(), "invalid meta claim")
226226
}
227227

228+
func TestDetermineSigningAlgorithmFallback(t *testing.T) {
229+
ja := &JWTAuth{SignAlgorithm: jwa.HS256().String()}
230+
231+
alg := ja.determineSigningAlgorithm()
232+
assert.Equal(t, jwa.HS256(), alg)
233+
234+
alg = ja.determineSigningAlgorithm(jwa.EdDSA().String())
235+
assert.Equal(t, jwa.EdDSA(), alg)
236+
}
237+
228238
func TestAuthenticate_FromAuthorizationHeader(t *testing.T) {
229239
claims := MapClaims{"sub": "ggicci"}
230240
ja := &JWTAuth{SignKey: TestSignKey, logger: testLogger}

0 commit comments

Comments
 (0)