Skip to content

Commit 2a121c1

Browse files
author
Thibaut Cholley
committed
Support placeholder on JWKURL
1 parent 1c34bf9 commit 2a121c1

File tree

2 files changed

+152
-28
lines changed

2 files changed

+152
-28
lines changed

jwt.go

Lines changed: 96 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,11 @@ type JWTAuth struct {
150150
logger *zap.Logger
151151
parsedSignKey interface{} // can be []byte, *rsa.PublicKey, *ecdsa.PublicKey, etc.
152152

153-
jwkCache *jwk.Cache
154-
jwkCachedSet jwk.Set
153+
// JWK cache by resolved URL to support placeholders in JWKURL
154+
jwkCaches map[string]*struct {
155+
cache *jwk.Cache
156+
cachedSet jwk.Set
157+
}
155158
}
156159

157160
// CaddyModule implements caddy.Module interface.
@@ -179,18 +182,72 @@ func (ja *JWTAuth) usingJWK() bool {
179182
}
180183

181184
func (ja *JWTAuth) setupJWKLoader() {
185+
// Initialiser le cache pour toutes les URL possibles
186+
ja.jwkCaches = make(map[string]*struct {
187+
cache *jwk.Cache
188+
cachedSet jwk.Set
189+
})
190+
ja.logger.Info("JWK cache initialized for placeholder URL", zap.String("placeholder_url", ja.JWKURL))
191+
}
192+
193+
// getOrCreateJWKCache retrieves or creates a cache for a specific JWK URL
194+
func (ja *JWTAuth) getOrCreateJWKCache(resolvedURL string) (*struct {
195+
cache *jwk.Cache
196+
cachedSet jwk.Set
197+
}, error) {
198+
// If the URL is empty, return an error
199+
if resolvedURL == "" {
200+
return nil, fmt.Errorf("resolved JWK URL is empty")
201+
}
202+
203+
// Check if cache already exists for this URL
204+
if entry, ok := ja.jwkCaches[resolvedURL]; ok {
205+
return entry, nil
206+
}
207+
208+
// Create a new cache for this URL
182209
cache := jwk.NewCache(context.Background(), jwk.WithErrSink(ja))
183-
cache.Register(ja.JWKURL)
184-
ja.jwkCache = cache
185-
ja.refreshJWKCache()
186-
ja.jwkCachedSet = jwk.NewCachedSet(cache, ja.JWKURL)
187-
ja.logger.Info("using JWKs from URL", zap.String("url", ja.JWKURL), zap.Int("loaded_keys", ja.jwkCachedSet.Len()))
210+
err := cache.Register(resolvedURL)
211+
if err != nil {
212+
return nil, fmt.Errorf("failed to register JWK URL: %w", err)
213+
}
214+
215+
// Try to refresh the cache immediately
216+
_, err = cache.Refresh(context.Background(), resolvedURL)
217+
if err != nil {
218+
ja.logger.Warn("failed to refresh JWK cache during initialization", zap.Error(err), zap.String("url", resolvedURL))
219+
// Continue even in case of error, the important thing is that the URL is registered
220+
}
221+
222+
cachedSet := jwk.NewCachedSet(cache, resolvedURL)
223+
entry := &struct {
224+
cache *jwk.Cache
225+
cachedSet jwk.Set
226+
}{
227+
cache: cache,
228+
cachedSet: cachedSet,
229+
}
230+
231+
// Register the entry in the cache
232+
ja.jwkCaches[resolvedURL] = entry
233+
ja.logger.Info("new JWK cache created", zap.String("url", resolvedURL), zap.Int("loaded_keys", cachedSet.Len()))
234+
235+
return entry, nil
188236
}
189237

