Skip to content

Commit 3c527df

Browse files
Allow setting REQUESTS_CA_BUNDLE to use custom certs for OpenAI embedders (#561)
If your embedding models are hosted within an enterprise/private PKI environment, you many need to trust a custom certificate bundle when interacting with them. This change allows setting REQUESTS_CA_BUNDLE to a file which will be used by the HTTP client in OpenAI and Azure OpenAI embedding interactions. --------- Co-authored-by: ryannikolaidis <[email protected]>
1 parent 7c204cf commit 3c527df

File tree

9 files changed

+179
-12
lines changed

9 files changed

+179
-12
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 1.1.0
2+
3+
- **Feature**: Embedding with OpenAI (or Azure OpenAI) can trust custom certificate authority by specifying environment variable REQUESTS_CA_BUNDLE.
4+
15
## 1.0.59
26

37
* **o11y: Downgrade OTEL logs to `DEBUG` by default, make it configurable**

requirements/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pydantic>=2.7
44
tqdm
55
click
66
opentelemetry-sdk
7+
certifi>=2025.7.14

test/integration/embedders/conftest.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
1+
import http.server
2+
import json
3+
import os
4+
import ssl
5+
import tempfile
6+
import threading
7+
from collections import Counter
8+
from datetime import datetime, timedelta, timezone
19
from pathlib import Path
10+
from typing import Generator
211

312
import pytest
13+
from cryptography import x509
14+
from cryptography.hazmat.primitives import hashes, serialization
15+
from cryptography.hazmat.primitives.asymmetric import rsa
16+
from cryptography.x509.oid import NameOID
417

518

619
@pytest.fixture
@@ -11,3 +24,76 @@ def embedder_file() -> Path:
1124
assert embedder_file.exists()
1225
assert embedder_file.is_file()
1326
return embedder_file
27+
28+
29+
_EMBEDDINGS_CALLS = Counter()
30+
31+
32+
class MockOpenAIEmbeddingsHandler(http.server.SimpleHTTPRequestHandler):
33+
"""
34+
Minimal OpenAPI Completions mock server
35+
"""
36+
37+
def do_POST(self):
38+
global _EMBEDDINGS_CALLS
39+
_EMBEDDINGS_CALLS["POST"] += 1
40+
self.send_response(200)
41+
self.send_header("Content-type", "application/json")
42+
self.end_headers()
43+
body = {
44+
"data": [{"object": "embedding", "embedding": [], "index": 0}],
45+
"object": "list",
46+
"model": "text-embedding-ada-002",
47+
"usage": {"prompt_tokens": 1, "total_tokens": 2},
48+
}
49+
self.wfile.write(json.dumps(body).encode("utf-8"))
50+
51+
52+
@pytest.fixture(scope="module")
53+
def mock_embeddings_server() -> Generator[tuple[int, str, Counter], None, None]:
54+
"""
55+
Runs a dead-simple HTTPS server on a random port, in a thread, with a custom TLS certificate.
56+
"""
57+
58+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
59+
subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "localhost")])
60+
cert = (
61+
x509.CertificateBuilder()
62+
.subject_name(subject)
63+
.issuer_name(issuer)
64+
.public_key(private_key.public_key())
65+
.serial_number(x509.random_serial_number())
66+
.not_valid_before(datetime.now(timezone.utc))
67+
.not_valid_after(datetime.now(timezone.utc) + timedelta(days=1))
68+
.sign(private_key, hashes.SHA256())
69+
)
70+
71+
with (
72+
tempfile.NamedTemporaryFile(delete=False, suffix=".pem") as cert_file,
73+
tempfile.NamedTemporaryFile(delete=False, suffix=".pem") as key_file,
74+
):
75+
cert_file.write(cert.public_bytes(serialization.Encoding.PEM))
76+
key_file.write(
77+
private_key.private_bytes(
78+
encoding=serialization.Encoding.PEM,
79+
format=serialization.PrivateFormat.PKCS8,
80+
encryption_algorithm=serialization.NoEncryption(),
81+
)
82+
)
83+
cert_fpath = cert_file.name
84+
privkey_fpath = key_file.name
85+
86+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
87+
context.load_cert_chain(certfile=cert_fpath, keyfile=privkey_fpath, password="")
88+
server_address = "127.0.0.1", 0
89+
90+
httpd = http.server.HTTPServer(server_address, MockOpenAIEmbeddingsHandler)
91+
httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
92+
thread = threading.Thread(target=httpd.serve_forever)
93+
thread.daemon = True
94+
thread.start()
95+
yield httpd.server_port, cert_fpath, _EMBEDDINGS_CALLS
96+
httpd.shutdown()
97+
thread.join()
98+
os.unlink(cert_fpath)
99+
os.unlink(privkey_fpath)

test/integration/embedders/test_azure_openai.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from dataclasses import dataclass
44
from pathlib import Path
55

