Skip to content

Commit d8ad5f7

Browse files
committed
Replace Cryptodome library with cryptography
- Use builtin provided data instead of hardcoded magic. - Make code python2 and python3 compatible. - Improve exception messages. - Fix test. Fixes #389 Signed-off-by: Ivan Kanakarakis <[email protected]>
1 parent 066a7ae commit d8ad5f7

File tree

2 files changed

+65
-39
lines changed

2 files changed

+65
-39
lines changed

src/saml2/aes.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
from base64 import b64decode
33
from base64 import b64encode
44

5-
from Cryptodome import Random
6-
from Cryptodome.Cipher import AES
5+
from cryptography.hazmat.backends import default_backend
6+
from cryptography.hazmat.primitives.ciphers import Cipher
7+
from cryptography.hazmat.primitives.ciphers import algorithms
8+
from cryptography.hazmat.primitives.ciphers import modes
79

810

911
POSTFIX_MODE = {
10-
'cbc': AES.MODE_CBC,
11-
'cfb': AES.MODE_CFB,
12-
'ecb': AES.MODE_CFB,
12+
'cbc': modes.CBC,
13+
'cfb': modes.CFB,
14+
'ecb': modes.ECB,
1315
}
1416

15-
BLOCK_SIZE = 16
17+
AES_BLOCK_SIZE = int(algorithms.AES.block_size / 8)
1618

1719

1820
class AESCipher(object):
@@ -31,29 +33,38 @@ def build_cipher(self, iv=None, alg='aes_128_cbc'):
3133
:param alg: cipher algorithm
3234
:return: A Cipher instance
3335
"""
34-
typ, bits, cmode = alg.split('_')
36+
typ, bits, cmode = alg.lower().split('_')
37+
bits = int(bits)
3538

3639
if not iv:
37-
iv = self.iv if self.iv else Random.new().read(AES.block_size)
40+
if self.iv:
41+
iv = self.iv
42+
else:
43+
iv = os.urandom(AES_BLOCK_SIZE)
3844

39-
if len(iv) != AES.block_size:
40-
raise Exception('Wrong iv size')
45+
if len(iv) != AES_BLOCK_SIZE:
46+
raise Exception('Wrong iv size: {}'.format(len(iv)))
4147

42-
if bits not in ['128', '192', '256']:
43-
raise Exception('Unsupported key length')
48+
if bits not in algorithms.AES.key_sizes:
49+
raise Exception('Unsupported key length: {}'.format(bits))
4450

45-
if len(self.key) != int(bits) >> 3:
46-
raise Exception('Wrong Key length')
51+
if len(self.key) != bits / 8:
52+
raise Exception('Wrong Key length: {}'.format(len(self.key)))
4753

4854
try:
49-
result = AES.new(self.key, POSTFIX_MODE[cmode], iv)
55+
mode = POSTFIX_MODE[cmode]
5056
except KeyError:
51-
raise Exception('Unsupported chaining mode')
52-
else:
53-
return result, iv
57+
raise Exception('Unsupported chaining mode: {}'.format(cmode))
58+
59+
cipher = Cipher(
60+
algorithms.AES(self.key),
61+
mode(iv),
62+
backend=default_backend())
63+
64+
return cipher, iv
5465

5566
def encrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7',
56-
b64enc=True, block_size=BLOCK_SIZE):
67+
b64enc=True, block_size=AES_BLOCK_SIZE):
5768
"""
5869
:param key: The encryption key
5970
:param iv: init vector
@@ -73,11 +84,12 @@ def encrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7',
7384

7485
if _block_size:
7586
plen = _block_size - (len(msg) % _block_size)
76-
c = chr(plen)
87+
c = chr(plen).encode()
7788
msg += c * plen
7889

7990
cipher, iv = self.build_cipher(iv, alg)
80-
cmsg = iv + cipher.encrypt(msg)
91+
encryptor = cipher.encryptor()
92+
cmsg = iv + encryptor.update(msg) + encryptor.finalize()
8193

