Skip to content

Commit 3a80531

Browse files
committed
Using teh jwk claim in a JWT for signing / verifying signature.
1 parent ec508bc commit 3a80531

File tree

7 files changed

+148
-32
lines changed

7 files changed

+148
-32
lines changed

src/cryptojwt/jwe/jwe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22

3+
from cryptojwt.jwk.jwk import key_from_jwk_dict
34
from ..jwk.asym import AsymmetricKey
45
from ..jwx import JWx
56

@@ -147,6 +148,11 @@ def decrypt(self, token=None, keys=None, alg=None, cek=None):
147148
else:
148149
keys = self.pick_keys(self._get_keys(), use="enc", alg=_alg)
149150

151+
try:
152+
keys.append(key_from_jwk_dict(_jwe.headers['jwk']))
153+
except KeyError:
154+
pass
155+
150156
if not keys and not cek:
151157
raise NoSuitableDecryptionKey(_alg)
152158

src/cryptojwt/jwe/jwe_ec.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,18 @@ def dec_setup(self, token, key=None, **kwargs):
177177

178178
return self.cek
179179

180-
def encrypt(self, iv="", cek="", **kwargs):
180+
def encrypt(self, key=None, iv="", cek="", **kwargs):
181+
"""
181182
183+
:param key: *Not used>, only there to present the same API as
184+
JWE_RSA and JWE_SYM
185+
:param iv:
186+
:param cek:
187+
:param kwargs:
188+
:return:
189+
"""
182190
_msg = as_bytes(self.msg)
191+
183192
_args = self._dict
184193
try:
185194
_args["kid"] = kwargs["kid"]

src/cryptojwt/jwk/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ def __init__(self, kty="", alg="", use="", kid="", x5c=None,
4444
if not isinstance(alg, str):
4545
alg = as_unicode(alg)
4646

47-
if alg not in [
48-
"HS256", "HS384", "HS512", "RS256", "RS384", "RS512", "ES256",
49-
"ES384","ES512", "PS256", "PS384", "PS512"
50-
]:
47+
# The list comes from https://tools.ietf.org/html/rfc7518#page-6
48+
# Should map against SIGNER_ALGS in cryptojwt.jws.jws
49+
if alg not in ["HS256", "HS384", "HS512", "RS256", "RS384",
50+
"RS512", "ES256", "ES384","ES512", "PS256",
51+
"PS384", "PS512", "none"]:
5152
raise UnsupportedAlgorithm("Unknown algorithm: {}".format(alg))
5253

5354
self.alg = alg
@@ -148,7 +149,7 @@ def verify(self):
148149

149150
if self.kid:
150151
if not isinstance(self.kid, str):
151-
raise HeaderError("kid of wrong value type")
152+
raise ValueError("kid of wrong value type")
152153
return True
153154

154155
def __eq__(self, other):

src/cryptojwt/jwx.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import logging
33

44
import requests
5+
6+
from cryptojwt.jwk import JWK
57
from cryptojwt.key_bundle import KeyBundle
68

79
from .jwk.jwk import key_from_jwk_dict
@@ -67,28 +69,36 @@ def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs):
6769
continue
6870

6971
if key == "jwk":
72+
# value MUST be a string
7073
if isinstance(_val, dict):
71-
self._dict["jwk"] = key_from_jwk_dict(_val)
74+
_k = key_from_jwk_dict(_val)
75+
self._dict["jwk"] = _val
7276
elif isinstance(_val, str):
73-
self._dict["jwk"] = key_from_jwk_dict(json.loads(_val))
77+
# verify that it's a real JWK
78+
_val = json.loads(_val)
79+
_j = key_from_jwk_dict(_val)
80+
self._dict["jwk"] = _val
81+
elif isinstance(_val, JWK):
82+
self._dict['jwk'] = _val.to_dict()
7483
else:
7584
raise ValueError(
76-
'JWK must be a string or a JSON object')
85+
'JWK must be a string a JSON object or a JWK '
86+
'instance')
7787
self._jwk = self._dict['jwk']
7888
elif key == "x5c":
7989
self._dict["x5c"] = _val
8090
_pub_key = import_rsa_key(_val)
81-
self._jwk = RSAKey(_pub_key)
91+
self._jwk = RSAKey(pub_key=_pub_key).to_dict()
8292
elif key == "jku":
8393
self._jwks = KeyBundle(source=_val, httpc=self.httpc)
8494
self._dict['jku'] = _val
8595
elif "x5u" in self:
8696
try:
8797
_spec = load_x509_cert(self["x5u"], self.httpc, {})
88-
self._jwk = RSAKey(pub_key=_spec['rsa'])
98+
self._jwk = RSAKey(pub_key=_spec['rsa']).to_dict()
8999
except Exception:
90100
# ca_chain = load_x509_cert_chain(self["x5u"])
91-
pass
101+
raise ValueError('x5u')
92102
else:
93103
self._dict[key] = _val
94104

