Skip to content

Commit af22f92

Browse files
committed
More about making sure the crypto algorithms use expect are the ones used.
1 parent f486fd3 commit af22f92

File tree

9 files changed

+79
-62
lines changed

9 files changed

+79
-62
lines changed

src/cryptojwt/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
except ImportError:
1616
pass
1717

18-
__version__ = '0.5.0'
18+
__version__ = '0.6.1'
1919

2020
logger = logging.getLogger(__name__)
2121

src/cryptojwt/jwe/jwe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,12 @@ def alg2keytype(self, alg):
204204
return alg2keytype(alg)
205205

206206

207-
def factory(token, **kwargs):
208-
_jwt = JWEnc().unpack(token, **kwargs)
207+
def factory(token, alg='', enc=''):
208+
try:
209+
_jwt = JWEnc().unpack(token, alg=alg, enc=enc)
210+
except KeyError:
211+
return None
212+
209213
if _jwt.is_jwe():
210214
_jwe = JWE()
211215
_jwe.jwt = _jwt

src/cryptojwt/jws/jws.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def valid(self):
7171

7272

7373
class JWS(JWx):
74+
def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs):
75+
JWx.__init__(self, msg, with_digest, httpc, **kwargs)
76+
if 'alg' not in self:
77+
self['alg'] = "RS256"
78+
7479
def alg_keys(self, keys, use, protected=None):
7580
_alg = self._pick_alg(keys)
7681

