Skip to content

Commit dc617fe

Browse files
committed
If it's an EC key you ant make sure you get one with the correct curve.
1 parent 566921d commit dc617fe

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

src/cryptojwt/jwt.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,29 @@ def pick_key(keys, use, alg='', key_type='', kid=''):
4949
key_type = jws_alg2keytype(alg)
5050
else:
5151
key_type = jwe_alg2keytype(alg)
52+
5253
for key in keys:
5354
if key.use and key.use != use:
5455
continue
5556

5657
if key.kty == key_type:
57-
if key.alg == '' or alg == '' or key.alg == alg:
58-
if key.kid == '' or kid == '' or key.kid == kid:
58+
if key.kid and kid:
59+
if key.kid == kid:
5960
res.append(key)
61+
else:
62+
continue
63+
64+
if key.alg == '':
65+
if alg:
66+
if key_type == 'EC':
67+
if key.crv == 'P-{}'.format(alg[2:]):
68+
res.append(key)
69+
continue
70+
res.append(key)
71+
elif alg and key.alg == alg:
72+
res.append(key)
73+
else:
74+
res.append(key)
6075
return res
6176

6277

tests/test_09_jwt.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22

3-
from cryptojwt.jwt import JWT
3+
from cryptojwt.jwt import JWT, pick_key
44
from cryptojwt.key_bundle import KeyBundle
5-
from cryptojwt.key_jar import KeyJar
5+
from cryptojwt.key_jar import KeyJar, init_key_jar
66

77
__author__ = 'Roland Hedberg'
88

@@ -164,3 +164,33 @@ def test_msg_cls():
164164
bob.msg_cls = DummyMsg
165165
info = bob.unpack(_jwt)
166166
assert isinstance(info, DummyMsg)
167+
168+
169+
KEY_DEFS = [
170+
{"type": "RSA", "use": ["sig"]},
171+
{"type": "RSA", "use": ["enc"]},
172+
{"type": "EC", "crv": "P-256", "use": ["sig"]},
173+
{"type": "EC", "crv": "P-384", "use": ["sig"]}
174+
]
175+
176+
kj = init_key_jar(key_defs=KEY_DEFS)
177+
178+
179+
def test_pick_key():
180+
keys = kj.get_issuer_keys('')
181+
182+
_k = pick_key(keys, 'sig', 'RS256')
183+
assert len(_k) == 1
184+
185+
_k = pick_key(keys, 'sig', 'ES256')
186+
assert len(_k) == 1
187+
188+
_k = pick_key(keys, 'sig', 'ES384')
189+
assert len(_k) == 1
190+
191+
_k = pick_key(keys, 'enc', "RSA-OAEP-256")
192+
assert len(_k) == 1
193+
194+
_k = pick_key(keys, 'enc', "ECDH-ES")
195+
assert len(_k) == 0
196+

0 commit comments

Comments
 (0)