@@ -110,12 +120,11 @@ def __getattr__(self, item):
110120
def keys(self):
111121
return list(self._dict.keys())
112122

113-
def headers(self, extra=None):
114-
_extra = extra or {}
123+
def headers(self, **kwargs):
115124
_header = self._header.copy()
116125
for param in self.args:
117126
try:
118-
_header[param] = _extra[param]
127+
_header[param] = kwargs[param]
119128
except KeyError:
120129
try:
121130
if self._dict[param]:
@@ -124,9 +133,27 @@ def headers(self, extra=None):
124133
pass
125134

126135
if "jwk" in self:
127-
_header["jwk"] = self["jwk"].serialize()
128-
elif "jwk" in _extra:
129-
_header["jwk"] = extra["jwk"].serialize()
136+
_header["jwk"] = self["jwk"]
137+
else:
138+
try:
139+
_jwk = kwargs['jwk']
140+
except KeyError:
141+
pass
142+
else:
143+
try:
144+
_header["jwk"] = _jwk.serialize() # JWK instance
145+
except AttributeError:
146+
if isinstance(_jwk, dict):
147+
_header['jwk'] = _jwk # dictionary
148+
else:
149+
try:
150+
_d = json.loads(_jwk) # JSON
151+
# Verify that it's a valid JWK
152+
_k = key_from_jwk_dict(_d)
153+
except Exception:
154+
raise
155+
else:
156+
_header['jwk'] = _d
130157

131158
if "kid" in self:
132159
if not isinstance(self["kid"], str):
@@ -135,12 +162,9 @@ def headers(self, extra=None):
135162
return _header
136163

