Skip to content

Commit a06db56

Browse files
authored
Return all keys when JWT has no key ID header (#140)
1 parent 83c760a commit a06db56

File tree

2 files changed

+145
-5
lines changed

2 files changed

+145
-5
lines changed

keyfunc.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type Keyfunc interface {
2525
Keyfunc(token *jwt.Token) (any, error)
2626
KeyfuncCtx(ctx context.Context) jwt.Keyfunc
2727
Storage() jwkset.Storage
28+
VerificationKeySet(ctx context.Context) (jwt.VerificationKeySet, error)
2829
}
2930

3031
// Options are used to create a new Keyfunc.
@@ -198,7 +199,7 @@ func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc {
198199
return func(token *jwt.Token) (any, error) {
199200
kidInter, ok := token.Header[jwkset.HeaderKID]
200201
if !ok {
201-
return nil, fmt.Errorf("%w: could not find kid in JWT header", ErrKeyfunc)
202+
return k.VerificationKeySet(ctx)
202203
}
203204
kid, ok := kidInter.(string)
204205
if !ok {
@@ -236,10 +237,6 @@ func (k keyfunc) KeyfuncCtx(ctx context.Context) jwt.Keyfunc {
236237
}
237238
}
238239

239-
type publicKeyer interface {
240-
Public() crypto.PublicKey
241-
}
242-
243240
key := jwk.Key()
244241
pk, ok := key.(publicKeyer)
245242
if ok {
@@ -256,3 +253,23 @@ func (k keyfunc) Keyfunc(token *jwt.Token) (any, error) {
256253
func (k keyfunc) Storage() jwkset.Storage {
257254
return k.storage
258255
}
256+
func (k keyfunc) VerificationKeySet(ctx context.Context) (jwt.VerificationKeySet, error) {
257+
jwk, err := k.storage.KeyReadAll(ctx)
258+
if err != nil {
259+
return jwt.VerificationKeySet{}, fmt.Errorf("failed to read all JWK from storage: %w", errors.Join(err, ErrKeyfunc))
260+
}
261+
var allKeys jwt.VerificationKeySet
262+
for _, j := range jwk {
263+
key := j.Key()
264+
pk, ok := key.(publicKeyer)
265+
if ok {
266+
key = pk.Public()
267+
}
268+
allKeys.Keys = append(allKeys.Keys, key)
269+
}
270+
return allKeys, nil
271+
}
272+
273+
type publicKeyer interface {
274+
Public() crypto.PublicKey
275+
}

keyfunc_test.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,126 @@ func TestNewJWKSetJSON(t *testing.T) {
175175
t.Fatalf("The token is not valid.")
176176
}
177177
}
178+
179+
func TestVerificationKeySet(t *testing.T) {
180+
ctx := context.Background()
181+
_, priv, err := ed25519.GenerateKey(rand.Reader)
182+
if err != nil {
183+
t.Fatalf("Failed to generate ED25519 key pair: %v", err)
184+
}
185+
jwk, err := jwkset.NewJWKFromKey(priv, jwkset.JWKOptions{})
186+
if err != nil {
187+
t.Fatalf("Failed to create JWK: %v", err)
188+
}
189+
store := jwkset.NewMemoryStorage()
190+
err = store.KeyWrite(ctx, jwk)
191+
if err != nil {
192+
t.Fatalf("Failed to write JWK: %v", err)
193+
}
194+
k, err := New(Options{Ctx: ctx, Storage: store})
195+
if err != nil {
196+
t.Fatalf("Failed to create Keyfunc: %v", err)
197+
}
198+
vks, err := k.VerificationKeySet(ctx)
199+
if err != nil {
200+
t.Fatalf("VerificationKeySet failed: %v", err)
201+
}
202+
if len(vks.Keys) != 1 {
203+
t.Fatalf("Expected 1 key, got %d", len(vks.Keys))
204+
}
205+
}
206+
207+
func TestNoKIDHeaderCallsVerificationKeySet(t *testing.T) {
208+
ctx := context.Background()
209+
210+
// Generate two key pairs.
211+
_, priv1, err := ed25519.GenerateKey(rand.Reader)
212+
if err != nil {
213+
t.Fatalf("Failed to generate ED25519 key pair 1: %v", err)
214+
}
215+
_, priv2, err := ed25519.GenerateKey(rand.Reader)
216+
if err != nil {
217+
t.Fatalf("Failed to generate ED25519 key pair 2: %v", err)
218+
}
219+
220+
jwk1, err := jwkset.NewJWKFromKey(priv1, jwkset.JWKOptions{})
221+
if err != nil {
222+
t.Fatalf("Failed to create JWK 1: %v", err)
223+
}
224+
jwk2, err := jwkset.NewJWKFromKey(priv2, jwkset.JWKOptions{})
225+
if err != nil {
226+
t.Fatalf("Failed to create JWK 2: %v", err)
227+
}
228+
229+
orders := [][]jwkset.JWK{
230+
{jwk1, jwk2},
231+
{jwk2, jwk1},
232+
}
233+
privs := []ed25519.PrivateKey{priv1, priv2}
234+
235+
for i, order := range orders {
236+
store := jwkset.NewMemoryStorage()
237+
for _, jwk := range order {
238+
err = store.KeyWrite(ctx, jwk)
239+
if err != nil {
240+
t.Fatalf("Failed to write JWK: %v", err)
241+
}
242+
}
243+
k, err := New(Options{Ctx: ctx, Storage: store})
244+
if err != nil {
245+
t.Fatalf("Failed to create Keyfunc: %v", err)
246+
}
247+
// Sign a token with the corresponding private key (no KID header)
248+
token := jwt.New(jwt.SigningMethodEdDSA)
249+
tokenString, err := token.SignedString(privs[i])
250+
if err != nil {
251+
t.Fatalf("Failed to sign token: %v", err)
252+
}
253+
parsedToken, err := jwt.Parse(tokenString, k.KeyfuncCtx(ctx))
254+
if err != nil {
255+
t.Fatalf("Parse failed (order %d): %v", i+1, err)
256+
}
257+
if !parsedToken.Valid {
258+
t.Fatalf("Expected token to be valid (order %d)", i+1)
259+
}
260+
}
261+
}
262+
263+
func TestNoKIDHeaderNoMatchingJWK(t *testing.T) {
264+
ctx := context.Background()
265+
266+
_, missingFromSet, err := ed25519.GenerateKey(rand.Reader)
267+
if err != nil {
268+
t.Fatalf("Failed to generate ED25519 key pair: %v", err)
269+
}
270+
271+
_, presentInSet, err := ed25519.GenerateKey(rand.Reader)
272+
if err != nil {
273+
t.Fatalf("Failed to generate other ED25519 key pair: %v", err)
274+
}
275+
jwk, err := jwkset.NewJWKFromKey(presentInSet, jwkset.JWKOptions{})
276+
if err != nil {
277+
t.Fatalf("Failed to create JWK: %v", err)
278+
}
279+
store := jwkset.NewMemoryStorage()
280+
err = store.KeyWrite(ctx, jwk)
281+
if err != nil {
282+
t.Fatalf("Failed to write JWK: %v", err)
283+
}
284+
285+
k, err := New(Options{Ctx: ctx, Storage: store})
286+
if err != nil {
287+
t.Fatalf("Failed to create Keyfunc: %v", err)
288+
}
289+
290+
token := jwt.New(jwt.SigningMethodEdDSA)
291+
tokenString, err := token.SignedString(missingFromSet)
292+
if err != nil {
293+
t.Fatalf("Failed to sign token: %v", err)
294+
}
295+
296+
_, err = jwt.Parse(tokenString, k.KeyfuncCtx(ctx))
297+
if err == nil {
298+
t.Fatalf("Expected error due to no matching JWK, but got none")
299+
}
300+
}

0 commit comments

Comments
 (0)