Skip to content

Commit 3e9b0de

Browse files
committed
Add encrypt_secret helper function
1 parent dbf1c07 commit 3e9b0de

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

gen.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ python:
3939
clientServerStatusCodesAsErrors: true
4040
defaultErrorName: SDKError
4141
description: Python Client SDK for Unstructured API
42-
enableCustomCodeRegions: false
42+
enableCustomCodeRegions: true
4343
enumFormat: enum
4444
fixFlags:
4545
responseRequiredSep2024: false

src/unstructured_client/users.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77
from unstructured_client.models import errors, operations, shared
88
from unstructured_client.types import BaseModel, OptionalNullable, UNSET
99

10+
# region imports
11+
from cryptography import x509
12+
from cryptography.hazmat.primitives import serialization, hashes
13+
from cryptography.hazmat.primitives.serialization import load_pem_public_key
14+
from cryptography.hazmat.primitives.asymmetric import padding, rsa
15+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
16+
from cryptography.hazmat.backends import default_backend
17+
import os
18+
import base64
19+
# endregion imports
1020

1121
class Users(BaseSDK):
1222
def retrieve(
@@ -458,3 +468,119 @@ async def store_secret_async(
458468
http_res_text,
459469
http_res,
460470
)
471+
472+
# region sdk-class-body
473+
def _encrypt_rsa_aes(
474+
self,
475+
encryption_key_pem: str,
476+
plaintext: str,
477+
) -> dict:
478+
# Load public RSA key
479+
public_key = serialization.load_pem_public_key(
480+
encryption_key_pem.encode('utf-8'),
481+
backend=default_backend()
482+
)
483+
484+
# Generate a random AES key
485+
aes_key = os.urandom(32) # 256-bit AES key
486+
487+
# Generate a random IV
488+
iv = os.urandom(16)
489+
490+
# Encrypt using AES-CFB
491+
cipher = Cipher(
492+
algorithms.AES(aes_key),
493+
modes.CFB(iv),
494+
)
495+
encryptor = cipher.encryptor()
496+
ciphertext = encryptor.update(plaintext.encode('utf-8')) + encryptor.finalize()
497+
498+
# Encrypt the AES key using the RSA public key
499+
encrypted_key = public_key.encrypt(
500+
aes_key,
501+
padding.OAEP(
502+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
503+
algorithm=hashes.SHA256(),
504+
label=None
505+
)
506+
)
507+
508+
return {
509+
'encrypted_aes_key': base64.b64encode(encrypted_key).decode('utf-8'),
510+
'aes_iv': base64.b64encode(iv).decode('utf-8'),
511+
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
512+
'type': 'rsa_aes',
513+
}
514+
515+
def _encrypt_rsa(
516+
self,
517+
encryption_key_pem: str,
518+
plaintext: str,
519+
) -> dict:
520+
# Load public RSA key
521+
public_key = serialization.load_pem_public_key(
522+
encryption_key_pem.encode('utf-8'),
523+
backend=default_backend()
524+
)
525+
526+
ciphertext = public_key.encrypt(
527+
plaintext.encode(),
528+
padding.OAEP(
529+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
530+
algorithm=hashes.SHA256(),
531+
label=None
532+
),
533+
)
534+
return {
535+
'encrypted_value': base64.b64encode(ciphertext).decode('utf-8'),
536+
'type': 'rsa',
537+
'encrypted_aes_key': "",
538+
'aes_iv': "",
539+
}
540+
541+
542+
def encrypt_secret(
543+
self,
544+
encryption_cert_or_key_pem: str,
545+
plaintext: str,
546+
type: Optional[str] = None,
547+
) -> dict:
548+
"""
549+
Encrypts a plaintext string for securely sending to the Unstructured API.
550+
551+
Args:
552+
encryption_cert_or_key_pem (str): A PEM-encoded RSA public key or certificate.
553+
plaintext (str): The string to encrypt.
554+
type (str, optional): Encryption type, either "rsa" or "rsa_aes".
555+
556+
Returns:
557+
dict: A dictionary with encrypted AES key, iv, and ciphertext (all base64-encoded).
558+
"""
559+
# If a cert is provided, extract the public key
560+
if "BEGIN CERTIFICATE" in encryption_cert_or_key_pem:
561+
cert = x509.load_pem_x509_certificate(
562+
encryption_cert_or_key_pem.encode('utf-8'),
563+
)
564+
565+
loaded_key = cert.public_key()
566+
567+
# Serialize back to PEM format for consistency
568+
public_key_pem = loaded_key.public_bytes(
569+
encoding=serialization.Encoding.PEM,
570+
format=serialization.PublicFormat.SubjectPublicKeyInfo
571+
).decode('utf-8')
572+
573+
else:
574+
public_key_pem = encryption_cert_or_key_pem
575+
576+
# If the plaintext is short, use RSA directly
577+
# Otherwise, use a RSA_AES envelope hybrid
578+
# The length of the public key is a good hueristic
579+
if not type:
580+
type = "rsa" if len(plaintext) <= len(public_key_pem) else "rsa_aes"
581+
582+
if type == "rsa":
583+
return self._encrypt_rsa(public_key_pem, plaintext)
584+
else:
585+
return self._encrypt_rsa_aes(public_key_pem, plaintext)
586+
# endregion sdk-class-body

0 commit comments

Comments
 (0)