137164
def _get_keys(self):
138-
logger.debug("_get_keys(): self._dict.keys={0}".format(
139-
self._dict.keys()))
140-
141165
_keys = []
142166
if self._jwk:
143-
_keys.append(self._jwk)
167+
_keys.append(key_from_jwk_dict(self._jwk))
144168
if self._jwks is not None:
145169
_keys.extend(self._jwks.keys())
146170
return _keys
@@ -153,8 +177,8 @@ def pick_keys(self, keys, use="", alg=""):
153177
The assumption is that upper layer has made certain you only get
154178
keys you can use.
155179
156-
:param alg:
157-
:param use:
180+
:param alg: The crypto algorithm
181+
:param use: What the key should be used for
158182
:param keys: A list of JWK instances
159183
:return: A list of JWK instances that fulfill the requirements
160184
"""

tests/test_02_jwk.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515

1616
import os.path
1717

18-
from cryptojwt.exception import DeSerializationNotPossible
18+
from cryptojwt.exception import DeSerializationNotPossible, UnsupportedAlgorithm
1919
from cryptojwt.exception import WrongUsage
2020

21-
from cryptojwt.utils import as_unicode
21+
from cryptojwt.utils import as_unicode, as_bytes
2222
from cryptojwt.utils import b64e
2323
from cryptojwt.utils import long2intarr
2424
from cryptojwt.utils import base64url_to_long
2525
from cryptojwt.utils import base64_to_long
26+
from cryptojwt.jwk import JWK
2627
from cryptojwt.jwk.ec import ECKey
2728
from cryptojwt.jwk.rsa import import_private_rsa_key_from_file
2829
from cryptojwt.jwk.rsa import import_public_rsa_key_from_file
@@ -50,7 +51,7 @@ def full_path(local_file):
5051
N = 'wf-wiusGhA-gleZYQAOPQlNUIucPiqXdPVyieDqQbXXOPBe3nuggtVzeq7pVFH1dZz4dY2Q2LA5DaegvP8kRvoSB_87ds3dy3Rfym_GUSc5B0l1TgEobcyaep8jguRoHto6GWHfCfKqoUYZq4N8vh4LLMQwLR6zi6Jtu82nB5k8'
5152
E = 'AQAB'
5253

53-
JWK = {"keys": [
54+
JWK_0 = {"keys": [
5455
{'kty': 'RSA', 'use': 'foo', 'e': E, 'kid': "abc",
5556
'n': N}
5657
]}
@@ -98,8 +99,8 @@ def test_kspec():
9899

99100
jwk = _key.serialize()
100101
assert jwk["kty"] == "RSA"
101-
assert jwk["e"] == JWK["keys"][0]["e"]
102-
assert jwk["n"] == JWK["keys"][0]["n"]
102+
assert jwk["e"] == JWK_0["keys"][0]["e"]
103+
assert jwk["n"] == JWK_0["keys"][0]["n"]
103104

104105
assert not _key.has_private_key()
105106

@@ -467,3 +468,54 @@ def test_key_from_jwk_dict_sym():
467468
assert isinstance(_key, SYMKey)
468469
jwk = _key.serialize()
469470
assert jwk == {'kty': 'oct', 'k': 'YWJjZGVmZ2hpamtsbW5vcHE'}
471+
472+
473+
def test_jwk_wrong_alg():
474+
with pytest.raises(UnsupportedAlgorithm):
475+
_j = JWK(alg='xyz')
476+
477+
478+
def test_jwk_conversion():
479+
_j = JWK(use=b'sig', kid=b'1', alg=b'RS512')
480+
assert _j.use == 'sig'
481+
args = _j.common()
482+
assert set(args.keys()) == {'kty', 'use', 'kid', 'alg'}
483+
484+
485+
def test_str():
486+
_j = RSAKey(alg="RS512", use='sig', n=N, e=E)
487+
s = '{}'.format(_j)
488+
assert s.startswith("{") and s.endswith("}")
489+
sp = s.replace("'", '"')
490+
_d = json.loads(sp)
491+
assert set(_d.keys()) == {'alg', 'use', 'n', 'e', 'kty'}
492+
493+
494+
def test_verify():
495+
_j = RSAKey(alg=b"RS512", use=b'sig', n=as_bytes(N), e=E)
496+
assert _j.verify()
497+
498+
499+
def test_verify_wrong_kid():
500+
_j = RSAKey(alg=b"RS512", use=b'sig', n=as_bytes(N), e=E, kid=1)
501+
with pytest.raises(ValueError):
502+
_j.verify()
503+
504+
505+
def test_cmp():
506+
_j1 = RSAKey(alg="RS256", use="sig", n=N, e=E)
507+
_j2 = RSAKey(alg="RS256", use="sig", n=N, e=E)
508+
assert _j1 == _j2
509+
510+
511+
def test_cmp_jwk():
512+
_j1 = JWK(use='sig', kid='1', alg='RS512')
513+
_j2 = JWK(use='sig', kid='1', alg='RS512')
514+
515+
assert _j1 == _j2
516+
517+
def test_appropriate():
518+
_j1 = JWK(use='sig', kid='1', alg='RS512')
519+
520+
assert _j1.appropriate_for('sign')
521+
assert _j1.appropriate_for('encrypt') is False

tests/test_05_jwx.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_get_headers():
7777
_headers = jwx.headers()
7878
assert set(_headers.keys()) == {'jwk', 'alg'}
7979

80-
_headers = jwx.headers({'kid': '123'})
80+
_headers = jwx.headers(kid='123')
8181
assert set(_headers.keys()) == {'jwk', 'alg', 'kid'}
8282

8383

@@ -92,3 +92,9 @@ def test_decode():
9292
jwx = JWx(cty='JWT')
9393
_msg = jwx._decode(b'eyJmb28iOiJiYXIifQ')
9494
assert _msg == {'foo':'bar'}
95+
96+
97+
def test_extra_headers():
98+
jwx = JWx()
99+
headers = jwx.headers(jwk=JSON_RSA_PUB_KEY, alg="RS256")
100+
assert set(headers.keys()) == {'jwk', 'alg'}

tests/test_07_jwe.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from cryptojwt.exception import MissingKey
1111
from cryptojwt.exception import Unsupported
1212
from cryptojwt.exception import VerificationError
13-
from cryptojwt.jwe.exception import UnsupportedBitLength
13+
from cryptojwt.jwe.exception import UnsupportedBitLength, \
14+
NoSuitableEncryptionKey
1415

1516
from cryptojwt.utils import b64e
1617

@@ -375,4 +376,21 @@ def test_verify_headers():
375376
decryptor = factory(jwt, alg="ECDH-ES", enc="A128GCM")
376377
assert decryptor.jwt.verify_headers(alg='ECDH-ES', enc='A128GCM')
377378
assert decryptor.jwt.verify_headers(alg='RS256') is False
378-
assert decryptor.jwt.verify_headers(kid='RS256') is False
379+
assert decryptor.jwt.verify_headers(kid='RS256') is False
380+
381+
382+
def test_encrypt_no_keys():
383+
jwenc = JWE(plain, alg="ECDH-ES", enc="A128GCM")
384+
with pytest.raises(NoSuitableEncryptionKey):
385+
jwenc.encrypt()
386+
387+
388+
def test_encrypt_jwk_key():
389+
# This is a weird case. Signing the JWT with a key that is
390+
# published in the JWT. Still it should be possible.
391+
jwenc = JWE(plain, alg="ECDH-ES", enc="A128GCM", jwk=eck_bob)
392+
_enc = jwenc.encrypt()
393+
assert _enc
394+
decryptor = factory(_enc, alg="ECDH-ES", enc="A128GCM")
395+
res = decryptor.decrypt()
396+
assert res == plain

0 commit comments

Comments
 (0)