@@ -188,9 +193,16 @@ def verify_compact_verbose(self, jws=None, keys=None, allow_none=False,
188193
else:
189194
raise SignerAlgError("none not allowed")
190195

191-
if "alg" in self and _alg:
192-
if self["alg"] != _alg:
193-
raise SignerAlgError("Wrong signing algorithm")
196+
if "alg" in self and self['alg'] and _alg:
197+
if isinstance(self['alg'], list):
198+
if _alg not in self["alg"] :
199+
raise SignerAlgError(
200+
"Wrong signing algorithm, expected {} go {}".format(
201+
self['alg'], _alg))
202+
elif _alg != self['alg']:
203+
raise SignerAlgError(
204+
"Wrong signing algorithm, expected {} go {}".format(
205+
self['alg'], _alg))
194206

195207
if sigalg and sigalg != _alg:
196208
raise SignerAlgError("Expected {0} got {1}".format(
@@ -419,15 +431,16 @@ def verify_alg(self, alg):
419431
return False
420432

421433

422-
def factory(token):
434+
def factory(token, alg=''):
423435
"""
424436
Instantiate an JWS instance if the token is a signed JWT.
425437
426438
:param token: The token that might be a signed JWT
439+
:param alg: The expected signature algorithm
427440
:return: A JWS instance if the token was a signed JWT, otherwise None
428441
"""
429442

430-
_jw = JWS()
443+
_jw = JWS(alg=alg)
431444
if _jw.is_jws(token):
432445
return _jw
433446
else:

src/cryptojwt/jwt.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -280,18 +280,17 @@ def unpack(self, token):
280280
_jwe_header = _jws_header = None
281281

282282
# Check if it's an encrypted JWT
283-
_decryptor = jwe_factory(token)
284-
if _decryptor:
285-
# check headers
286-
darg = {}
287-
if self.allowed_enc_encs:
288-
darg['enc'] = self.allowed_enc_encs
289-
if self.allowed_enc_algs:
290-
darg['alg'] = self.allowed_enc_algs
291-
292-
if _decryptor.jwt.verify_headers(**darg) is False:
293-
raise HeaderError('Wrong alg or enc')
283+
darg = {}
284+
if self.allowed_enc_encs:
285+
darg['enc'] = self.allowed_enc_encs
286+
if self.allowed_enc_algs:
287+
darg['alg'] = self.allowed_enc_algs
288+
try:
289+
_decryptor = jwe_factory(token, **darg)
290+
except (KeyError, HeaderError):
291+
_decryptor = None
294292

293+
if _decryptor:
295294
# Yes, try to decode
296295
_info = self._decrypt(_decryptor, token)
297296
_jwe_header = _decryptor.jwt.headers
@@ -307,14 +306,12 @@ def unpack(self, token):
307306
# If I have reason to believe the information I have is a signed JWT
308307
if _content_type.lower() == 'jwt':
309308
# Check that is a signed JWT
310-
_verifier = jws_factory(_info)
309+
if self.allowed_sign_algs:
310+
_verifier = jws_factory(_info, alg=self.allowed_sign_algs)
311+
else:
312+
_verifier = jws_factory(_info, alg=self.allowed_sign_algs)
313+
311314
if _verifier:
312-
if self.allowed_sign_algs and not _verifier.jwt.verify_headers(
313-
alg=self.allowed_sign_algs):
314-
raise HeaderError(
315-
'Wrong signing algorithm: "{}" expected "{}"'.format(
316-
_verifier.jwt.headers['alg'],
317-
self.allowed_sign_algs))
318315
_info = self._verify(_verifier, _info)
319316
else:
320317
raise Exception()

src/cryptojwt/simple_jwt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def unpack(self, token, **kwargs):
4444
self.part = [b64d(p) for p in part]
4545
self.headers = json.loads(as_unicode(self.part[0]))
4646
for key,val in kwargs.items():
47+
if not val and key in self.headers:
48+
continue
49+
4750
try:
4851
_ok = self.verify_header(key,val)
4952
except KeyError:

tests/test_06_jws.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_1():
165165
"exp": 1300819380,
166166
"http://example.com/is_root": True}
167167

168-
_jws = JWS(claimset, cty="JWT")
168+
_jws = JWS(claimset, cty="JWT", alg='none')
169169
_jwt = _jws.sign_compact()
170170

171171
_jr = JWS()
@@ -181,7 +181,7 @@ def test_hmac_256():
181181
_jws = JWS(payload, alg="HS256")
182182
_jwt = _jws.sign_compact(keys)
183183

184-
info = JWS().verify_compact(_jwt, keys)
184+
info = JWS(alg="HS256").verify_compact(_jwt, keys)
185185

186186
assert info == payload
187187

@@ -192,7 +192,7 @@ def test_hmac_384():
192192
_jws = JWS(payload, alg="HS384")
193193
_jwt = _jws.sign_compact(keys)
194194

195-
_rj = JWS()
195+
_rj = JWS(alg="HS384")
196196
info = _rj.verify_compact(_jwt, keys)
197197

198198
assert info == payload
@@ -204,7 +204,7 @@ def test_hmac_512():
204204
_jws = JWS(payload, alg="HS512")
205205
_jwt = _jws.sign_compact(keys)
206206

207-
_rj = JWS()
207+
_rj = JWS(alg="HS512")
208208
info = _rj.verify_compact(_jwt, keys)
209209
assert info == payload
210210

@@ -215,7 +215,7 @@ def test_hmac_from_keyrep():
215215
_jws = JWS(payload, alg="HS512")
216216
_jwt = _jws.sign_compact(symkeys)
217217

218-
_rj = JWS()
218+
_rj = JWS(alg="HS512")
219219
info = _rj.verify_compact(_jwt, symkeys)
220220
assert info == payload
221221

@@ -239,7 +239,7 @@ def test_rs256():
239239
_jwt = _jws.sign_compact(skeys)
240240

241241
vkeys = [RSAKey(pub_key=_pkey.public_key())]
242-
_rj = JWS()
242+
_rj = JWS(alg="RS256")
243243
info = _rj.verify_compact(_jwt, vkeys)
244244

245245
assert info == payload
@@ -254,7 +254,7 @@ def test_rs384():
254254
_jwt = _jws.sign_compact(keys)
255255

256256
vkeys = [RSAKey(pub_key=_pkey.public_key())]
257-
_rj = JWS()
257+
_rj = JWS(alg="RS384")
258258
info = _rj.verify_compact(_jwt, vkeys)
259259
assert info == payload
260260

@@ -268,7 +268,7 @@ def test_rs512():
268268
_jwt = _jws.sign_compact(keys)
269269

270270
vkeys = [RSAKey(pub_key=_pkey.public_key())]
271-
_rj = JWS()
271+
_rj = JWS(alg="RS512")
272272
info = _rj.verify_compact(_jwt, vkeys)
273273
assert info == payload
274274

@@ -304,8 +304,8 @@ def test_a_1_3a():
304304
"HAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnV"
305305
"lfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")
306306

307-
# keycol = {"hmac": cryptojwt.intarr2bin(HMAC_KEY)}
308-
jwt = JWSig().unpack(_jwt)
307+
# alg == '' means I'm fine with whatever I get
308+
jwt = JWSig(alg='').unpack(_jwt)
309309
assert jwt.valid()
310310

311311
hmac = intarr2bin(HMAC_KEY)
@@ -318,7 +318,7 @@ def test_a_1_3b():
318318
"eHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0c"
319319
"nVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk")
320320
keys = [SYMKey(key=intarr2bin(HMAC_KEY))]
321-
_jws2 = JWS()
321+
_jws2 = JWS(alg='')
322322
_jws2.verify_compact(_jwt, keys)
323323

