1+ import http .server
2+ import json
3+ import os
4+ import ssl
5+ import tempfile
6+ import threading
7+ from collections import Counter
8+ from datetime import datetime , timedelta , timezone
19from pathlib import Path
10+ from typing import Generator
211
312import pytest
13+ from cryptography import x509
14+ from cryptography .hazmat .primitives import hashes , serialization
15+ from cryptography .hazmat .primitives .asymmetric import rsa
16+ from cryptography .x509 .oid import NameOID
417
518
619@pytest .fixture
@@ -11,3 +24,76 @@ def embedder_file() -> Path:
1124 assert embedder_file .exists ()
1225 assert embedder_file .is_file ()
1326 return embedder_file
27+
28+
29+ _EMBEDDINGS_CALLS = Counter ()
30+
31+
32+ class MockOpenAIEmbeddingsHandler (http .server .SimpleHTTPRequestHandler ):
33+ """
34+ Minimal OpenAPI Completions mock server
35+ """
36+
37+ def do_POST (self ):
38+ global _EMBEDDINGS_CALLS
39+ _EMBEDDINGS_CALLS ["POST" ] += 1
40+ self .send_response (200 )
41+ self .send_header ("Content-type" , "application/json" )
42+ self .end_headers ()
43+ body = {
44+ "data" : [{"object" : "embedding" , "embedding" : [], "index" : 0 }],
45+ "object" : "list" ,
46+ "model" : "text-embedding-ada-002" ,
47+ "usage" : {"prompt_tokens" : 1 , "total_tokens" : 2 },
48+ }
49+ self .wfile .write (json .dumps (body ).encode ("utf-8" ))
50+
51+
52+ @pytest .fixture (scope = "module" )
53+ def mock_embeddings_server () -> Generator [tuple [int , str , Counter ], None , None ]:
54+ """
55+ Runs a dead-simple HTTPS server on a random port, in a thread, with a custom TLS certificate.
56+ """
57+
58+ private_key = rsa .generate_private_key (public_exponent = 65537 , key_size = 2048 )
59+ subject = issuer = x509 .Name ([x509 .NameAttribute (NameOID .COMMON_NAME , "localhost" )])
60+ cert = (
61+ x509 .CertificateBuilder ()
62+ .subject_name (subject )
63+ .issuer_name (issuer )
64+ .public_key (private_key .public_key ())
65+ .serial_number (x509 .random_serial_number ())
66+ .not_valid_before (datetime .now (timezone .utc ))
67+ .not_valid_after (datetime .now (timezone .utc ) + timedelta (days = 1 ))
68+ .sign (private_key , hashes .SHA256 ())
69+ )
70+
71+ with (
72+ tempfile .NamedTemporaryFile (delete = False , suffix = ".pem" ) as cert_file ,
73+ tempfile .NamedTemporaryFile (delete = False , suffix = ".pem" ) as key_file ,
74+ ):
75+ cert_file .write (cert .public_bytes (serialization .Encoding .PEM ))
76+ key_file .write (
77+ private_key .private_bytes (
78+ encoding = serialization .Encoding .PEM ,
79+ format = serialization .PrivateFormat .PKCS8 ,
80+ encryption_algorithm = serialization .NoEncryption (),
81+ )
82+ )
83+ cert_fpath = cert_file .name
84+ privkey_fpath = key_file .name
85+
86+ context = ssl .SSLContext (ssl .PROTOCOL_TLS_SERVER )
87+ context .load_cert_chain (certfile = cert_fpath , keyfile = privkey_fpath , password = "" )
88+ server_address = "127.0.0.1" , 0
89+
90+ httpd = http .server .HTTPServer (server_address , MockOpenAIEmbeddingsHandler )
91+ httpd .socket = context .wrap_socket (httpd .socket , server_side = True )
92+ thread = threading .Thread (target = httpd .serve_forever )
93+ thread .daemon = True
94+ thread .start ()
95+ yield httpd .server_port , cert_fpath , _EMBEDDINGS_CALLS
96+ httpd .shutdown ()
97+ thread .join ()
98+ os .unlink (cert_fpath )
99+ os .unlink (privkey_fpath )
0 commit comments