Skip to content

Commit 6009134

Browse files
committed
Add encryption tests
1 parent 3e9b0de commit 6009134

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from cryptography import x509
2+
from cryptography.hazmat.primitives import serialization, hashes
3+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
4+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
5+
from cryptography.hazmat.backends import default_backend
6+
import os
7+
import base64
8+
from typing import Optional
9+
10+
import pytest
11+
12+
from unstructured_client import UnstructuredClient
13+
14+
@pytest.fixture
15+
def rsa_key_pair():
16+
private_key = rsa.generate_private_key(
17+
public_exponent=65537,
18+
key_size=2048,
19+
backend=default_backend()
20+
)
21+
public_key = private_key.public_key()
22+
23+
private_key_pem = private_key.private_bytes(
24+
encoding=serialization.Encoding.PEM,
25+
format=serialization.PrivateFormat.TraditionalOpenSSL,
26+
encryption_algorithm=serialization.NoEncryption()
27+
).decode('utf-8')
28+
29+
public_key_pem = public_key.public_bytes(
30+
encoding=serialization.Encoding.PEM,
31+
format=serialization.PublicFormat.SubjectPublicKeyInfo
32+
).decode('utf-8')
33+
34+
return private_key_pem, public_key_pem
35+
36+
37+
def decrypt_secret(
38+
private_key_pem: str,
39+
encrypted_value: str,
40+
type: str,
41+
encrypted_aes_key: str,
42+
aes_iv: str,
43+
) -> str:
44+
private_key = serialization.load_pem_private_key(
45+
private_key_pem.encode('utf-8'),
46+
password=None,
47+
backend=default_backend()
48+
)
49+
50+
if type == 'rsa':
51+
ciphertext = base64.b64decode(encrypted_value)
52+
plaintext = private_key.decrypt(
53+
ciphertext,
54+
padding.OAEP(
55+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
56+
algorithm=hashes.SHA256(),
57+
label=None
58+
)
59+
)
60+
return plaintext.decode('utf-8')
61+
else:
62+
encrypted_aes_key = base64.b64decode(encrypted_aes_key)
63+
iv = base64.b64decode(aes_iv)
64+
ciphertext = base64.b64decode(encrypted_value)
65+
66+
aes_key = private_key.decrypt(
67+
encrypted_aes_key,
68+
padding.OAEP(
69+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
70+
algorithm=hashes.SHA256(),
71+
label=None
72+
)
73+
)
74+
cipher = Cipher(
75+
algorithms.AES(aes_key),
76+
modes.CFB(iv),
77+
)
78+
decryptor = cipher.decryptor()
79+
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
80+
return plaintext.decode('utf-8')
81+
82+
83+
def test_encrypt_rsa(rsa_key_pair):
84+
private_key_pem, public_key_pem = rsa_key_pair
85+
86+
client = UnstructuredClient()
87+
88+
plaintext = "This is a secret message."
89+
90+
secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)
91+
92+
# A short payload should use direct RSA encryption
93+
assert secret_obj["type"] == 'rsa'
94+
95+
decrypted_text = decrypt_secret(
96+
private_key_pem,
97+
secret_obj["encrypted_value"],
98+
secret_obj["type"],
99+
"",
100+
"",
101+
)
102+
assert decrypted_text == plaintext
103+
104+
assert True
105+
106+
107+
def test_encrypt_rsa_aes(rsa_key_pair):
108+
private_key_pem, public_key_pem = rsa_key_pair
109+
110+
client = UnstructuredClient()
111+
112+
plaintext = "This is a secret message." * 100
113+
114+
secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)
115+
116+
# A longer payload uses hybrid RSA-AES encryption
117+
assert secret_obj["type"] == 'rsa_aes'
118+
119+
decrypted_text = decrypt_secret(
120+
private_key_pem,
121+
secret_obj["encrypted_value"],
122+
secret_obj["type"],
123+
secret_obj["encrypted_aes_key"],
124+
secret_obj["aes_iv"],
125+
)
126+
assert decrypted_text == plaintext
127+
128+
assert True

0 commit comments

Comments
 (0)