Skip to content

Commit e1981fe

Browse files
committed
Refactored. Made init_key_jat update the files on disc if the key_defs didn't correspond to what was there.
1 parent dc617fe commit e1981fe

File tree

4 files changed

+583
-122
lines changed

4 files changed

+583
-122
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 247 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import json
22
import logging
33
import os
4+
from functools import cmp_to_key
5+
46
import requests
57
import time
68

7-
from .exception import JWKException
9+
from .exception import JWKException, DeSerializationNotPossible
810
from .exception import UnknownKeyType
911
from .exception import UpdateFailed
1012
from .jwk.hmac import SYMKey
@@ -609,3 +611,247 @@ def dump_jwks(kbl, target, private=False):
609611
_txt = json.dumps(res)
610612
f.write(_txt)
611613
f.close()
614+
615+
616+
def build_key_bundle(key_conf, kid_template=""):
617+
"""
618+
Builds a :py:class:`oidcmsg.key_bundle.KeyBundle` instance based on a key
619+
specification.
620+
621+
An example of such a specification::
622+
623+
keys = [
624+
{"type": "RSA", "key": "cp_keys/key.pem", "use": ["enc", "sig"]},
625+
{"type": "EC", "crv": "P-256", "use": ["sig"], "kid": "ec.1"},
626+
{"type": "EC", "crv": "P-256", "use": ["enc"], "kid": "ec.2"}
627+
]
628+
629+
Keys in this specification are:
630+
631+
type
632+
The type of key. Presently only 'rsa' and 'ec' supported.
633+
634+
key
635+
A name of a file where a key can be found. Only works with PEM encoded
636+
RSA keys
637+
638+
use
639+
What the key should be used for
640+
641+
crv
642+
The elliptic curve that should be used. Only applies to elliptic curve
643+
keys :-)
644+
645+
kid
646+
Key ID, can only be used with one usage type is specified. If there
647+
are more the one usage type specified 'kid' will just be ignored.
648+
649+
:param key_conf: The key configuration
650+
:param kid_template: A template by which to build the key IDs. If no
651+
kid_template is given then the built-in function add_kid() will be used.
652+
:return: A KeyBundle instance
653+
"""
654+
655+
kid = 0
656+
657+
tot_kb = KeyBundle()
658+
for spec in key_conf:
659+
typ = spec["type"].upper()
660+
661+
if typ == "RSA":
662+
if "key" in spec:
663+
error_to_catch = (OSError, IOError,
664+
DeSerializationNotPossible)
665+
try:
666+
kb = KeyBundle(source="file://%s" % spec["key"],
667+
fileformat="der",
668+
keytype=typ, keyusage=spec["use"])
669+
except error_to_catch:
670+
kb = rsa_init(spec)
671+
except Exception:
672+
raise
673+
else:
674+
kb = rsa_init(spec)
675+
elif typ == "EC":
676+
kb = ec_init(spec)
677+
else:
678+
continue
679+
680+
if 'kid' in spec and len(kb) == 1:
681+
ks = kb.keys()
682+
ks[0].kid = spec['kid']
683+
else:
684+
for k in kb.keys():
685+
if kid_template:
686+
k.kid = kid_template % kid
687+
kid += 1
688+
else:
689+
k.add_kid()
690+
691+
tot_kb.extend(kb.keys())
692+
693+
return tot_kb
694+
695+
696+
def _cmp(kd1, kd2):
697+
if kd1 == kd2:
698+
return 0
699+
elif kd1< kd2:
700+
return -1
701+
elif kd1 > kd2:
702+
return 1
703+
704+
705+
def sort_func(kd1, kd2):
706+
_l = _cmp(kd1['type'], kd2['type'])
707+
if _l:
708+
return _l
709+
710+
if kd1['type'] == 'EC':
711+
_l = _cmp(kd1['crv'], kd2['crv'])
712+
if _l:
713+
return _l
714+
715+
_l = _cmp(kd1['type'], kd2['type'])
716+
if _l:
717+
return _l
718+
719+
_l = _cmp(kd1['use'][0], kd2['use'][0])
720+
if _l:
721+
return _l
722+
723+
try:
724+
_kid1 = kd1['kid']
725+
except KeyError:
726+
_kid1 = None
727+
728+
try:
729+
_kid2 = kd2['kid']
730+
except KeyError:
731+
_kid2 = None
732+
733+
if _kid1 and _kid2:
734+
return _cmp(_kid1, _kid2)
735+
elif _kid1:
736+
return -1
737+
elif _kid2:
738+
return 1
739+
740+
return 0
741+
742+
743+
def order_key_defs(key_def):
744+
"""
745+
746+
:param key_def:
747+
:return:
748+
"""
749+
_int = []
750+
# First make sure all defs only reference one usage
751+
for kd in key_def:
752+
if len(kd['use']) > 1:
753+
for _use in kd['use']:
754+
_kd = kd.copy()
755+
_kd['use'] = _use
756+
_int.append(_kd)
757+
else:
758+
_int.append(kd)
759+
760+
_int.sort(key=cmp_to_key(sort_func))
761+
762+
return _int
763+
764+
765+
def key_diff(key_bundle, key_defs, owner=''):
766+
"""
767+
Compares a KeyJar instance with a key specification and returns
768+
what new keys should be created and added to the key_jar and which should be
769+
removed from the key_jar.
770+
771+
:param key_jar:
772+
:param key_defs:
773+
:return:
774+
"""
775+
776+
keys = key_bundle.get()
777+
diff = {}
778+
779+
# My own sorted copy
780+
key_defs = order_key_defs(key_defs)[:]
781+
used = []
782+
783+
for key in keys:
784+
match = False
785+
for kd in key_defs:
786+
if key.use not in kd['use']:
787+
continue
788+
789+
if key.kty != kd['type']:
790+
continue
791+
792+
if key.kty == 'EC':
793+
# special test only for EC keys
794+
if key.crv != kd['crv']:
795+
continue
796+
797+
try:
798+
_kid = kd['kid']
799+
except KeyError:
800+
pass
801+
else:
802+
if key.kid != _kid:
803+
continue
804+
805+
match = True
806+
used.append(kd)
807+
key_defs.remove(kd)
808+
break
809+
810+
if not match:
811+
try:
812+
diff['del'].append(key)
813+
except KeyError:
814+
diff['del'] = [key]
815+
816+
if key_defs:
817+
_kb = build_key_bundle(key_defs)
818+
diff['add'] = _kb.keys()
819+
820+
return diff
821+
822+
823+
def update_key_bundle(key_bundle, diff):
824+
try:
825+
_add = diff['add']
826+
except KeyError:
827+
pass
828+
else:
829+
key_bundle.extend(_add)
830+
831+
try:
832+
_del = diff['del']
833+
except KeyError:
834+
pass
835+
else:
836+
_now = time.time()
837+
for k in _del:
838+
k.inactive_since = _now
839+
840+
841+
def key_rollover(kb):
842+
key_spec = []
843+
for key in kb.get():
844+
_spec = {'type': key.kty, 'use':[key.use]}
845+
if key.kid:
846+
_spec['kid'] = key.kid
847+
if key.kty == 'EC':
848+
_spec['crv'] = key.crv
849+
850+
key_spec.append(_spec)
851+
852+
diff = {'del': kb.get()}
853+
_kb = build_key_bundle(key_spec)
854+
diff['add'] = _kb.keys()
855+
856+
update_key_bundle(kb, diff)
857+
return kb