324324

@@ -372,7 +372,7 @@ def test_signer_es(ec_func, alg):
372372
_jwt = _jws.sign_compact(keys)
373373

374374
_pubkey = ECKey().load_key(eck.public_key())
375-
_rj = JWS()
375+
_rj = JWS(alg=alg)
376376
info = _rj.verify_compact(_jwt, [_pubkey])
377377
assert info == payload
378378

@@ -386,7 +386,7 @@ def test_signer_es256_verbose():
386386
_jwt = _jws.sign_compact(keys)
387387

388388
_pubkey = ECKey().load_key(eck.public_key())
389-
_rj = JWS()
389+
_rj = JWS(alg="ES256")
390390
info = _rj.verify_compact_verbose(_jwt, [_pubkey])
391391
assert info['msg'] == payload
392392
assert info['key'] == _pubkey
@@ -401,7 +401,7 @@ def test_signer_ps256():
401401
_jwt = _jws.sign_compact(keys)
402402

403403
vkeys = [RSAKey(pub_key=_pkey.public_key())]
404-
_rj = JWS()
404+
_rj = JWS(alg="PS256")
405405
info = _rj.verify_compact(_jwt, vkeys)
406406
assert info == payload
407407

@@ -415,7 +415,7 @@ def test_signer_ps256_fail():
415415
_jwt = _jws.sign_compact(keys)[:-5] + 'abcde'
416416

417417
vkeys = [RSAKey(pub_key=_pkey.public_key())]
418-
_rj = JWS()
418+
_rj = JWS(alg="PS256")
419419
try:
420420
_rj.verify_compact(_jwt, vkeys)
421421
except BadSignature:
@@ -433,7 +433,7 @@ def test_signer_ps384():
433433
_jwt = _jws.sign_compact(keys)
434434

435435
vkeys = [RSAKey(pub_key=_pkey.public_key())]
436-
_rj = JWS()
436+
_rj = JWS(alg="PS384")
437437
info = _rj.verify_compact(_jwt, vkeys)
438438
assert info == payload
439439

@@ -448,7 +448,7 @@ def test_signer_ps512():
448448
_jwt = _jws.sign_compact(keys)
449449

450450
vkeys = [RSAKey(pub_key=_pkey.public_key())]
451-
_rj = factory(_jwt)
451+
_rj = factory(_jwt, alg="PS512")
452452
info = _rj.verify_compact(_jwt, vkeys)
453453
assert info == payload
454454
assert _rj.verify_alg('PS512')
@@ -462,7 +462,7 @@ def test_no_alg_and_alg_none_same():
462462
_jwt0 = _jws.sign_compact([])
463463

464464
# The class instance that sets up the signing operation
465-
_jws = JWS(payload)
465+
_jws = JWS(payload, alg="none")
466466

