Skip to content

Commit 689e003

Browse files
committed
Make all import of keys to a KeyBundle go through do_keys().
1 parent 0e011bd commit 689e003

File tree

4 files changed

+49
-49
lines changed

4 files changed

+49
-49
lines changed

src/cryptojwt/jws/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def alg2keytype(alg):
4545
elif alg.startswith("RS") or alg.startswith("PS"):
4646
return "RSA"
4747
elif alg.startswith("HS") or alg.startswith("A"):
48-
return "oct"
48+
return "OCT"
4949
elif alg.startswith("ES") or alg.startswith("ECDH-ES"):
5050
return "EC"
5151
else:

src/cryptojwt/key_bundle.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from .exception import DeSerializationNotPossible
1313
from .exception import JWKException
1414
from .exception import UnknownKeyType
15+
from .exception import UnsupportedAlgorithm
16+
from .exception import UnsupportedECurve
1517
from .exception import UpdateFailed
1618
from .jwk.ec import ECKey
1719
from .jwk.ec import import_private_key_from_file
@@ -36,11 +38,11 @@
3638
# _err = json.dumps({'error': error, 'error_description': descr})
3739
# raise excep(_err, 'application/json')
3840

39-
41+
# Make sure the keys are all uppercase
4042
K2C = {
4143
"RSA": RSAKey,
4244
"EC": ECKey,
43-
"oct": SYMKey,
45+
"OCT": SYMKey,
4446
}
4547

4648
MAP = {'dec': 'enc', 'enc': 'enc', 'ver': 'sig', 'sig': 'sig'}
@@ -242,33 +244,39 @@ def do_keys(self, keys):
242244
:return:
243245
"""
244246
for inst in keys:
245-
typ = inst["kty"]
247+
inst['kty'] = inst["kty"].upper()
248+
_typ = inst['kty']
246249
try:
247250
_usage = harmonize_usage(inst['use'])
248251
except KeyError:
249252
_usage = ['']
250253
else:
251254
del inst['use']
252255

253-
flag = 0
256+
_error = ''
254257
for _use in _usage:
255-
for _typ in [typ, typ.lower(), typ.upper()]:
256-
try:
257-
_key = K2C[_typ](use=_use, **inst)
258-
except KeyError:
259-
continue
260-
except JWKException as err:
261-
LOGGER.warning('While loading keys: %s', err)
262-
else:
263-
if _key not in self._keys:
264-
if not _key.kid:
265-
_key.add_kid()
266-
self._keys.append(_key)
267-
flag = 1
268-
break
269-
if not flag:
270-
LOGGER.warning(
271-
'While loading keys, UnknownKeyType: %s', typ)
258+
try:
259+
_key = K2C[_typ](use=_use, **inst)
260+
except KeyError:
261+
_error = 'UnknownKeyType: {}'.format(_typ)
262+
continue
263+
except (UnsupportedECurve, UnsupportedAlgorithm) as err:
264+
_error = str(err)
265+
break
266+
except JWKException as err:
267+
LOGGER.warning('While loading keys: %s', err)
268+
_error = str(err)
269+
else:
270+
if _key not in self._keys:
271+
if not _key.kid:
272+
_key.add_kid()
273+
self._keys.append(_key)
274+
_error = ''
275+
break
276+
if _error:
277+
LOGGER.warning('While loading keys, %s', _error)
278+
279+
self.last_updated = time.time()
272280

273281
def do_local_jwk(self, filename):
274282
"""
@@ -282,8 +290,6 @@ def do_local_jwk(self, filename):
282290
else:
283291
self.do_keys([_info])
284292

