Skip to content

Commit 56d60ec

Browse files
authored
Fix prepdocs compatibility with openai key and add test (#605)
* Remove defaults for getenv * Remove print * missing output * Add tests and fix prepdocs issue * rm uneeded print
1 parent 2f792eb commit 56d60ec

File tree

2 files changed

+81
-14
lines changed

2 files changed

+81
-14
lines changed

scripts/prepdocs.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,14 @@
3030
)
3131
from azure.storage.blob import BlobServiceClient
3232
from pypdf import PdfReader, PdfWriter
33-
from tenacity import retry, stop_after_attempt, wait_random_exponential
33+
from tenacity import (
34+
retry,
35+
retry_if_exception_type,
36+
stop_after_attempt,
37+
wait_random_exponential,
38+
)
39+
40+
args = argparse.Namespace(verbose=False)
3441

3542
MAX_SECTION_LENGTH = 1000
3643
SENTENCE_SEARCH_LIMIT = 100
@@ -225,7 +232,7 @@ def filename_to_id(filename):
225232
filename_hash = base64.b16encode(filename.encode('utf-8')).decode('ascii')
226233
return f"file-{filename_ascii}-{filename_hash}"
227234

228-
def create_sections(filename, page_map, use_vectors):
235+
def create_sections(filename, page_map, use_vectors, embedding_deployment: str = None):
229236
file_id = filename_to_id(filename)
230237
for i, (content, pagenum) in enumerate(split_text(page_map, filename)):
231238
section = {
@@ -236,16 +243,16 @@ def create_sections(filename, page_map, use_vectors):
236243
"sourcefile": filename
237244
}
238245
if use_vectors:
239-
section["embedding"] = compute_embedding(content)
246+
section["embedding"] = compute_embedding(content, embedding_deployment)
240247
yield section
241248

242249
def before_retry_sleep(retry_state):
243250
if args.verbose: print("Rate limited on the OpenAI embeddings API, sleeping before retrying...")
244251

245-
@retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(15), before_sleep=before_retry_sleep)
246-
def compute_embedding(text):
252+
@retry(retry=retry_if_exception_type(openai.error.RateLimitError), wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(15), before_sleep=before_retry_sleep)
253+
def compute_embedding(text, embedding_deployment):
247254
refresh_openai_token()
248-
return openai.Embedding.create(engine=args.openaideployment, input=text)["data"][0]["embedding"]
255+
return openai.Embedding.create(engine=embedding_deployment, input=text)["data"][0]["embedding"]
249256

250257
@retry(wait=wait_random_exponential(min=15, max=60), stop=stop_after_attempt(15), before_sleep=before_retry_sleep)
251258
def compute_embedding_in_batch(texts):
@@ -314,7 +321,7 @@ def update_embeddings_in_batch(sections):
314321
if args.verbose: print(f"Batch Completed. Batch size {len(batch_queue)} Token count {token_count}")
315322
for emb, item in zip(emb_responses, batch_queue):
316323
batch_response[item["id"]] = emb
317-
324+
318325
for s in copy_s:
319326
s["embedding"] = batch_response[s["id"]]
320327
yield s
@@ -355,14 +362,18 @@ def remove_from_index(filename):
355362
# It can take a few seconds for search results to reflect changes, so wait a bit
356363
time.sleep(2)
357364

358-
# refresh open ai token every 5 minutes
365+
359366
def refresh_openai_token():
360-
if open_ai_token_cache[CACHE_KEY_TOKEN_TYPE] == 'azure_ad' and open_ai_token_cache[CACHE_KEY_CREATED_TIME] + 300 < time.time():
367+
"""
368+
Refresh OpenAI token every 5 minutes
369+
"""
370+
if CACHE_KEY_TOKEN_TYPE in open_ai_token_cache and open_ai_token_cache[CACHE_KEY_TOKEN_TYPE] == 'azure_ad' and open_ai_token_cache[CACHE_KEY_CREATED_TIME] + 300 < time.time():
361371
token_cred = open_ai_token_cache[CACHE_KEY_TOKEN_CRED]
362372
openai.api_key = token_cred.get_token("https://cognitiveservices.azure.com/.default").token
363373
open_ai_token_cache[CACHE_KEY_CREATED_TIME] = time.time()
364374

