@@ -47,6 +47,26 @@ def base64_to_long(data):
4747 return int_arr_to_long (struct .unpack ('%sB' % len (_d ), _d ))
4848
4949
50+ def get_key (algorithm ):
51+ if algorithm in ALGORITHMS .KEYS :
52+ return ALGORITHMS .KEYS [algorithm ]
53+ elif algorithm in ALGORITHMS .HMAC :
54+ return HMACKey
55+ elif algorithm in ALGORITHMS .RSA :
56+ return RSAKey
57+ elif algorithm in ALGORITHMS .EC :
58+ return ECKey
59+ return None
60+
61+
62+ def register_key (algorithm , key_class ):
63+ if not issubclass (key_class , Key ):
64+ raise TypeError ("Key class not a subclass of jwk.Key" )
65+ ALGORITHMS .KEYS [algorithm ] = key_class
66+ ALGORITHMS .SUPPORTED .add (algorithm )
67+ return True
68+
69+
5070def construct (key_data , algorithm = None ):
5171 """
5272 Construct a Key object for the given algorithm with the given
@@ -60,14 +80,10 @@ def construct(key_data, algorithm=None):
6080 if not algorithm :
6181 raise JWKError ('Unable to find a algorithm for key: %s' % key_data )
6282
63- if algorithm in ALGORITHMS .HMAC :
64- return HMACKey (key_data , algorithm )
65-
66- if algorithm in ALGORITHMS .RSA :
67- return RSAKey (key_data , algorithm )
68-
69- if algorithm in ALGORITHMS .EC :
70- return ECKey (key_data , algorithm )
83+ key_class = get_key (algorithm )
84+ if not key_class :
85+ raise JWKError ('Unable to find a algorithm for key: %s' % key_data )
86+ return key_class (key_data , algorithm )
7187
7288
7389def get_algorithm_object (algorithm ):
@@ -91,11 +107,8 @@ class Key(object):
91107 """
92108 A simple interface for implementing JWK keys.
93109 """
94- prepared_key = None
95- hash_alg = None
96-
97- def _process_jwk (self , jwk_dict ):
98- raise NotImplementedError ()
110+ def __init__ (self , key , algorithm ):
111+ pass
99112
100113 def sign (self , msg ):
101114 raise NotImplementedError ()
@@ -112,13 +125,9 @@ class HMACKey(Key):
112125 SHA256 = hashlib .sha256
113126 SHA384 = hashlib .sha384
114127 SHA512 = hashlib .sha512
115- valid_hash_algs = ALGORITHMS .HMAC
116-
117- prepared_key = None
118- hash_alg = None
119128
120129 def __init__ (self , key , algorithm ):
121- if algorithm not in self . valid_hash_algs :
130+ if algorithm not in ALGORITHMS . HMAC :
122131 raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
123132 self .hash_alg = get_algorithm_object (algorithm )
124133
@@ -174,14 +183,10 @@ class RSAKey(Key):
174183 SHA256 = Crypto .Hash .SHA256
175184 SHA384 = Crypto .Hash .SHA384
176185 SHA512 = Crypto .Hash .SHA512
177- valid_hash_algs = ALGORITHMS .RSA
178-
179- prepared_key = None
180- hash_alg = None
181186
182187 def __init__ (self , key , algorithm ):
183188
184- if algorithm not in self . valid_hash_algs :
189+ if algorithm not in ALGORITHMS . RSA :
185190 raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
186191 self .hash_alg = get_algorithm_object (algorithm )
187192
@@ -242,7 +247,7 @@ def verify(self, msg, sig):
242247 try :
243248 return PKCS1_v1_5 .new (self .prepared_key ).verify (self .hash_alg .new (msg ), sig )
244249 except Exception as e :
245- raise JWKError ( e )
250+ return False
246251
247252
248253class ECKey (Key ):
@@ -257,24 +262,19 @@ class ECKey(Key):
257262 SHA256 = hashlib .sha256
258263 SHA384 = hashlib .sha384
259264 SHA512 = hashlib .sha512
260- valid_hash_algs = ALGORITHMS .EC
261265
262- curve_map = {
266+ CURVE_MAP = {
263267 SHA256 : ecdsa .curves .NIST256p ,
264268 SHA384 : ecdsa .curves .NIST384p ,
265269 SHA512 : ecdsa .curves .NIST521p ,
266270 }
267271
268- prepared_key = None
269- hash_alg = None
270- curve = None
271-
272272 def __init__ (self , key , algorithm ):
273- if algorithm not in self . valid_hash_algs :
273+ if algorithm not in ALGORITHMS . EC :
274274 raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
275275 self .hash_alg = get_algorithm_object (algorithm )
276276
277- self .curve = self .curve_map .get (self .hash_alg )
277+ self .curve = self .CURVE_MAP .get (self .hash_alg )
278278
279279 if isinstance (key , (ecdsa .SigningKey , ecdsa .VerifyingKey )):
280280 self .prepared_key = key
0 commit comments