285-
self.last_updated = time.time()
286-
287293
def do_local_der(self, filename, keytype, keyusage=None, kid=''):
288294
"""
289295
Load a DER encoded file amd create a key from it.
@@ -292,29 +298,25 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=''):
292298
:param keytype: Presently 'rsa' and 'ec' supported
293299
:param keyusage: encryption ('enc') or signing ('sig') or both
294300
"""
295-
if keytype.lower() == 'rsa':
296-
_bkey = import_private_rsa_key_from_file(filename)
297-
_key = RSAKey().load_key(_bkey)
298-
elif keytype.lower() == 'ec':
299-
_bkey = import_private_key_from_file(filename)
300-
_key = ECKey().load_key(_bkey)
301+
key_args = {}
302+
_kty = keytype.lower()
303+
if _kty in ['rsa', 'ec']:
304+
key_args["kty"] = _kty
305+
_key = import_private_rsa_key_from_file(filename)
306+
key_args["priv_key"] = _key
307+
key_args["pub_key"] = _key.public_key()
301308
else:
302-
raise NotImplementedError('No support for DER decoding of that key type')
309+
raise NotImplementedError('No support for DER decoding of key type {}'.format(_kty))
303310

304311
if not keyusage:
305-
keyusage = ["enc", "sig"]
312+
key_args["use"] = ["enc", "sig"]
306313
else:
307-
keyusage = harmonize_usage(keyusage)
314+
key_args["use"] = harmonize_usage(keyusage)
308315

309-
for use in keyusage:
310-
_key.use = use
311-
if kid:
312-
_key.kid = kid
313-
if not _key.kid:
314-
_key.add_kid()
315-
self._keys.append(_key)
316+
if kid:
317+
key_args['kid'] = kid
316318

317-
self.last_updated = time.time()
319+
self.do_keys([key_args])
318320

319321
def do_remote(self):
320322
"""

tests/test_03_key_bundle.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,18 +430,18 @@ def test_outdated():
430430

431431

432432
def test_dump_jwks():
433-
kb1 = rsa_init(
434-
{'use': ['enc', 'sig'], 'size': 1024, 'name': 'rsa', 'path': 'keys'})
435433
a = {"kty": "oct", "key": "highestsupersecret", "use": "sig"}
436434
b = {"kty": "oct", "key": "highestsupersecret", "use": "enc"}
437435
kb2 = KeyBundle([a, b])
436+
437+
kb1 = rsa_init({'use': ['enc', 'sig'], 'size': 1024, 'name': 'rsa', 'path': 'keys'})
438438
dump_jwks([kb1, kb2], 'jwks_combo')
439439

440440
# Now read it
441441

442442
nkb = KeyBundle(source='file://jwks_combo', fileformat='jwks')
443443

444-
assert len(nkb) == 2
444+
assert len(nkb) == 4
445445
# both RSA keys
446446
assert len(nkb.get('rsa')) == 2
447447

@@ -656,10 +656,8 @@ def test_keys():
656656

657657
EXPECTED = [
658658
b'iA7PvG_DfJIeeqQcuXFmvUGjqBkda8In_uMpZrcodVA',
659-
b'kLsuyGef1kfw5-t-N9CJLIHx_dpZ79-KemwqjwdrvTI',
660-
b'8w34j9PLyCVC7VOZZb1tFVf0MOa2KZoy87lICMeD5w8',
661-
b'nKzalL5pJOtVAdCtBAU8giNRNimE-XbylWZ4vq6ZlF8',
662-
b'akXzyGlXg8yLhsCczKb_r8VERLx7-iZBUMIVgg2K7p4'
659+
b'akXzyGlXg8yLhsCczKb_r8VERLx7-iZBUMIVgg2K7p4',
660+
b'Rdy8n5h0fo2q9USHJ6HQKnNZFynN1pWN_X6Bc_Tx-lE'
663661
]
664662

665663

tests/test_06_jws.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def test_hmac_512():
225225

226226
def test_hmac_from_keyrep():
227227
payload = "Please take a moment to register today"
228-
symkeys = [k for k in SIGJWKS if k.kty == "oct"]
228+
symkeys = [k for k in SIGJWKS if k.kty == "OCT"]
229229
_jws = JWS(payload, alg="HS512")
230230
_jwt = _jws.sign_compact(symkeys)
231231

0 commit comments

Comments
 (0)