Skip to content

Commit dbbee89

Browse files
committed
Stricter treatment of JWT headers.
Signature algorithm should always be checked, the same for encryption alg and enc.
1 parent 5d2efb2 commit dbbee89

File tree

7 files changed

+117
-51
lines changed

7 files changed

+117
-51
lines changed

src/cryptojwt/jwe/jwe.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
6161
"""
6262
6363
:param keys: A set of possibly usable keys
64-
:param context: If the other party's public or my private key should be
65-
used for encryption
6664
:param cek: Content master key
6765
:param iv: Initialization vector
6866
:param kwargs: Extra key word arguments
@@ -206,8 +204,8 @@ def alg2keytype(self, alg):
206204
return alg2keytype(alg)
207205

208206

209-
def factory(token):
210-
_jwt = JWEnc().unpack(token)
207+
def factory(token, **kwargs):
208+
_jwt = JWEnc().unpack(token, **kwargs)
211209
if _jwt.is_jwe():
212210
_jwe = JWE()
213211
_jwe.jwt = _jwt

src/cryptojwt/jws/jws.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -410,33 +410,24 @@ def verify_alg(self, alg):
410410
Specifically check that the 'alg' claim has a specific value
411411
412412
:param alg: The expected alg value
413-
:raises: KeyError if the 'alg' is not present in the header
414413
:return: True if the alg value in the header is the same as the one
415-
given.
414+
given. Returns False if no 'alg' claim exists in the header.
416415
"""
417-
if alg == self.jwt.headers['alg']:
418-
return True
419-
else:
416+
try:
417+
return self.jwt.verify_header('alg', alg)
418+
except KeyError:
420419
return False
421420

422-
def verify_header(self, key, val):
423-
"""
424-
Check that a particular header claim is present as a has specific value
425421

426-
:param key: The claim
427-
:param val: The value of the claim
428-
:raises: KeyError if the claim is not present in the header
429-
:return: True if the claim exists in the header and has the prescribed
430-
value
431-
"""
432-
if val == self.jwt.headers[key]:
433-
return True
434-
else:
435-
return False
422+
def factory(token):
423+
"""
424+
Instantiate an JWS instance if the token is a signed JWT.
436425
426+
:param token: The token that might be a signed JWT
427+
:return: A JWS instance if the token was a signed JWT, otherwise None
428+
"""
437429

438-
def factory(token, **kwargs):
439-
_jw = JWS(**kwargs)
430+
_jw = JWS()
440431
if _jw.is_jws(token):
441432
return _jw
442433
else:

src/cryptojwt/jwt.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from json import JSONDecodeError
66

77

