Skip to content

Commit a0761d8

Browse files
authored
Merge pull request #96 from Cafeine42/main
Support for Placeholder Integration on JWKURL
2 parents a248851 + 3ce36e8 commit a0761d8

File tree

2 files changed

+165
-29
lines changed

2 files changed

+165
-29
lines changed

jwt.go

Lines changed: 109 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"net/http"
1313
"strconv"
1414
"strings"
15+
"sync"
1516
"time"
1617

1718
"github.com/caddyserver/caddy/v2"
@@ -30,6 +31,23 @@ func init() {
3031
type User = caddyauth.User
3132
type Token = jwt.Token
3233

34+
// jwkCacheEntry stores the JWK cache information for a specific URL
35+
type jwkCacheEntry struct {
36+
URL string
37+
Cache *jwk.Cache
38+
CachedSet jwk.Set
39+
}
40+
41+
// refresh refreshes the JWK cache for this entry
42+
func (entry *jwkCacheEntry) refresh(ctx context.Context, logger *zap.Logger) error {
43+
_, err := entry.Cache.Refresh(ctx, entry.URL)
44+
if err != nil {
45+
logger.Warn("failed to refresh JWK cache", zap.Error(err), zap.String("url", entry.URL))
46+
return err
47+
}
48+
return nil
49+
}
50+
3351
// JWTAuth facilitates JWT (JSON Web Token) authentication.
3452
type JWTAuth struct {
3553
// SignKey is the key used by the signing algorithm to verify the signature.
@@ -150,8 +168,10 @@ type JWTAuth struct {
150168
logger *zap.Logger
151169
parsedSignKey interface{} // can be []byte, *rsa.PublicKey, *ecdsa.PublicKey, etc.
152170

153-
jwkCache *jwk.Cache
154-
jwkCachedSet jwk.Set
171+
// JWK cache by resolved URL to support placeholders in JWKURL
172+
jwkCaches map[string]*jwkCacheEntry
173+
174+
mutex sync.RWMutex
155175
}
156176

157177
// CaddyModule implements caddy.Module interface.
@@ -179,18 +199,64 @@ func (ja *JWTAuth) usingJWK() bool {
179199
}
180200

181201
func (ja *JWTAuth) setupJWKLoader() {
182-
cache := jwk.NewCache(context.Background(), jwk.WithErrSink(ja))
183-
cache.Register(ja.JWKURL)
184-
ja.jwkCache = cache
185-
ja.refreshJWKCache()
186-
ja.jwkCachedSet = jwk.NewCachedSet(cache, ja.JWKURL)
187-
ja.logger.Info("using JWKs from URL", zap.String("url", ja.JWKURL), zap.Int("loaded_keys", ja.jwkCachedSet.Len()))
202+
// Initialize cache for all possible URLs
203+
ja.mutex.Lock()
204+
ja.jwkCaches = make(map[string]*jwkCacheEntry)
205+
ja.mutex.Unlock()
206+
ja.logger.Info("JWK cache initialized for JWK URL", zap.String("jwk_url", ja.JWKURL))
188207
}
189208

190-
// refreshJWKCache refreshes the JWK cache. It validates the JWKs from the given URL.
191-
func (ja *JWTAuth) refreshJWKCache() {
192-
_, err := ja.jwkCache.Refresh(context.Background(), ja.JWKURL)
193-
ja.logger.Warn("failed to refresh JWK cache", zap.Error(err))
209+
// getOrCreateJWKCache retrieves or creates a cache for a specific JWK URL
210+
func (ja *JWTAuth) getOrCreateJWKCache(resolvedURL string) (*jwkCacheEntry, error) {
211+
// If the URL is empty, return an error
212+
if resolvedURL == "" {
213+
return nil, fmt.Errorf("resolved JWK URL is empty")
214+
}
215+
216+
// First, check if cache already exists for this URL using a read lock
217+
ja.mutex.RLock()
218+
entry, ok := ja.jwkCaches[resolvedURL]
219+
ja.mutex.RUnlock()
220+
221+
if ok {
222+
return entry, nil
223+
}
224+
225+
// If not found, acquire a write lock to create a new entry
226+
ja.mutex.Lock()
227+
defer ja.mutex.Unlock()
228+
229+
// Double-check if another goroutine created the cache while we were waiting for the lock
230+
if entry, ok := ja.jwkCaches[resolvedURL]; ok {
231+
return entry, nil
232+
}
233+
234+
// Create a new cache for this URL
235+
cache := jwk.NewCache(context.Background(), jwk.WithErrSink(ja))
236+
err := cache.Register(resolvedURL)
237+
if err != nil {
238+
return nil, fmt.Errorf("failed to register JWK URL: %w", err)
239+
}
240+
241+
// Create cache entry before attempting refresh
242+
entry = &jwkCacheEntry{
243+
URL: resolvedURL,
244+
Cache: cache,
245+
CachedSet: jwk.NewCachedSet(cache, resolvedURL),
246+
}
247+
248+
// Try to refresh the cache immediately
249+
err = entry.refresh(context.Background(), ja.logger)
250+
if err != nil {
251+
ja.logger.Warn("failed to refresh JWK cache during initialization", zap.Error(err), zap.String("url", resolvedURL))
252+
// Continue even in case of error, the important thing is that the URL is registered
253+
}
254+
255+
// Register the entry in the cache
256+
ja.jwkCaches[resolvedURL] = entry
257+
ja.logger.Info("new JWK cache created", zap.String("url", resolvedURL), zap.Int("loaded_keys", entry.CachedSet.Len()))
258+
259+
return entry, nil
194260
}
195261

196262
// Validate implements caddy.Validator interface.
@@ -216,6 +282,8 @@ func (ja *JWTAuth) Validate() error {
216282

217283
func (ja *JWTAuth) validateSignatureKeys() error {
218284
if ja.usingJWK() {
285+
// Initialize the cache structure without attempting to resolve placeholders
286+
// URLs will be resolved during usage
219287
ja.setupJWKLoader()
220288
} else {
221289
if keyBytes, asymmetric, err := parseSignKey(ja.SignKey); err != nil {
@@ -241,19 +309,36 @@ func (ja *JWTAuth) validateSignatureKeys() error {
241309
return nil
242310
}
243311

244-
func (ja *JWTAuth) keyProvider() jws.KeyProviderFunc {
245-
return func(_ context.Context, sink jws.KeySink, sig *jws.Signature, _ *jws.Message) error {
312+
// resolveJWKURL prend une requête HTTP et résout l'URL JWK avec des espaces réservés
313+
func (ja *JWTAuth) resolveJWKURL(request *http.Request) string {
314+
replacer := request.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
315+
return replacer.ReplaceAll(ja.JWKURL, "")
316+
}
317+
318+
func (ja *JWTAuth) keyProvider(request *http.Request) jws.KeyProviderFunc {
319+
return func(curContext context.Context, sink jws.KeySink, sig *jws.Signature, _ *jws.Message) error {
246320
if ja.usingJWK() {
321+
resolvedURL := ja.resolveJWKURL(request)
322+
323+
ja.logger.Info("JWK URL", zap.String("unresolved", ja.JWKURL), zap.String("resolved", resolvedURL))
324+
325+
// Get or create the cache for this URL
326+
cacheEntry, err := ja.getOrCreateJWKCache(resolvedURL)
327+
if err != nil {
328+
return fmt.Errorf("failed to get JWK cache: %w", err)
329+
}
330+
331+
// Use the key set associated with this URL
247332
kid := sig.ProtectedHeaders().KeyID()
248-
key, found := ja.jwkCachedSet.LookupKeyID(kid)
333+
key, found := cacheEntry.CachedSet.LookupKeyID(kid)
249334
if !found {
250-
// trigger a refresh if the key is not found
251-
go ja.refreshJWKCache()
335+
// Trigger an asynchronous refresh if the key is not found
336+
go cacheEntry.refresh(context.Background(), ja.logger)
252337

253338
if kid == "" {
254339
return fmt.Errorf("missing kid in JWT header")
255340
}
256-
return fmt.Errorf("key specified by kid %q not found in JWKs", kid)
341+
return fmt.Errorf("key specified by kid %q not found in JWKs from %s", kid, resolvedURL)
257342
}
258343
sink.Key(ja.determineSigningAlgorithm(key.Algorithm(), sig.ProtectedHeaders().Algorithm()), key)
259344
} else if ja.SignAlgorithm == string(jwa.EdDSA) {
@@ -304,7 +389,7 @@ func (ja *JWTAuth) Authenticate(rw http.ResponseWriter, r *http.Request) (User,
304389
jwt.WithVerify(!ja.SkipVerification),
305390
}
306391
if !ja.SkipVerification {
307-
jwtOptions = append(jwtOptions, jwt.WithKeyProvider(ja.keyProvider()))
392+
jwtOptions = append(jwtOptions, jwt.WithKeyProvider(ja.keyProvider(r)))
308393
}
309394
gotToken, err = jwt.ParseString(tokenString, jwtOptions...)
310395

@@ -515,15 +600,15 @@ func desensitizedTokenString(token string) string {
515600
func parseSignKey(signKey string) (keyBytes []byte, asymmetric bool, err error) {
516601
repl := caddy.NewReplacer()
517602
// Replace placeholders in the signKey such as {file./path/to/sign_key.txt}
518-
signKey = repl.ReplaceAll(signKey, "")
519-
if len(signKey) == 0 {
603+
resolvedSignKey := repl.ReplaceAll(signKey, "")
604+
if len(resolvedSignKey) == 0 {
520605
return nil, false, ErrMissingKeys
521606
}
522-
if strings.Contains(signKey, "-----BEGIN PUBLIC KEY-----") {
523-
keyBytes, err = parsePEMFormattedPublicKey(signKey)
607+
if strings.Contains(resolvedSignKey, "-----BEGIN PUBLIC KEY-----") {
608+
keyBytes, err = parsePEMFormattedPublicKey(resolvedSignKey)
524609
return keyBytes, true, err
525610
}
526-
keyBytes, err = base64.StdEncoding.DecodeString(signKey)
611+
keyBytes, err = base64.StdEncoding.DecodeString(resolvedSignKey)
527612
return keyBytes, false, err
528613
}
529614

jwt_test.go

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package caddyjwt
22

33
import (
4+
"context"
45
"crypto/ed25519"
56
"crypto/rand"
67
"crypto/rsa"
78
"encoding/base64"
89
"encoding/json"
10+
"github.com/caddyserver/caddy/v2"
911
"net/http"
1012
"net/http/httptest"
1113
"net/url"
@@ -153,9 +155,20 @@ func issueTokenStringEdDSA(claims MapClaims) string {
153155
return string(tokenBytes)
154156
}
155157

156-
func issueTokenStringJWK(claims MapClaims) string {
158+
func issueTokenStringJWK(claims MapClaims, options ...func(*jwt.SignOption)) string {
157159
token := buildToken(claims)
158-
tokenBytes, err := jwt.Sign(token, jwt.WithKey(jwa.RS256, jwkKey))
160+
161+
// Options par défaut
162+
signOptions := []jwt.SignOption{jwt.WithKey(jwa.RS256, jwkKey)}
163+
164+
// Appliquer les options supplémentaires
165+
for _, opt := range options {
166+
var option jwt.SignOption
167+
opt(&option)
168+
signOptions = append(signOptions, option)
169+
}
170+
171+
tokenBytes, err := jwt.Sign(token, signOptions...)
159172
panicOnError(err)
160173

161174
return string(tokenBytes)
@@ -772,45 +785,83 @@ func TestJWK(t *testing.T) {
772785
time.Sleep(3 * time.Second)
773786
ja := &JWTAuth{JWKURL: TestJWKURL, logger: testLogger}
774787
assert.Nil(t, ja.Validate())
775-
assert.Equal(t, 1, ja.jwkCachedSet.Len())
776788

789+
// Le cache sera créé lors de la première authentification
777790
token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
778791
rw := httptest.NewRecorder()
779792
r, _ := http.NewRequest("GET", "/", nil)
793+
794+
repl := caddy.NewReplacer()
795+
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
796+
r = r.WithContext(ctx)
797+
780798
r.Header.Add("Authorization", "Bearer "+token)
781799
gotUser, authenticated, err := ja.Authenticate(rw, r)
782800
assert.Nil(t, err)
783801
assert.True(t, authenticated)
784802
assert.Equal(t, User{ID: "ggicci"}, gotUser)
803+
804+
// Vérifier que le cache a bien été créé pour l'URL
805+
assert.NotNil(t, ja.jwkCaches)
806+
cachedEntry, exists := ja.jwkCaches[TestJWKURL]
807+
assert.True(t, exists, "Le cache devrait exister pour l'URL de test")
808+
assert.NotNil(t, cachedEntry)
809+
assert.Equal(t, 1, cachedEntry.CachedSet.Len())
785810
}
786811

787812
func TestJWKSet(t *testing.T) {
788813
time.Sleep(3 * time.Second)
789814
ja := &JWTAuth{JWKURL: TestJWKSetURL, logger: testLogger}
790815
assert.Nil(t, ja.Validate())
791-
assert.Equal(t, 2, ja.jwkCachedSet.Len())
792816

817+
// Authentifier pour créer le cache
793818
token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
794819
rw := httptest.NewRecorder()
795820
r, _ := http.NewRequest("GET", "/", nil)
821+
822+
repl := caddy.NewReplacer()
823+
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
824+
r = r.WithContext(ctx)
825+
796826
r.Header.Add("Authorization", "Bearer "+token)
797827
gotUser, authenticated, err := ja.Authenticate(rw, r)
798828
assert.Nil(t, err)
799829
assert.True(t, authenticated)
800830
assert.Equal(t, User{ID: "ggicci"}, gotUser)
831+
832+
// Vérifier que le cache a été créé correctement
833+
assert.NotNil(t, ja.jwkCaches)
834+
cachedEntry, exists := ja.jwkCaches[TestJWKSetURL]
835+
assert.True(t, exists, "Le cache devrait exister pour l'URL du set JWK")
836+
assert.NotNil(t, cachedEntry)
837+
assert.Equal(t, 2, cachedEntry.CachedSet.Len())
801838
}
802839

803840
func TestJWKSet_KeyNotFound(t *testing.T) {
804841
time.Sleep(3 * time.Second)
805842
ja := &JWTAuth{JWKURL: TestJWKSetURLInapplicable, logger: testLogger}
806843
assert.Nil(t, ja.Validate())
807-
assert.Equal(t, 2, ja.jwkCachedSet.Len())
808844

845+
// Première requête pour créer le cache
809846
token := issueTokenStringJWK(MapClaims{"sub": "ggicci"})
810847
rw := httptest.NewRecorder()
811848
r, _ := http.NewRequest("GET", "/", nil)
849+
850+
repl := caddy.NewReplacer()
851+
ctx := context.WithValue(r.Context(), caddy.ReplacerCtxKey, repl)
852+
r = r.WithContext(ctx)
853+
812854
r.Header.Add("Authorization", "Bearer "+token)
813855
gotUser, authenticated, err := ja.Authenticate(rw, r)
856+
857+
// Vérifier que le cache a été créé correctement
858+
assert.NotNil(t, ja.jwkCaches)
859+
cachedEntry, exists := ja.jwkCaches[TestJWKSetURLInapplicable]
860+
assert.True(t, exists, "Le cache devrait exister pour l'URL inapplicable")
861+
assert.NotNil(t, cachedEntry)
862+
assert.Equal(t, 2, cachedEntry.CachedSet.Len())
863+
864+
// Vérifier que l'authentification a échoué car la clé n'est pas trouvée
814865
assert.Error(t, err)
815866
assert.False(t, authenticated)
816867
assert.Empty(t, gotUser.ID)

0 commit comments

Comments
 (0)