Skip to content

Commit b945912

Browse files
committed
Use the correct max size for RSA
1 parent eec2a77 commit b945912

File tree

2 files changed

+51
-33
lines changed

2 files changed

+51
-33
lines changed

_test_unstructured_client/unit/test_encryption.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,36 @@ def test_encrypt_rsa_aes(rsa_key_pair):
121121
secret_obj["encrypted_aes_key"],
122122
secret_obj["aes_iv"],
123123
)
124-
assert decrypted_text == plaintext
124+
assert decrypted_text == plaintext
125+
126+
127+
rsa_key_size_bytes = 2048 // 8
128+
max_payload_size = rsa_key_size_bytes - 66 # OAEP SHA256 overhead
129+
130+
@pytest.mark.parametrize(("plaintext", "secret_type"), [
131+
("Short message", "rsa"),
132+
("A" * (max_payload_size), "rsa"), # Just at the RSA limit
133+
("A" * (max_payload_size + 1), "rsa_aes"), # Just over the RSA limit
134+
("A" * 500, "rsa_aes"), # Well over the RSA limit
135+
])
136+
def test_encrypt_around_rsa_size_limit(rsa_key_pair, plaintext, secret_type):
137+
"""
138+
Test that payloads around the RSA size limit choose the correct algorithm.
139+
"""
140+
_, public_key_pem = rsa_key_pair
141+
142+
print(f"Testing plaintext of length {len(plaintext)} with expected type {secret_type}")
143+
144+
# Load the public key
145+
public_key = serialization.load_pem_public_key(
146+
public_key_pem.encode('utf-8'),
147+
backend=default_backend()
148+
)
149+
150+
client = UnstructuredClient()
151+
152+
secret_obj = client.users.encrypt_secret(public_key_pem, plaintext)
153+
154+
# Should still use direct RSA encryption
155+
assert secret_obj["type"] == secret_type
156+
assert secret_obj["encrypted_value"] is not None

src/unstructured_client/users.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -471,18 +471,9 @@ async def store_secret_async(
471471
# region sdk-class-body
472472
def _encrypt_rsa_aes(
473473
self,
474-
encryption_key_pem: str,
474+
public_key: rsa.RSAPublicKey,
475475
plaintext: str,
476476
) -> dict:
477-
# Load public RSA key
478-
public_key = serialization.load_pem_public_key(
479-
encryption_key_pem.encode('utf-8'),
480-
backend=default_backend()
481-
)
482-
483-
if not isinstance(public_key, rsa.RSAPublicKey):
484-
raise TypeError("Public key must be an RSA public key for envelope encryption.")
485-
486477
# Generate a random AES key
487478
aes_key = os.urandom(32) # 256-bit AES key
488479

@@ -516,18 +507,10 @@ def _encrypt_rsa_aes(
516507

517508
def _encrypt_rsa(
518509
self,
519-
encryption_key_pem: str,
510+
public_key: rsa.RSAPublicKey,
520511
plaintext: str,
521512
) -> dict:
522513
# Load public RSA key
523-
public_key = serialization.load_pem_public_key(
524-
encryption_key_pem.encode('utf-8'),
525-
backend=default_backend()
526-
)
527-
528-
if not isinstance(public_key, rsa.RSAPublicKey):
529-
raise TypeError("Public key must be an RSA public key for encryption.")
530-
531514
ciphertext = public_key.encrypt(
532515
plaintext.encode(),
533516
padding.OAEP(
@@ -567,25 +550,28 @@ def encrypt_secret(
567550
encryption_cert_or_key_pem.encode('utf-8'),
568551
)
569552

570-
loaded_key = cert.public_key()
571-
572-
# Serialize back to PEM format for consistency
573-
public_key_pem = loaded_key.public_bytes(
574-
encoding=serialization.Encoding.PEM,
575-
format=serialization.PublicFormat.SubjectPublicKeyInfo
576-
).decode('utf-8')
577-
553+
public_key = cert.public_key()
578554
else:
579-
public_key_pem = encryption_cert_or_key_pem
555+
public_key = serialization.load_pem_public_key(
556+
encryption_cert_or_key_pem.encode('utf-8'),
557+
backend=default_backend()
558+
)
559+
560+
if not isinstance(public_key, rsa.RSAPublicKey):
561+
raise TypeError("Public key must be a RSA public key for encryption.")
580562

581563
# If the plaintext is short, use RSA directly
582564
# Otherwise, use a RSA_AES envelope hybrid
583-
# The length of the public key is a good hueristic
565+
# Use the length of the public key to determine the encryption type
566+
key_size_bytes = public_key.key_size // 8
567+
max_rsa_length = key_size_bytes - 66 # OAEP SHA256 overhead
568+
print(max_rsa_length)
569+
584570
if not encryption_type:
585-
encryption_type = "rsa" if len(plaintext) <= len(public_key_pem) else "rsa_aes"
571+
encryption_type = "rsa" if len(plaintext) <= max_rsa_length else "rsa_aes"
586572

587573
if encryption_type == "rsa":
588-
return self._encrypt_rsa(public_key_pem, plaintext)
574+
return self._encrypt_rsa(public_key, plaintext)
589575

590-
return self._encrypt_rsa_aes(public_key_pem, plaintext)
576+
return self._encrypt_rsa_aes(public_key, plaintext)
591577
# endregion sdk-class-body

0 commit comments

Comments
 (0)