src/cryptojwt/key_jar.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from .jws.utils import alg2keytype as jws_alg2keytype
88
from .jwe.jwe import alg2keytype as jwe_alg2keytype
99

10-
from .exception import DeSerializationNotPossible
10+
from .key_bundle import build_key_bundle, key_diff, update_key_bundle
1111
from .key_bundle import KeyBundle
12-
from .key_bundle import ec_init
13-
from .key_bundle import rsa_init
1412

1513

1614
__author__ = 'Roland Hedberg'
@@ -287,7 +285,7 @@ def __contains__(self, item):
287285
else:
288286
return False
289287

290-
def __getitem__(self, owner):
288+
def __getitem__(self, owner=''):
291289
"""
292290
Get all the key bundles that belong to an entity.
293291
@@ -651,8 +649,8 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, owner=''):
651649
652650
keys = [
653651
{"type": "RSA", "key": "cp_keys/key.pem", "use": ["enc", "sig"]},
654-
{"type": "EC", "crv": "P-256", "use": ["sig"]},
655-
{"type": "EC", "crv": "P-256", "use": ["enc"]}
652+
{"type": "EC", "crv": "P-256", "use": ["sig"], "kid": "ec.1"},
653+
{"type": "EC", "crv": "P-256", "use": ["enc"], "kid": "ec.2"}
656654
]
657655
658656
Keys in this specification are:
@@ -671,6 +669,10 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, owner=''):
671669
The elliptic curve that should be used. Only applies to elliptic curve
672670
keys :-)
673671
672+
kid
673+
Key ID, can only be used with one usage type is specified. If there
674+
are more the one usage type specified 'kid' will just be ignored.
675+
674676
:param key_conf: The key configuration
675677
:param kid_template: A template by which to build the key IDs. If no
676678
kid_template is given then the built-in function add_kid() will be used.
@@ -682,39 +684,7 @@ def build_keyjar(key_conf, kid_template="", keyjar=None, owner=''):
682684
if keyjar is None:
683685
keyjar = KeyJar()
684686

685-
kid = 0
686-
687-
tot_kb = KeyBundle()
688-
for spec in key_conf:
689-
typ = spec["type"].upper()
690-
691-
if typ == "RSA":
692-
if "key" in spec:
693-
error_to_catch = (OSError, IOError,
694-
DeSerializationNotPossible)
695-
try:
696-
kb = KeyBundle(source="file://%s" % spec["key"],
697-
fileformat="der",
698-
keytype=typ, keyusage=spec["use"])
699-
except error_to_catch:
700-
kb = rsa_init(spec)
701-
except Exception:
702-
raise
703-
else:
704-
kb = rsa_init(spec)
705-
elif typ == "EC":
706-
kb = ec_init(spec)
707-
else:
708-
continue
709-
710-
for k in kb.keys():
711-
if kid_template:
712-
k.kid = kid_template % kid
713-
kid += 1
714-
else:
715-
k.add_kid()
716-
717-
tot_kb.extend(kb.keys())
687+
tot_kb = build_key_bundle(key_conf, kid_template)
718688

719689
keyjar.add_kb(owner, tot_kb)
720690

@@ -804,6 +774,16 @@ def init_key_jar(public_path='', private_path='', key_defs='', owner=''):
804774
_jwks = open(private_path, 'r').read()
805775
_kj = KeyJar()
806776
_kj.import_jwks(json.loads(_jwks), owner)
777+
if key_defs:
778+
_kb = _kj.issuer_keys[owner][0]
779+
_diff = key_diff(_kb, key_defs)
780+
if _diff:
781+
update_key_bundle(_kb, _diff)
782+
_kj.issuer_keys[owner] = [_kb]
783+
jwks = _kj.export_jwks(private=True, issuer=owner)
784+
fp = open(private_path, 'w')
785+
fp.write(json.dumps(jwks))
786+
fp.close()
807787
else:
808788
_kj = build_keyjar(key_defs, owner=owner)
809789
jwks = _kj.export_jwks(private=True, issuer=owner)
@@ -824,6 +804,16 @@ def init_key_jar(public_path='', private_path='', key_defs='', owner=''):
824804
_jwks = open(public_path, 'r').read()
825805
_kj = KeyJar()
826806
_kj.import_jwks(json.loads(_jwks), owner)
807+
if key_defs:
808+
_kb = _kj.issuer_keys[owner][0]
809+
_diff = key_diff(_kb, key_defs)
810+
if _diff:
811+
update_key_bundle(_kb, _diff)
812+
_kj.issuer_keys[owner] = [_kb]
813+
jwks = _kj.export_jwks(issuer=owner)
814+
fp = open(private_path, 'w')
815+
fp.write(json.dumps(jwks))
816+
fp.close()
827817
else:
828818
_kj = build_keyjar(key_defs, owner=owner)
829819
_jwks = _kj.export_jwks(issuer=owner)

0 commit comments

Comments
 (0)