6+
import pydantic
7+
import pytest
8+
69
from test.integration.embedders.utils import validate_embedding_output, validate_raw_embedder
710
from test.integration.utils import requires_env
811
from unstructured_ingest.embed.azure_openai import (
@@ -57,3 +60,43 @@ def test_raw_azure_openai_embedder(embedder_file: Path):
5760
)
5861
embedder.precheck()
5962
validate_raw_embedder(embedder=embedder, embedder_file=embedder_file, expected_dimension=1536)
63+
64+
65+
def test_openai_custom_tls_no_override_should_fail(mock_embeddings_server, embedder_file: Path):
66+
from openai import APIConnectionError
67+
68+
port, certificate_path, counter = mock_embeddings_server
69+
calls_before = counter["POST"]
70+
with pytest.raises(APIConnectionError):
71+
embedder_config = EmbedderConfig(
72+
embedding_provider="azure-openai",
73+
embedding_api_key=pydantic.SecretStr("foo"),
74+
embedding_azure_endpoint=f"https://localhost:{port}",
75+
)
76+
embedder = Embedder(config=embedder_config)
77+
embedder.precheck()
78+
_ = embedder.run(elements_filepath=embedder_file)
79+
80+
assert counter["POST"] == calls_before, (
81+
f"Expected to see no change to POST calls toward embedder, got {counter} != {calls_before}"
82+
)
83+
84+
85+
def test_openai_custom_tls_with_override_should_succeed(
86+
mock_embeddings_server, monkeypatch, embedder_file: Path
87+
):
88+
port, certificate_path, counter = mock_embeddings_server
89+
calls_before = counter["POST"]
90+
monkeypatch.setenv("REQUESTS_CA_BUNDLE", certificate_path)
91+
embedder_config = EmbedderConfig(
92+
embedding_provider="azure-openai",
93+
embedding_api_key=pydantic.SecretStr("foo"),
94+
embedding_azure_endpoint=f"https://localhost:{port}",
95+
)
96+
embedder = Embedder(config=embedder_config)
97+
embedder.precheck()
98+
results = embedder.run(elements_filepath=embedder_file)
99+
assert results
100+
assert counter["POST"] > calls_before, (
101+
f"Expected to see more POST calls to embedder, got {counter} from {calls_before}"
102+
)

unstructured_ingest/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.59" # pragma: no cover
1+
__version__ = "1.1.0" # pragma: no cover

unstructured_ingest/embed/azure_openai.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
OpenAIEmbeddingEncoder,
1010
)
1111
from unstructured_ingest.utils.dep_check import requires_dependencies
12+
from unstructured_ingest.utils.tls import ssl_context_with_optional_ca_override
1213

1314
if TYPE_CHECKING:
1415
from openai import AsyncAzureOpenAI, AzureOpenAI
@@ -23,19 +24,23 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
2324

2425
@requires_dependencies(["openai"], extras="openai")
2526
def get_client(self) -> "AzureOpenAI":
26-
from openai import AzureOpenAI
27+
from openai import AzureOpenAI, DefaultHttpxClient
2728

29+
client = DefaultHttpxClient(verify=ssl_context_with_optional_ca_override())
2830
return AzureOpenAI(
31+
http_client=client,
2932
api_key=self.api_key.get_secret_value(),
3033
api_version=self.api_version,
3134
azure_endpoint=self.azure_endpoint,
3235
)
3336

3437
@requires_dependencies(["openai"], extras="openai")
3538
def get_async_client(self) -> "AsyncAzureOpenAI":
36-
from openai import AsyncAzureOpenAI
39+
from openai import AsyncAzureOpenAI, DefaultAsyncHttpxClient
3740

41+
client = DefaultAsyncHttpxClient(verify=ssl_context_with_optional_ca_override())
3842
return AsyncAzureOpenAI(
43+
http_client=client,
3944
api_key=self.api_key.get_secret_value(),
4045
api_version=self.api_version,
4146
azure_endpoint=self.azure_endpoint,

unstructured_ingest/embed/openai.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from unstructured_ingest.logger import logger
2020
from unstructured_ingest.utils.dep_check import requires_dependencies
21+
from unstructured_ingest.utils.tls import ssl_context_with_optional_ca_override
2122

2223
if TYPE_CHECKING:
2324
from openai import AsyncOpenAI, OpenAI
@@ -86,15 +87,21 @@ def run_precheck(self) -> None:
8687

8788
@requires_dependencies(["openai"], extras="openai")
8889
def get_client(self) -> "OpenAI":
89-
from openai import OpenAI
90+
from openai import DefaultHttpxClient, OpenAI
9091

91-
return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
92+
client = DefaultHttpxClient(verify=ssl_context_with_optional_ca_override())
93+
return OpenAI(
94+
api_key=self.api_key.get_secret_value(), http_client=client, base_url=self.base_url
95+
)
9296

9397
@requires_dependencies(["openai"], extras="openai")
9498
def get_async_client(self) -> "AsyncOpenAI":
95-
from openai import AsyncOpenAI
99+
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
96100

97-
return AsyncOpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
101+
client = DefaultAsyncHttpxClient(verify=ssl_context_with_optional_ca_override())
102+
return AsyncOpenAI(
103+
api_key=self.api_key.get_secret_value(), http_client=client, base_url=self.base_url
104+
)
98105

99106

100107
@dataclass

unstructured_ingest/utils/tls.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
import ssl
3+
4+
import certifi
5+
6+
7+
def ssl_context_with_optional_ca_override():
8+
"""
9+
# https://www.python-httpx.org/advanced/ssl/#working-with-ssl_cert_file-and-ssl_cert_dir
10+
# We choose REQUESTS_CA_BUNDLE because that works with many other Python packages.
11+
"""
12+
return ssl.create_default_context(
13+
cafile=os.environ.get("REQUESTS_CA_BUNDLE", certifi.where()),
14+
capath=os.environ.get("REQUESTS_CA_BUNDLE"),
15+
)

uv.lock

Lines changed: 11 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)