467467
# Create a JWS (signed JWT)
468468
_jwt1 = _jws.sign_compact([])
@@ -506,7 +506,7 @@ def test_signer_protected_headers():
506506
assert b64d(enc_payload.encode("utf-8")).decode("utf-8") == payload
507507

508508
_pub_key = ECKey().load_key(eck.public_key())
509-
_rj = JWS()
509+
_rj = JWS(alg='ES256')
510510
info = _rj.verify_compact(_jwt, [_pub_key])
511511
assert info == payload
512512

@@ -699,7 +699,7 @@ def test_pick_alg_assume_alg_from_single_key():
699699
expected_alg = "HS256"
700700
keys = [SYMKey(key="foobar subdued thought", alg=expected_alg)]
701701

702-
alg = JWS()._pick_alg(keys)
702+
alg = JWS(alg=expected_alg)._pick_alg(keys)
703703
assert alg == expected_alg
704704

705705

tests/test_07_jwe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def test_encrypt_decrypt_rsa_cbc():
276276

277277
jwt = _jwe0.encrypt([_key])
278278

279-
_jwe1 = factory(jwt)
279+
_jwe1 = factory(jwt, alg="RSA1_5", enc="A128CBC-HS256")
280280
_dkey = RSAKey(priv_key=priv_key)
281281
_dkey._keytype = "private"
282282
msg = _jwe1.decrypt(jwt, [_dkey])
@@ -318,7 +318,7 @@ def test_ecdh_encrypt_decrypt_direct_key():
318318
jwt = jwenc.encrypt(**kwargs)
319319

320320
# Bob decrypts
321-
ret_jwe = factory(jwt)
321+
ret_jwe = factory(jwt, alg="ECDH-ES", enc="A128GCM")
322322
jwdec = JWE_EC()
323323
jwdec.dec_setup(ret_jwe.jwt, key=bob)
324324
msg = jwdec.decrypt(ret_jwe.jwt)
@@ -342,7 +342,7 @@ def test_ecdh_encrypt_decrypt_keywrapped_key():
342342

343343
jwt = jwenc.encrypt(**kwargs)
344344

345-
ret_jwe = factory(jwt)
345+
ret_jwe = factory(jwt, alg="ECDH-ES+A128KW", enc="A128GCM")
346346
jwdec = JWE_EC()
347347
jwdec.dec_setup(ret_jwe.jwt, key=bob)
348348
msg = jwdec.decrypt(ret_jwe.jwt)
@@ -365,7 +365,7 @@ def test_ecdh_no_setup_dynamic_epk():
365365
jwenc = JWE(plain, alg="ECDH-ES", enc="A128GCM")
366366
jwt = jwenc.encrypt([eck_bob])
367367
assert jwt
368-
ret_jwe = factory(jwt)
368+
ret_jwe = factory(jwt, alg="ECDH-ES", enc="A128GCM")
369369
res = ret_jwe.decrypt(jwt, [eck_bob])
370370
assert res == plain
371371

@@ -374,7 +374,7 @@ def test_verify_headers():
374374
jwenc = JWE(plain, alg="ECDH-ES", enc="A128GCM")
375375
jwt = jwenc.encrypt([eck_bob])
376376
assert jwt
377-
decryptor = factory(jwt)
377+
decryptor = factory(jwt, alg="ECDH-ES", enc="A128GCM")
378378
assert decryptor.jwt.verify_headers(alg='ECDH-ES', enc='A128GCM')
379379
assert decryptor.jwt.verify_headers(alg='RS256') is False
380380
assert decryptor.jwt.verify_headers(kid='RS256') is False

tests/test_09_jwt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ def test_jwt_pack():
5252

5353

5454
def test_jwt_pack_and_unpack():
55-
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE)
55+
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg='RS256')
5656
payload = {'sub': 'sub'}
5757
_jwt = alice.pack(payload=payload)
5858

59-
bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB)
59+
bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"])
6060
info = bob.unpack(_jwt)
6161

6262
assert set(info.keys()) == {'iat', 'iss', 'sub', 'kid'}

0 commit comments

Comments
 (0)