Skip to content

Commit 82bd8aa

Browse files
authored
Merge pull request #188 from blag/support-loading-jwks
Support loading JWKs directly
2 parents eb7c5fe + 60fa95d commit 82bd8aa

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

jose/jws.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections import Mapping, Iterable # Python 2, will be deprecated in Python 3.8
1010

1111
from jose import jwk
12+
from jose.backends.base import Key
1213
from jose.constants import ALGORITHMS
1314
from jose.exceptions import JWSError
1415
from jose.exceptions import JWSSignatureError
@@ -163,10 +164,11 @@ def _encode_payload(payload):
163164
return base64url_encode(payload)
164165

165166

166-
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key_data):
167+
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):
167168
signing_input = b'.'.join([encoded_header, encoded_claims])
168169
try:
169-
key = jwk.construct(key_data, algorithm)
170+
if not isinstance(key, Key):
171+
key = jwk.construct(key, algorithm)
170172
signature = key.sign(signing_input)
171173
except Exception as e:
172174
raise JWSError(e)
@@ -213,7 +215,8 @@ def _load(jwt):
213215

214216
def _sig_matches_keys(keys, signing_input, signature, alg):
215217
for key in keys:
216-
key = jwk.construct(key, alg)
218+
if not isinstance(key, Key):
219+
key = jwk.construct(key, alg)
217220
try:
218221
if key.verify(signing_input, signature):
219222
return True
@@ -224,6 +227,9 @@ def _sig_matches_keys(keys, signing_input, signature, alg):
224227

225228
def _get_keys(key):
226229

230+
if isinstance(key, Key):
231+
return (key,)
232+
227233
try:
228234
key = json.loads(key, parse_int=str, parse_float=str)
229235
except Exception:

tests/test_jws.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ def test_round_trip_with_different_key_types(self, key):
8989
assert verified_data['testkey'] == 'testvalue'
9090

9191

92+
class TestJWK(object):
93+
def test_jwk(self, payload):
94+
key_data = 'key'
95+
key = jwk.construct(key_data, algorithm='HS256')
96+
token = jws.sign(payload, key, algorithm=ALGORITHMS.HS256)
97+
assert jws.verify(token, key_data, ALGORITHMS.HS256) == payload
98+
99+
92100
class TestHMAC(object):
93101

94102
def testHMAC256(self, payload):
@@ -272,6 +280,10 @@ def test_tuple(self):
272280
def test_list(self):
273281
assert ['test', 'key'] == jws._get_keys(['test', 'key'])
274282

283+
def test_jwk(self):
284+
jwkey = jwk.construct('key', algorithm='HS256')
285+
assert (jwkey,) == jws._get_keys(jwkey)
286+
275287

276288
class TestRSA(object):
277289

0 commit comments

Comments
 (0)