@@ -58,32 +58,31 @@ def construct(key_data, algorithm=None):
5858 if not algorithm :
5959 raise JWKError ('Unable to find a algorithm for key: %s' % key_data )
6060
61- if algorithm == ALGORITHMS .HS256 :
62- return HMACKey (key_data , HMACKey . SHA256 )
61+ if algorithm in ALGORITHMS .HMAC :
62+ return HMACKey (key_data , algorithm )
6363
64- if algorithm == ALGORITHMS .HS384 :
65- return HMACKey (key_data , HMACKey . SHA384 )
64+ if algorithm in ALGORITHMS .RSA :
65+ return RSAKey (key_data , algorithm )
6666
67- if algorithm == ALGORITHMS .HS512 :
68- return HMACKey (key_data , HMACKey . SHA512 )
67+ if algorithm in ALGORITHMS .EC :
68+ return ECKey (key_data , algorithm )
6969
70- if algorithm == ALGORITHMS .RS256 :
71- return RSAKey (key_data , RSAKey .SHA256 )
7270
73- if algorithm == ALGORITHMS .RS384 :
74- return RSAKey (key_data , RSAKey .SHA384 )
71+ def get_algorithm_object (algorithm ):
7572
76- if algorithm == ALGORITHMS .RS512 :
77- return RSAKey (key_data , RSAKey .SHA512 )
78-
79- if algorithm == ALGORITHMS .ES256 :
80- return ECKey (key_data , ECKey .SHA256 )
81-
82- if algorithm == ALGORITHMS .ES384 :
83- return ECKey (key_data , ECKey .SHA384 )
73+ algorithms = {
74+ ALGORITHMS .HS256 : HMACKey .SHA256 ,
75+ ALGORITHMS .HS384 : HMACKey .SHA384 ,
76+ ALGORITHMS .HS512 : HMACKey .SHA512 ,
77+ ALGORITHMS .RS256 : RSAKey .SHA256 ,
78+ ALGORITHMS .RS384 : RSAKey .SHA384 ,
79+ ALGORITHMS .RS512 : RSAKey .SHA512 ,
80+ ALGORITHMS .ES256 : ECKey .SHA256 ,
81+ ALGORITHMS .ES384 : ECKey .SHA384 ,
82+ ALGORITHMS .ES512 : ECKey .SHA512 ,
83+ }
8484
85- if algorithm == ALGORITHMS .ES512 :
86- return ECKey (key_data , ECKey .SHA512 )
85+ return algorithms .get (algorithm , None )
8786
8887
8988class Key (object ):
@@ -111,15 +110,15 @@ class HMACKey(Key):
111110 SHA256 = hashlib .sha256
112111 SHA384 = hashlib .sha384
113112 SHA512 = hashlib .sha512
114- valid_hash_algs = ( SHA256 , SHA384 , SHA512 )
113+ valid_hash_algs = ALGORITHMS . HMAC
115114
116115 prepared_key = None
117116 hash_alg = None
118117
119- def __init__ (self , key , hash_alg ):
120- if hash_alg not in self .valid_hash_algs :
121- raise JWKError ('hash_alg: %s is not a valid hash algorithm' % hash_alg )
122- self .hash_alg = hash_alg
118+ def __init__ (self , key , algorithm ):
119+ if algorithm not in self .valid_hash_algs :
120+ raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
121+ self .hash_alg = get_algorithm_object ( algorithm )
123122
124123 if isinstance (key , dict ):
125124 self .prepared_key = self ._process_jwk (key )
@@ -173,16 +172,16 @@ class RSAKey(Key):
173172 SHA256 = Crypto .Hash .SHA256
174173 SHA384 = Crypto .Hash .SHA384
175174 SHA512 = Crypto .Hash .SHA512
176- valid_hash_algs = ( SHA256 , SHA384 , SHA512 )
175+ valid_hash_algs = ALGORITHMS . RSA
177176
178177 prepared_key = None
179178 hash_alg = None
180179
181- def __init__ (self , key , hash_alg ):
180+ def __init__ (self , key , algorithm ):
182181
183- if hash_alg not in self .valid_hash_algs :
184- raise JWKError ('hash_alg: %s is not a valid hash algorithm' % hash_alg )
185- self .hash_alg = hash_alg
182+ if algorithm not in self .valid_hash_algs :
183+ raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
184+ self .hash_alg = get_algorithm_object ( algorithm )
186185
187186 if isinstance (key , _RSAKey ):
188187 self .prepared_key = key
@@ -240,7 +239,7 @@ class ECKey(Key):
240239 SHA256 = hashlib .sha256
241240 SHA384 = hashlib .sha384
242241 SHA512 = hashlib .sha512
243- valid_hash_algs = ( SHA256 , SHA384 , SHA512 )
242+ valid_hash_algs = ALGORITHMS . EC
244243
245244 curve_map = {
246245 SHA256 : ecdsa .curves .NIST256p ,
@@ -252,10 +251,10 @@ class ECKey(Key):
252251 hash_alg = None
253252 curve = None
254253
255- def __init__ (self , key , hash_alg ):
256- if hash_alg not in self .valid_hash_algs :
257- raise JWKError ('hash_alg: %s is not a valid hash algorithm' % hash_alg )
258- self .hash_alg = hash_alg
254+ def __init__ (self , key , algorithm ):
255+ if algorithm not in self .valid_hash_algs :
256+ raise JWKError ('hash_alg: %s is not a valid hash algorithm' % algorithm )
257+ self .hash_alg = get_algorithm_object ( algorithm )
259258
260259 self .curve = self .curve_map .get (self .hash_alg )
261260
0 commit comments