8294
if b64enc:
8395
enc_msg = b64encode(cmsg)
@@ -96,26 +108,38 @@ def decrypt(self, msg, iv=None, alg='aes_128_cbc', padding='PKCS#7',
96108
"""
97109
data = b64decode(msg) if b64dec else msg
98110

99-
_iv = data[:AES.block_size]
111+
_iv = data[:AES_BLOCK_SIZE]
100112
if iv:
101113
assert iv == _iv
102114
cipher, iv = self.build_cipher(iv, alg=alg)
103-
res = cipher.decrypt(data)[AES.block_size:]
115+
decryptor = cipher.decryptor()
116+
res = decryptor.update(data)[AES_BLOCK_SIZE:] + decryptor.finalize()
104117
if padding in ['PKCS#5', 'PKCS#7']:
105-
res = res[:-ord(res[-1])]
118+
idx = bytearray(res)[-1]
119+
res = res[:-idx]
106120
return res
107121

108122

109-
if __name__ == '__main__':
110-
key_ = '1234523451234545' # 16 byte key
123+
def run_test():
124+
key = b'1234523451234545' # 16 byte key
125+
iv = os.urandom(AES_BLOCK_SIZE)
111126
# Iff padded, the message doesn't have to be multiple of 16 in length
112-
msg_ = 'ToBeOrNotTobe W.S.'
113-
aes = AESCipher(key_)
114-
iv_ = os.urandom(16)
115-
encrypted_msg = aes.encrypt(key_, msg_, iv_)
116-
txt = aes.decrypt(key_, encrypted_msg, iv_)
117-
assert txt == msg_
118-
119-
encrypted_msg = aes.encrypt(key_, msg_, 0)
120-
txt = aes.decrypt(key_, encrypted_msg, 0)
121-
assert txt == msg_
127+
original_msg = b'ToBeOrNotTobe W.S.'
128+
aes = AESCipher(key)
129+
130+
encrypted_msg = aes.encrypt(original_msg, iv)
131+
decrypted_msg = aes.decrypt(encrypted_msg, iv)
132+
assert decrypted_msg == original_msg
133+
134+
encrypted_msg = aes.encrypt(original_msg)
135+
decrypted_msg = aes.decrypt(encrypted_msg)
136+
assert decrypted_msg == original_msg
137+
138+
aes = AESCipher(key, iv)
139+
encrypted_msg = aes.encrypt(original_msg)
140+
decrypted_msg = aes.decrypt(encrypted_msg)
141+
assert decrypted_msg == original_msg
142+
143+
144+
if __name__ == '__main__':
145+
run_test()

src/saml2/authn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __init__(self, srv, mako_template, template_lookup, pwd, return_to):
120120
self.return_to = return_to
121121
self.active = {}
122122
self.query_param = "upm_answer"
123-
self.aes = AESCipher(self.srv.symkey, srv.iv)
123+
self.aes = AESCipher(self.srv.symkey.encode(), srv.iv)
124124

125125
def __call__(self, cookie=None, policy_url=None, logo_url=None,
126126
query="", **kwargs):
@@ -171,7 +171,8 @@ def verify(self, request, **kwargs):
171171
try:
172172
self._verify(_dict["password"][0], _dict["login"][0])
173173
timestamp = str(int(time.mktime(time.gmtime())))
174-
info = self.aes.encrypt("::".join([_dict["login"][0], timestamp]))
174+
msg = "::".join([_dict["login"][0], timestamp])
175+
info = self.aes.encrypt(msg.encode())
175176
self.active[info] = timestamp
176177
cookie = make_cookie(self.cookie_name, info, self.srv.seed)
177178
return_to = create_return_url(self.return_to, _dict["query"][0],
@@ -191,7 +192,8 @@ def authenticated_as(self, cookie=None, **kwargs):
191192
info, timestamp = parse_cookie(self.cookie_name,
192193
self.srv.seed, cookie)
193194
if self.active[info] == timestamp:
194-
uid, _ts = self.aes.decrypt(info).split("::")
195+
msg = self.aes.decrypt(info).decode()
196+
uid, _ts = msg.split("::")
195197
if timestamp == _ts:
196198
return {"uid": uid}
197199
except Exception:

0 commit comments

Comments
 (0)