365-
def read_files(path_pattern: str, use_vectors: bool, vectors_batch_support: bool):
375+
376+
def read_files(path_pattern: str, use_vectors: bool, vectors_batch_support: bool, embedding_deployment: str = None):
366377
"""
367378
Recursively read directory structure under `path_pattern`
368379
and execute indexing for the individual files
@@ -380,8 +391,7 @@ def read_files(path_pattern: str, use_vectors: bool, vectors_batch_support: bool
380391
if not args.skipblobs:
381392
upload_blobs(filename)
382393
page_map = get_document_text(filename)
383-
sections = create_sections(os.path.basename(filename), page_map, use_vectors and not vectors_batch_support)
384-
print (use_vectors and vectors_batch_support)
394+
sections = create_sections(os.path.basename(filename), page_map, use_vectors and not vectors_batch_support, embedding_deployment)
385395
if use_vectors and vectors_batch_support:
386396
sections = update_embeddings_in_batch(sections)
387397
index_sections(os.path.basename(filename), sections)
@@ -456,4 +466,4 @@ def read_files(path_pattern: str, use_vectors: bool, vectors_batch_support: bool
456466
create_search_index()
457467

458468
print("Processing files...")
459-
read_files(args.files, use_vectors, compute_vectors_in_batch)
469+
read_files(args.files, use_vectors, compute_vectors_in_batch, args.openaideployment)

tests/test_prepdocs.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from scripts.prepdocs import filename_to_id
1+
import openai
2+
import pytest
3+
import tenacity
4+
from scripts.prepdocs import args, compute_embedding, filename_to_id
25

36

47
def test_filename_to_id():
@@ -8,3 +11,57 @@ def test_filename_to_id():
811
assert filename_to_id("foo\u00A9.txt") == "file-foo__txt-666F6FC2A92E747874"
912
# test filenaming starting with unicode
1013
assert filename_to_id("ファイル名.pdf") == "file-______pdf-E38395E382A1E382A4E383ABE5908D2E706466"
14+
15+
16+
def test_compute_embedding_success(monkeypatch, capsys):
17+
monkeypatch.setattr(args, "verbose", True)
18+
def mock_create(*args, **kwargs):
19+
# From https://platform.openai.com/docs/api-reference/embeddings/create
20+
return {
21+
"object": "list",
22+
"data": [
23+
{
24+
"object": "embedding",
25+
"embedding": [
26+
0.0023064255,
27+
-0.009327292,
28+
-0.0028842222,
29+
],
30+
"index": 0
31+
}
32+
],
33+
"model": "text-embedding-ada-002",
34+
"usage": {
35+
"prompt_tokens": 8,
36+
"total_tokens": 8
37+
}
38+
}
39+
40+
monkeypatch.setattr(openai.Embedding, "create", mock_create)
41+
assert compute_embedding("foo", "ada") == [
42+
0.0023064255,
43+
-0.009327292,
44+
-0.0028842222,
45+
]
46+
47+
48+
def test_compute_embedding_ratelimiterror(monkeypatch, capsys):
49+
monkeypatch.setattr(args, "verbose", True)
50+
def mock_create(*args, **kwargs):
51+
raise openai.error.RateLimitError
52+
monkeypatch.setattr(openai.Embedding, "create", mock_create)
53+
monkeypatch.setattr(tenacity.nap.time, "sleep", lambda x: None)
54+
with pytest.raises(tenacity.RetryError):
55+
compute_embedding("foo", "ada")
56+
captured = capsys.readouterr()
57+
assert captured.out.count("Rate limited on the OpenAI embeddings API") == 14
58+
59+
60+
def test_compute_embedding_autherror(monkeypatch, capsys):
61+
monkeypatch.setattr(args, "verbose", True)
62+
def mock_create(*args, **kwargs):
63+
raise openai.error.AuthenticationError
64+
monkeypatch.setattr(openai.Embedding, "create", mock_create)
65+
monkeypatch.setattr(tenacity.nap.time, "sleep", lambda x: None)
66+
with pytest.raises(openai.error.AuthenticationError):
67+
compute_embedding("foo", "ada")

0 commit comments

Comments
 (0)