8+
from .exception import HeaderError
89
from .exception import VerificationError
910
from .utils import as_unicode
1011
from .jwe.utils import alg2keytype as jwe_alg2keytype
@@ -69,7 +70,7 @@ def __init__(self, key_jar=None, iss='', lifetime=0,
6970
self.iss = iss # My identifier
7071
self.lifetime = lifetime # default life time of the signature
7172
self.sign = sign # default signing or not
72-
self.sign_alg = sign_alg # default signing algorithm
73+
self.alg = sign_alg # default signing algorithm
7374
self.encrypt = encrypt # default encrypting or not
7475
self.enc_alg = enc_alg # CEK encryption algorithm
7576
self.enc_enc = enc_enc # content encryption algorithm
@@ -118,11 +119,11 @@ def put_together_aud(recv, aud=None):
118119
:return: A possibly extended audience set
119120
"""
120121
if aud:
121-
if recv in aud:
122-
_aud = aud
123-
elif recv:
122+
if recv and recv not in aud:
124123
_aud = [recv]
125124
_aud.extend(aud)
125+
else:
126+
_aud = aud
126127
elif recv:
127128
_aud = [recv]
128129
else:
@@ -154,7 +155,7 @@ def pack_key(self, owner_id='', kid=''):
154155
:param kid: Key ID
155156
:return: One key
156157
"""
157-
keys = pick_key(self.my_keys(owner_id, 'sig'), 'sig', alg=self.sign_alg,
158+
keys = pick_key(self.my_keys(owner_id, 'sig'), 'sig', alg=self.alg,
158159
kid=kid)
159160

160161
if not keys:
@@ -198,13 +199,13 @@ def pack(self, payload=None, kid='', owner='', recv='', aud=None, **kwargs):
198199
owner = self.iss
199200

200201
if self.sign:
201-
if self.sign_alg != 'none':
202+
if self.alg != 'none':
202203
_key = self.pack_key(owner, kid)
203204
_args['kid'] = _key.kid
204205
else:
205206
_key = None
206207

207-
_jws = JWS(json.dumps(_args), alg=self.sign_alg)
208+
_jws = JWS(json.dumps(_args), alg=self.alg)
208209
_sjwt = _jws.sign_compact([_key])
209210
else:
210211
_sjwt = json.dumps(_args)
@@ -242,7 +243,8 @@ def _decrypt(self, rj, token):
242243
keys = self.key_jar.get_jwt_decrypt_keys(rj.jwt)
243244
return rj.decrypt(token, keys=keys)
244245

245-
def verify_profile(self, msg_cls, info, **kwargs):
246+
@staticmethod
247+
def verify_profile(msg_cls, info, **kwargs):
246248
"""
247249
If a message type is known for this JSON document. Verify that the
248250
document complies with the message specifications.
@@ -273,14 +275,24 @@ def unpack(self, token):
273275
_jwe_header = _jws_header = None
274276

275277
# Check if it's an encrypted JWT
276-
_rj = jwe_factory(token)
277-
if _rj:
278+
_decryptor = jwe_factory(token)
279+
if _decryptor:
280+
# check headers
281+
darg = {}
282+
if self.enc_enc:
283+
darg['enc'] = self.enc_enc
284+
if self.enc_alg:
285+
darg['alg'] = self.enc_alg
286+
287+
if _decryptor.jwt.verify_headers(**darg) is False:
288+
raise HeaderError('Wrong alg or enc')
289+
278290
# Yes, try to decode
279-
_info = self._decrypt(_rj, token)
280-
_jwe_header = _rj.jwt.headers
291+
_info = self._decrypt(_decryptor, token)
292+
_jwe_header = _decryptor.jwt.headers
281293
# Try to find out if the information encrypted was a signed JWT
282294
try:
283-
_content_type = _rj.jwt.headers['cty']
295+
_content_type = _decryptor.jwt.headers['cty']
284296
except KeyError:
285297
pass
286298
else:
@@ -289,12 +301,16 @@ def unpack(self, token):
289301
# If I have reason to believe the information I have is a signed JWT
290302
if _content_type.lower() == 'jwt':
291303
# Check that is a signed JWT
292-
_rj = jws_factory(_info)
293-
if _rj:
294-
_info = self._verify(_rj, _info)
304+
_verifier = jws_factory(_info)
305+
if _verifier:
306+
if self.alg and not _verifier.jwt.verify_headers(alg=self.alg):
307+
raise HeaderError(
308+
'Wrong signing algorithm: "{}" expected "{}"'.format(
309+
_verifier.jwt.headers['alg'], self.alg))
310+
_info = self._verify(_verifier, _info)
295311
else:
296312
raise Exception()
297-
_jws_header = _rj.jwt.headers
313+
_jws_header = _verifier.jwt.headers
298314
else:
299315
# So, not a signed JWT
300316
try:

src/cryptojwt/simple_jwt.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33

4+
from cryptojwt.exception import HeaderError
45
from .utils import as_unicode
56
from .utils import b64d
67
from .utils import b64encode_item
@@ -23,12 +24,14 @@ def __init__(self, **headers):
2324
self.b64part = [b64encode_item(headers)]
2425
self.part = [b64d(self.b64part[0])]
2526

26-
def unpack(self, token):
27+
def unpack(self, token, **kwargs):
2728
"""
2829
Unpacks a JWT into its parts and base64 decodes the parts
2930
individually
3031
3132
:param token: The JWT
33+
:param kwargs: A possible empty set of claims to verify the header
34+
against.
3235
"""
3336
if isinstance(token, str):
3437
try:
@@ -39,8 +42,18 @@ def unpack(self, token):
3942
part = split_token(token)
4043
self.b64part = part
4144
self.part = [b64d(p) for p in part]
42-
#self.headers = json.loads(self.part[0].decode())
4345
self.headers = json.loads(as_unicode(self.part[0]))
46+
for key,val in kwargs.items():
47+
try:
48+
_ok = self.verify_header(key,val)
49+
except KeyError:
50+
raise
51+
else:
52+
if not _ok:
53+
raise HeaderError(
54+
'Expected "{}" to be "{}", was "{}"'.format(
55+
key, val, self.headers[key]))
56+
4457
return self
4558

4659
def pack(self, parts=None, headers=None):
@@ -87,4 +100,43 @@ def payload(self):
87100
except ValueError:
88101
pass
89102

90-
return _msg
103+
return _msg
104+
105+
def verify_header(self, key, val):
106+
"""
107+
Check that a particular header claim is present and has a specific value
108+
109+
:param key: The claim
110+
:param val: The value of the claim
111+
:raises: KeyError if the claim is not present in the header
112+
:return: True if the claim exists in the header and has the prescribed
113+
value
114+
"""
115+
116+
if self.headers[key] == val:
117+
return True
118+
else:
119+
return False
120+
121+
def verify_headers(self, check_presence=True, **kwargs):
122+
"""
123+
Check that a set of particular header claim are present and has
124+
specific values
125+
126+
:param kwargs: The claim/value sets as a dictionary
127+
:return: True if the claim that appears in the header has the
128+
prescribed values. If a claim is not present in the header and
129+
check_presence is True then False is returned.
130+
"""
131+
for key, val in kwargs.items():
132+
try:
133+
_ok = self.verify_header(key, val)
134+
except KeyError:
135+
if check_presence:
136+
return False
137+
else:
138+
pass
139+
else:
140+
if not _ok:
141+
return False
142+
return True

tests/test_06_jws.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,5 @@ def test_factory_verify_alg():
828828
_signer = JWS(payload, alg='RS256')
829829
_signer.set_header_claim('foo', 'bar')
830830
_jws = _signer.sign_compact(keys, abc=123)
831-
_verifier = factory(_jws, alg='RS512')
832-
with pytest.raises(SignerAlgError):
833-
_verifier.verify_compact(_jws, keys)
831+
_verifier = factory(_jws)
832+
assert _verifier.jwt.verify_headers(alg='RS512') is False

tests/test_07_jwe.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,13 @@ def test_ecdh_no_setup_dynamic_epk():
368368
ret_jwe = factory(jwt)
369369
res = ret_jwe.decrypt(jwt, [eck_bob])
370370
assert res == plain
371+
372+
373+
def test_verify_headers():
374+
jwenc = JWE(plain, alg="ECDH-ES", enc="A128GCM")
375+
jwt = jwenc.encrypt([eck_bob])
376+
assert jwt
377+
decryptor = factory(jwt)
378+
assert decryptor.jwt.verify_headers(alg='ECDH-ES', enc='A128GCM')
379+
assert decryptor.jwt.verify_headers(alg='RS256') is False
380+
assert decryptor.jwt.verify_headers(kid='RS256') is False

tests/test_09_jwt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_jwt_pack_unpack_sym():
9393

9494
_kj = KeyJar()
9595
_kj.add_symmetric(ALICE, 'hemligt ordsprak', usage=['sig'])
96-
bob = JWT(key_jar=_kj, iss=BOB)
96+
bob = JWT(key_jar=_kj, iss=BOB, sign_alg="HS256")
9797
info = bob.unpack(_jwt)
9898
assert info
9999

@@ -115,7 +115,7 @@ def test_jwt_pack_and_unpack_with_alg():
115115
payload = {'sub': 'sub'}
116116
_jwt = alice.pack(payload=payload)
117117

118-
bob = JWT(BOB_KEY_JAR)
118+
bob = JWT(BOB_KEY_JAR, sign_alg='RS384')
119119
info = bob.unpack(_jwt)
120120

121121
assert set(info.keys()) == {'iat', 'iss', 'sub', 'kid'}
@@ -138,7 +138,7 @@ def test_with_jti():
138138
payload = {'sub': 'sub2'}
139139
_jwt = alice.pack(payload=payload)
140140

141-
bob = JWT(key_jar=_kj, iss=BOB)
141+
bob = JWT(key_jar=_kj, iss=BOB, sign_alg="HS256")
142142
info = bob.unpack(_jwt)
143143
assert 'jti' in info
144144

@@ -160,7 +160,7 @@ def test_msg_cls():
160160
payload = {'sub': 'sub2'}
161161
_jwt = alice.pack(payload=payload)
162162

163-
bob = JWT(key_jar=_kj, iss=BOB)
163+
bob = JWT(key_jar=_kj, iss=BOB, sign_alg="HS256")
164164
bob.msg_cls = DummyMsg
165165
info = bob.unpack(_jwt)
166166
assert isinstance(info, DummyMsg)

0 commit comments

Comments
 (0)