190-
// refreshJWKCache refreshes the JWK cache. It validates the JWKs from the given URL.
191-
func (ja *JWTAuth) refreshJWKCache() {
192-
_, err := ja.jwkCache.Refresh(context.Background(), ja.JWKURL)
193-
ja.logger.Warn("failed to refresh JWK cache", zap.Error(err))
238+
// refreshJWKCache refreshes the JWK cache for a specific URL. It validates the JWKs from the given URL.
239+
func (ja *JWTAuth) refreshJWKCache(resolvedURL string) error {
240+
entry, err := ja.getOrCreateJWKCache(resolvedURL)
241+
if err != nil {
242+
return err
243+
}
244+
245+
_, err = entry.cache.Refresh(context.Background(), resolvedURL)
246+
if err != nil {
247+
ja.logger.Warn("failed to refresh JWK cache", zap.Error(err), zap.String("url", resolvedURL))
248+
return err
249+
}
250+
return nil
194251
}
195252

196253
// Validate implements caddy.Validator interface.
@@ -216,6 +273,8 @@ func (ja *JWTAuth) Validate() error {
216273

217274
func (ja *JWTAuth) validateSignatureKeys() error {
218275
if ja.usingJWK() {
276+
// Initialize the cache structure without attempting to resolve placeholders
277+
// URLs will be resolved during usage
219278
ja.setupJWKLoader()
220279
} else {
221280
if keyBytes, asymmetric, err := parseSignKey(ja.SignKey); err != nil {
@@ -241,19 +300,33 @@ func (ja *JWTAuth) validateSignatureKeys() error {
241300
return nil
242301
}
243302

244-
func (ja *JWTAuth) keyProvider() jws.KeyProviderFunc {
245-
return func(_ context.Context, sink jws.KeySink, sig *jws.Signature, _ *jws.Message) error {
303+
func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
304+
return func(context context.Context, sink jws.KeySink, sig *jws.Signature, _ *jws.Message) error {
246305
if ja.usingJWK() {
306+
// Resolve JWKURL with placeholders
307+
replacer := request.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
308+
resolvedURL := replacer.ReplaceAll(ja.JWKURL, "")
309+
310+
ja.logger.Info("JWK unresolved", zap.String("placeholder_url", ja.JWKURL))
311+
ja.logger.Info("JWK resolved", zap.String("placeholder_url", resolvedURL))
312+
313+
// Get or create the cache for this URL
314+
cacheEntry, err := ja.getOrCreateJWKCache(resolvedURL)
315+
if err != nil {
316+
return fmt.Errorf("failed to get JWK cache: %w", err)
317+
}
318+
319+
// Use the key set associated with this URL
247320
kid := sig.ProtectedHeaders().KeyID()
248-
key, found := ja.jwkCachedSet.LookupKeyID(kid)
321+
key, found := cacheEntry.cachedSet.LookupKeyID(kid)
249322
if !found {
250-
// trigger a refresh if the key is not found
251-
go ja.refreshJWKCache()
323+
// Trigger an asynchronous refresh if the key is not found
324+
go ja.refreshJWKCache(resolvedURL)
252325

253326
if kid == "" {
254327
return fmt.Errorf("missing kid in JWT header")
255328
}
256-
return fmt.Errorf("key specified by kid %q not found in JWKs", kid)
329+
return fmt.Errorf("key specified by kid %q not found in JWKs from %s", kid, resolvedURL)
257330
}
258331
sink.Key(ja.determineSigningAlgorithm(key.Algorithm(), sig.ProtectedHeaders().Algorithm()), key)
259332
} else if ja.SignAlgorithm == string(jwa.EdDSA) {
@@ -304,7 +377,7 @@ func (ja *JWTAuth) Authenticate(rw http.ResponseWriter, r *http.Request) (User,
304377
jwt.WithVerify(!ja.SkipVerification),
305378
}
306379
if !ja.SkipVerification {
307-
jwtOptions = append(jwtOptions, jwt.WithKeyProvider(ja.keyProvider()))
380+
jwtOptions = append(jwtOptions, jwt.WithKeyProvider(ja.keyProvider(r)))
308381
}
309382
gotToken, err = jwt.ParseString(tokenString, jwtOptions...)
310383

@@ -515,15 +588,15 @@ func desensitizedTokenString(token string) string {
515588
func parseSignKey(signKey string) (keyBytes []byte, asymmetric bool, err error) {
516589
repl := caddy.NewReplacer()
517590
// Replace placeholders in the signKey such as {file./path/to/sign_key.txt}
518-
signKey = repl.ReplaceAll(signKey, "")
519-
if len(signKey) == 0 {
591+
resolvedSignKey := repl.ReplaceAll(signKey, "")
592+
if len(resolvedSignKey) == 0 {
520593
return nil, false, ErrMissingKeys
521594
}
522-
if strings.Contains(signKey, "-----BEGIN PUBLIC KEY-----") {
523-
keyBytes, err = parsePEMFormattedPublicKey(signKey)
595+
if strings.Contains(resolvedSignKey, "-----BEGIN PUBLIC KEY-----") {
596+
keyBytes, err = parsePEMFormattedPublicKey(resolvedSignKey)
524597
return keyBytes, true, err
525598
}
526-
keyBytes, err = base64.StdEncoding.DecodeString(signKey)
599+
keyBytes, err = base64.StdEncoding.DecodeString(resolvedSignKey)
527600
return keyBytes, false, err
528601
}
529602

jwt_test.go

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package caddyjwt
22

33
import (
4+
"context"
45
"crypto/ed25519"
56
"crypto/rand"
67
"crypto/rsa"
78
"encoding/base64"
89
"encoding/json"
10+
"github.com/caddyserver/caddy/v2"
911
"net/http"
1012
"net/http/httptest"
1113
"net/url"
@@ -153,9 +155,20 @@ func issueTokenStringEdDSA(claims MapClaims) string {
153155
return string(tokenBytes)
154156
}
155157

156-
func issueTokenStringJWK(claims MapClaims) string {
158+
func issueTokenStringJWK(claims MapClaims, options ...func(*jwt.SignOption)) string {
157159
token := buildToken(claims)
158-
tokenBytes, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, jwkKey))
160+
161+
// Options par défaut
162+
signOptions := []jwt.SignOption{jwt.WithKey(jwa.RS256, jwkKey)}
163+
164+
// Appliquer les options supplémentaires
165+
for _, opt := range options {
166+
var option jwt.SignOption
167+
opt(&option)
168+
signOptions = append(signOptions, option)
169+
}
170+
171+
tokenBytes, err := jwt.Sign(token, signOptions...)
159172
panicOnError(err)
160173

161174
return string(tokenBytes)
@@ -772,45 +785,83 @@ func TestJWK(t *testing.T) {
772785
time.Sleep(3 * time.Second)
773786
ja := &JWTAuth{JWKURL: TestJWKURL, logger: testLogger}
774787
assert.Nil(t, ja.Validate())
775-
assert.Equal(t, 1, ja.jwkCachedSet.Len())
776788

789+
// Le cache sera créé lors de la première authentification
777790
token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
778791
rw := httptest.NewRecorder()
779792
r, _ := http.NewRequest("GET", "/", nil)
793+
794+
repl := caddy.NewReplacer()
795+
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
796+
r = r.WithContext(ctx)
797+
780798
r.Header.Add("Authorization", "Bearer "+token)
781799
gotUser, authenticated, err := ja.Authenticate(rw, r)
782800
assert.Nil(t, err)
783801
assert.True(t, authenticated)
784802
assert.Equal(t, User{ID: "ggicci"}, gotUser)
803+
804+
// Vérifier que le cache a bien été créé pour l'URL
805+
assert.NotNil(t, ja.jwkCaches)
806+
cachedEntry, exists := ja.jwkCaches[TestJWKURL]
807+
assert.True(t, exists, "Le cache devrait exister pour l'URL de test")
808+
assert.NotNil(t, cachedEntry)
809+
assert.Equal(t, 1, cachedEntry.cachedSet.Len())
785810
}
786811

787812
func TestJWKSet(t *testing.T) {
788813
time.Sleep(3 * time.Second)
789814
ja := &JWTAuth{JWKURL: TestJWKSetURL, logger: testLogger}
790815
assert.Nil(t, ja.Validate())
791-
assert.Equal(t, 2, ja.jwkCachedSet.Len())
792816

817+
// Authentifier pour créer le cache
793818
token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
794819
rw := httptest.NewRecorder()
795820
r, _ := http.NewRequest("GET", "/", nil)
821+
822+
repl := caddy.NewReplacer()
823+
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
824+
r = r.WithContext(ctx)
825+
796826
r.Header.Add("Authorization", "Bearer "+token)
797827
gotUser, authenticated, err := ja.Authenticate(rw, r)
798828
assert.Nil(t, err)
799829
assert.True(t, authenticated)
800830
assert.Equal(t, User{ID: "ggicci"}, gotUser)
831+
832+
// Vérifier que le cache a été créé correctement
833+
assert.NotNil(t, ja.jwkCaches)
834+
cachedEntry, exists := ja.jwkCaches[TestJWKSetURL]
835+
assert.True(t, exists, "Le cache devrait exister pour l'URL du set JWK")
836+
assert.NotNil(t, cachedEntry)
837+
assert.Equal(t, 2, cachedEntry.cachedSet.Len())
801838
}
802839

803840
func TestJWKSet_KeyNotFound(t *testing.T) {
804841
time.Sleep(3 * time.Second)
805842
ja := &JWTAuth{JWKURL: TestJWKSetURLInapplicable, logger: testLogger}
806843
assert.Nil(t, ja.Validate())
807-
assert.Equal(t, 2, ja.jwkCachedSet.Len())
808844

845+
// Première requête pour créer le cache
809846
token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
810847
rw := httptest.NewRecorder()
811848
r, _ := http.NewRequest("GET", "/", nil)
849+
850+
repl := caddy.NewReplacer()
851+
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
852+
r = r.WithContext(ctx)
853+
812854
r.Header.Add("Authorization", "Bearer "+token)
813855
gotUser, authenticated, err := ja.Authenticate(rw, r)
856+
857+
// Vérifier que le cache a été créé correctement
858+
assert.NotNil(t, ja.jwkCaches)
859+
cachedEntry, exists := ja.jwkCaches[TestJWKSetURLInapplicable]
860+
assert.True(t, exists, "Le cache devrait exister pour l'URL inapplicable")
861+
assert.NotNil(t, cachedEntry)
862+
assert.Equal(t, 2, cachedEntry.cachedSet.Len())
863+
864+
// Vérifier que l'authentification a échoué car la clé n'est pas trouvée
814865
assert.Error(t, err)
815866
assert.False(t, authenticated)
816867
assert.Empty(t, gotUser.ID)

0 commit comments

Comments
 (0)