diff --git a/_test_unstructured_client/unit/test_encryption.py b/_test_unstructured_client/unit/test_encryption.py new file mode 100644 index 00000000..3ec5d6dc --- /dev/null +++ b/_test_unstructured_client/unit/test_encryption.py @@ -0,0 +1,108 @@ +from cryptography import x509 +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend +import os +import base64 +from typing import Optional + +import pytest + +from unstructured_client import UnstructuredClient + +@pytest.fixture +def rsa_key_pair(): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + public_key = private_key.public_key() + + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + ).decode('utf-8') + + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ).decode('utf-8') + + return private_key_pem, public_key_pem + +def test_encrypt_rsa(rsa_key_pair): + private_key_pem, public_key_pem = rsa_key_pair + + client = UnstructuredClient() + + plaintext = "This is a secret message." + + secret_obj = client.users.encrypt_secret(public_key_pem, plaintext) + + # A short payload should use direct RSA encryption + assert secret_obj["type"] == 'rsa' + + decrypted_text = client.users.decrypt_secret( + private_key_pem, + secret_obj["encrypted_value"], + secret_obj["type"], + "", + "", + ) + assert decrypted_text == plaintext + + +def test_encrypt_rsa_aes(rsa_key_pair): + private_key_pem, public_key_pem = rsa_key_pair + + client = UnstructuredClient() + + plaintext = "This is a secret message." * 100 + + secret_obj = client.users.encrypt_secret(public_key_pem, plaintext) + + # A longer payload uses hybrid RSA-AES encryption + assert secret_obj["type"] == 'rsa_aes' + + decrypted_text = client.users.decrypt_secret( + private_key_pem, + secret_obj["encrypted_value"], + secret_obj["type"], + secret_obj["encrypted_aes_key"], + secret_obj["aes_iv"], + ) + assert decrypted_text == plaintext + + +rsa_key_size_bytes = 2048 // 8 +max_payload_size = rsa_key_size_bytes - 66 # OAEP SHA256 overhead + +@pytest.mark.parametrize(("plaintext", "secret_type"), [ + ("Short message", "rsa"), + ("A" * (max_payload_size), "rsa"), # Just at the RSA limit + ("A" * (max_payload_size + 1), "rsa_aes"), # Just over the RSA limit + ("A" * 500, "rsa_aes"), # Well over the RSA limit +]) +def test_encrypt_around_rsa_size_limit(rsa_key_pair, plaintext, secret_type): + """ + Test that payloads around the RSA size limit choose the correct algorithm. + """ + _, public_key_pem = rsa_key_pair + + print(f"Testing plaintext of length {len(plaintext)} with expected type {secret_type}") + + # Load the public key + public_key = serialization.load_pem_public_key( + public_key_pem.encode('utf-8'), + backend=default_backend() + ) + + client = UnstructuredClient() + + secret_obj = client.users.encrypt_secret(public_key_pem, plaintext) + + assert secret_obj["type"] == secret_type + assert secret_obj["encrypted_value"] is not None \ No newline at end of file diff --git a/gen.yaml b/gen.yaml index fa03dd01..dc366447 100644 --- a/gen.yaml +++ b/gen.yaml @@ -39,7 +39,7 @@ python: clientServerStatusCodesAsErrors: true defaultErrorName: SDKError description: Python Client SDK for Unstructured API - enableCustomCodeRegions: false + enableCustomCodeRegions: true enumFormat: enum fixFlags: responseRequiredSep2024: false diff --git a/src/unstructured_client/users.py b/src/unstructured_client/users.py index 1a4c5ab3..6ce52bf7 100644 --- a/src/unstructured_client/users.py +++ b/src/unstructured_client/users.py @@ -7,6 +7,15 @@ from unstructured_client.models import errors, operations, shared from unstructured_client.types import BaseModel, OptionalNullable, UNSET +# region imports +from cryptography import x509 +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.backends import default_backend +import os +import base64 +# endregion imports class Users(BaseSDK): def retrieve( @@ -458,3 +467,160 @@ async def store_secret_async( http_res_text, http_res, ) + + # region sdk-class-body + def _encrypt_rsa_aes( + self, + public_key: rsa.RSAPublicKey, + plaintext: str, + ) -> dict: + # Generate a random AES key + aes_key = os.urandom(32) # 256-bit AES key + + # Generate a random IV + iv = os.urandom(16) + + # Encrypt using AES-CFB + cipher = Cipher( + algorithms.AES(aes_key), + modes.CFB(iv), + ) + encryptor = cipher.encryptor() + ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize() + + # Encrypt the AES key using the RSA public key + encrypted_key = public_key.encrypt( + aes_key, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + ) + ) + + return { + 'encrypted_aes_key': base64.b64encode(encrypted_key).decode('utf-8'), + 'aes_iv': base64.b64encode(iv).decode('utf-8'), + 'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'), + 'type': 'rsa_aes', + } + + def _encrypt_rsa( + self, + public_key: rsa.RSAPublicKey, + plaintext: str, + ) -> dict: + # Load public RSA key + ciphertext = public_key.encrypt( + plaintext.encode(), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + ), + ) + return { + 'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'), + 'type': 'rsa', + 'encrypted_aes_key': "", + 'aes_iv': "", + } + + def decrypt_secret( + self, + private_key_pem: str, + encrypted_value: str, + secret_type: str, + encrypted_aes_key: str, + aes_iv: str, + ) -> str: + private_key = serialization.load_pem_private_key( + private_key_pem.encode('utf-8'), + password=None, + backend=default_backend() + ) + + if not isinstance(private_key, rsa.RSAPrivateKey): + raise TypeError("Private key must be a RSA private key for decryption.") + + if secret_type == 'rsa': + ciphertext = base64.b64decode(encrypted_value) + plaintext = private_key.decrypt( + ciphertext, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + ) + ) + return plaintext.decode('utf-8') + + # aes_rsa + encrypted_aes_key_decoded = base64.b64decode(encrypted_aes_key) + iv = base64.b64decode(aes_iv) + ciphertext = base64.b64decode(encrypted_value) + + aes_key = private_key.decrypt( + encrypted_aes_key_decoded, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + ) + ) + cipher = Cipher( + algorithms.AES(aes_key), + modes.CFB(iv), + ) + decryptor = cipher.decryptor() + plaintext = decryptor.update(ciphertext) + decryptor.finalize() + return plaintext.decode('utf-8') + + def encrypt_secret( + self, + encryption_cert_or_key_pem: str, + plaintext: str, + encryption_type: Optional[str] = None, + ) -> dict: + """ + Encrypts a plaintext string for securely sending to the Unstructured API. + + Args: + encryption_cert_or_key_pem (str): A PEM-encoded RSA public key or certificate. + plaintext (str): The string to encrypt. + type (str, optional): Encryption type, either "rsa" or "rsa_aes". + + Returns: + dict: A dictionary with encrypted AES key, iv, and ciphertext (all base64-encoded). + """ + # If a cert is provided, extract the public key + if "BEGIN CERTIFICATE" in encryption_cert_or_key_pem: + cert = x509.load_pem_x509_certificate( + encryption_cert_or_key_pem.encode('utf-8'), + ) + + public_key = cert.public_key() # type: ignore[assignment] + else: + public_key = serialization.load_pem_public_key( + encryption_cert_or_key_pem.encode('utf-8'), + backend=default_backend() + ) # type: ignore[assignment] + + if not isinstance(public_key, rsa.RSAPublicKey): + raise TypeError("Public key must be a RSA public key for encryption.") + + # If the plaintext is short, use RSA directly + # Otherwise, use a RSA_AES envelope hybrid + # Use the length of the public key to determine the encryption type + key_size_bytes = public_key.key_size // 8 + max_rsa_length = key_size_bytes - 66 # OAEP SHA256 overhead + print(max_rsa_length) + + if not encryption_type: + encryption_type = "rsa" if len(plaintext) <= max_rsa_length else "rsa_aes" + + if encryption_type == "rsa": + return self._encrypt_rsa(public_key, plaintext) + + return self._encrypt_rsa_aes(public_key, plaintext) + # endregion sdk-class-body