1+ from cryptography import x509
2+ from cryptography .hazmat .primitives import serialization , hashes
3+ from cryptography .hazmat .primitives .asymmetric import padding , rsa
4+ from cryptography .hazmat .primitives .ciphers import Cipher , algorithms , modes
5+ from cryptography .hazmat .backends import default_backend
6+ import os
7+ import base64
8+ from typing import Optional
9+
10+ import pytest
11+
12+ from unstructured_client import UnstructuredClient
13+
14+ @pytest .fixture
15+ def rsa_key_pair ():
16+ private_key = rsa .generate_private_key (
17+ public_exponent = 65537 ,
18+ key_size = 2048 ,
19+ backend = default_backend ()
20+ )
21+ public_key = private_key .public_key ()
22+
23+ private_key_pem = private_key .private_bytes (
24+ encoding = serialization .Encoding .PEM ,
25+ format = serialization .PrivateFormat .TraditionalOpenSSL ,
26+ encryption_algorithm = serialization .NoEncryption ()
27+ ).decode ('utf-8' )
28+
29+ public_key_pem = public_key .public_bytes (
30+ encoding = serialization .Encoding .PEM ,
31+ format = serialization .PublicFormat .SubjectPublicKeyInfo
32+ ).decode ('utf-8' )
33+
34+ return private_key_pem , public_key_pem
35+
36+
37+ def decrypt_secret (
38+ private_key_pem : str ,
39+ encrypted_value : str ,
40+ type : str ,
41+ encrypted_aes_key : str ,
42+ aes_iv : str ,
43+ ) -> str :
44+ private_key = serialization .load_pem_private_key (
45+ private_key_pem .encode ('utf-8' ),
46+ password = None ,
47+ backend = default_backend ()
48+ )
49+
50+ if type == 'rsa' :
51+ ciphertext = base64 .b64decode (encrypted_value )
52+ plaintext = private_key .decrypt (
53+ ciphertext ,
54+ padding .OAEP (
55+ mgf = padding .MGF1 (algorithm = hashes .SHA256 ()),
56+ algorithm = hashes .SHA256 (),
57+ label = None
58+ )
59+ )
60+ return plaintext .decode ('utf-8' )
61+ else :
62+ encrypted_aes_key = base64 .b64decode (encrypted_aes_key )
63+ iv = base64 .b64decode (aes_iv )
64+ ciphertext = base64 .b64decode (encrypted_value )
65+
66+ aes_key = private_key .decrypt (
67+ encrypted_aes_key ,
68+ padding .OAEP (
69+ mgf = padding .MGF1 (algorithm = hashes .SHA256 ()),
70+ algorithm = hashes .SHA256 (),
71+ label = None
72+ )
73+ )
74+ cipher = Cipher (
75+ algorithms .AES (aes_key ),
76+ modes .CFB (iv ),
77+ )
78+ decryptor = cipher .decryptor ()
79+ plaintext = decryptor .update (ciphertext ) + decryptor .finalize ()
80+ return plaintext .decode ('utf-8' )
81+
82+
83+ def test_encrypt_rsa (rsa_key_pair ):
84+ private_key_pem , public_key_pem = rsa_key_pair
85+
86+ client = UnstructuredClient ()
87+
88+ plaintext = "This is a secret message."
89+
90+ secret_obj = client .users .encrypt_secret (public_key_pem , plaintext )
91+
92+ # A short payload should use direct RSA encryption
93+ assert secret_obj ["type" ] == 'rsa'
94+
95+ decrypted_text = decrypt_secret (
96+ private_key_pem ,
97+ secret_obj ["encrypted_value" ],
98+ secret_obj ["type" ],
99+ "" ,
100+ "" ,
101+ )
102+ assert decrypted_text == plaintext
103+
104+ assert True
105+
106+
107+ def test_encrypt_rsa_aes (rsa_key_pair ):
108+ private_key_pem , public_key_pem = rsa_key_pair
109+
110+ client = UnstructuredClient ()
111+
112+ plaintext = "This is a secret message." * 100
113+
114+ secret_obj = client .users .encrypt_secret (public_key_pem , plaintext )
115+
116+ # A longer payload uses hybrid RSA-AES encryption
117+ assert secret_obj ["type" ] == 'rsa_aes'
118+
119+ decrypted_text = decrypt_secret (
120+ private_key_pem ,
121+ secret_obj ["encrypted_value" ],
122+ secret_obj ["type" ],
123+ secret_obj ["encrypted_aes_key" ],
124+ secret_obj ["aes_iv" ],
125+ )
126+ assert decrypted_text == plaintext
127+
128+ assert True
0 commit comments