Skip to content

Commit ff50774

Browse files
authored
[Bug fix] Fix prepdocs logic in uploading docs with more than 1000 sections (#971)
* Fix batch id logic * add more tests * Add embeddings test * 100 percent cov * rm print
1 parent 59b8cbd commit ff50774

File tree

2 files changed

+287
-3
lines changed

2 files changed

+287
-3
lines changed

scripts/prepdocslib/searchmanager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ async def update_content(self, sections: List[Section]):
122122
section_batches = [sections[i : i + MAX_BATCH_SIZE] for i in range(0, len(sections), MAX_BATCH_SIZE)]
123123

124124
async with self.search_info.create_search_client() as search_client:
125-
for batch in section_batches:
125+
for batch_index, batch in enumerate(section_batches):
126126
documents = [
127127
{
128-
"id": f"{section.content.filename_to_id()}-page-{i}",
128+
"id": f"{section.content.filename_to_id()}-page-{section_index + batch_index * MAX_BATCH_SIZE}",
129129
"content": section.split_page.text,
130130
"category": section.category,
131131
"sourcepage": BlobManager.sourcepage_from_file_page(
@@ -134,7 +134,7 @@ async def update_content(self, sections: List[Section]):
134134
"sourcefile": section.content.filename(),
135135
**section.content.acls,
136136
}
137-
for i, section in enumerate(batch)
137+
for section_index, section in enumerate(batch)
138138
]
139139
if self.embeddings:
140140
embeddings = await self.embeddings.create_embeddings(

tests/test_searchmanager.py

Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
import io
2+
3+
import openai
4+
import pytest
5+
from azure.core.credentials import AzureKeyCredential
6+
from azure.search.documents.aio import SearchClient
7+
from azure.search.documents.indexes.aio import SearchIndexClient
8+
9+
from scripts.prepdocslib.embeddings import AzureOpenAIEmbeddingService
10+
from scripts.prepdocslib.listfilestrategy import File
11+
from scripts.prepdocslib.searchmanager import SearchManager, Section
12+
from scripts.prepdocslib.strategy import SearchInfo
13+
from scripts.prepdocslib.textsplitter import SplitPage
14+
15+
16+
@pytest.fixture
17+
def search_info():
18+
return SearchInfo(
19+
endpoint="https://testsearchclient.blob.core.windows.net",
20+
credential=AzureKeyCredential("test"),
21+
index_name="test",
22+
verbose=True,
23+
)
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_create_index_doesnt_exist_yet(monkeypatch, search_info):
28+
indexes = []
29+
30+
async def mock_create_index(self, index):
31+
indexes.append(index)
32+
33+
async def mock_list_index_names(self):
34+
for index in []:
35+
yield index
36+
37+
monkeypatch.setattr(SearchIndexClient, "create_index", mock_create_index)
38+
monkeypatch.setattr(SearchIndexClient, "list_index_names", mock_list_index_names)
39+
40+
manager = SearchManager(
41+
search_info,
42+
)
43+
await manager.create_index()
44+
assert len(indexes) == 1, "It should have created one index"
45+
assert indexes[0].name == "test"
46+
assert len(indexes[0].fields) == 6
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_create_index_does_exist(monkeypatch, search_info):
51+
indexes = []
52+
53+
async def mock_create_index(self, index):
54+
indexes.append(index)
55+
56+
async def mock_list_index_names(self):
57+
yield "test"
58+
59+
monkeypatch.setattr(SearchIndexClient, "create_index", mock_create_index)
60+
monkeypatch.setattr(SearchIndexClient, "list_index_names", mock_list_index_names)
61+
62+
manager = SearchManager(
63+
search_info,
64+
)
65+
await manager.create_index()
66+
assert len(indexes) == 0, "It should not have created a new index"
67+
68+
69+
@pytest.mark.asyncio
70+
async def test_create_index_acls(monkeypatch, search_info):
71+
indexes = []
72+
73+
async def mock_create_index(self, index):
74+
indexes.append(index)
75+
76+
async def mock_list_index_names(self):
77+
for index in []:
78+
yield index
79+
80+
monkeypatch.setattr(SearchIndexClient, "create_index", mock_create_index)
81+
monkeypatch.setattr(SearchIndexClient, "list_index_names", mock_list_index_names)
82+
83+
manager = SearchManager(
84+
search_info,
85+
use_acls=True,
86+
)
87+
await manager.create_index()
88+
assert len(indexes) == 1, "It should have created one index"
89+
assert indexes[0].name == "test"
90+
assert len(indexes[0].fields) == 8
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_update_content(monkeypatch, search_info):
95+
async def mock_upload_documents(self, documents):
96+
assert len(documents) == 1
97+
assert documents[0]["id"] == "file-foo_pdf-666F6F2E706466-page-0"
98+
assert documents[0]["content"] == "test content"
99+
assert documents[0]["category"] == "test"
100+
assert documents[0]["sourcepage"] == "foo.pdf#page=1"
101+
assert documents[0]["sourcefile"] == "foo.pdf"
102+
103+
monkeypatch.setattr(SearchClient, "upload_documents", mock_upload_documents)
104+
105+
manager = SearchManager(
106+
search_info,
107+
)
108+
109+
test_io = io.BytesIO(b"test content")
110+
test_io.name = "test/foo.pdf"
111+
file = File(test_io)
112+
113+
await manager.update_content(
114+
[
115+
Section(
116+
split_page=SplitPage(
117+
page_num=0,
118+
text="test content",
119+
),
120+
content=file,
121+
category="test",
122+
)
123+
]
124+
)
125+
126+
127+
@pytest.mark.asyncio
128+
async def test_update_content_many(monkeypatch, search_info):
129+
ids = []
130+
131+
async def mock_upload_documents(self, documents):
132+
ids.extend([doc["id"] for doc in documents])
133+
134+
monkeypatch.setattr(SearchClient, "upload_documents", mock_upload_documents)
135+
136+
manager = SearchManager(
137+
search_info,
138+
)
139+
140+
# create 1500 sections for 500 pages
141+
sections = []
142+
test_io = io.BytesIO(b"test page")
143+
test_io.name = "test/foo.pdf"
144+
file = File(test_io)
145+
for page_num in range(500):
146+
for page_section_num in range(3):
147+
sections.append(
148+
Section(
149+
split_page=SplitPage(
150+
page_num=page_num,
151+
text=f"test section {page_section_num}",
152+
),
153+
content=file,
154+
category="test",
155+
)
156+
)
157+
158+
await manager.update_content(sections)
159+
160+
assert len(ids) == 1500, "Wrong number of documents uploaded"
161+
assert len(set(ids)) == 1500, "Document ids are not unique"
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_update_content_with_embeddings(monkeypatch, search_info):
166+
async def mock_create(*args, **kwargs):
167+
# From https://platform.openai.com/docs/api-reference/embeddings/create
168+
return {
169+
"object": "list",
170+
"data": [
171+
{
172+
"object": "embedding",
173+
"embedding": [
174+
0.0023064255,
175+
-0.009327292,
176+
-0.0028842222,
177+
],
178+
"index": 0,
179+
}
180+
],
181+
"model": "text-embedding-ada-002",
182+
"usage": {"prompt_tokens": 8, "total_tokens": 8},
183+
}
184+
185+
monkeypatch.setattr(openai.Embedding, "acreate", mock_create)
186+
187+
documents_uploaded = []
188+
189+
async def mock_upload_documents(self, documents):
190+
documents_uploaded.extend(documents)
191+
192+
monkeypatch.setattr(SearchClient, "upload_documents", mock_upload_documents)
193+
194+
manager = SearchManager(
195+
search_info,
196+
embeddings=AzureOpenAIEmbeddingService(
197+
open_ai_service="x",
198+
open_ai_deployment="x",
199+
open_ai_model_name="text-ada-003",
200+
credential=AzureKeyCredential("test"),
201+
disable_batch=True,
202+
),
203+
)
204+
205+
test_io = io.BytesIO(b"test content")
206+
test_io.name = "test/foo.pdf"
207+
file = File(test_io)
208+
209+
await manager.update_content(
210+
[
211+
Section(
212+
split_page=SplitPage(
213+
page_num=0,
214+
text="test content",
215+
),
216+
content=file,
217+
category="test",
218+
)
219+
]
220+
)
221+
222+
assert len(documents_uploaded) == 1, "It should have uploaded one document"
223+
assert documents_uploaded[0]["embedding"] == [
224+
0.0023064255,
225+
-0.009327292,
226+
-0.0028842222,
227+
]
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_remove_content(monkeypatch, search_info):
232+
class AsyncSearchResultsIterator:
233+
def __init__(self):
234+
self.results = [
235+
{
236+
"@search.score": 1,
237+
"id": "file-foo_pdf-666F6F2E706466-page-0",
238+
"content": "test content",
239+
"category": "test",
240+
"sourcepage": "foo.pdf#page=1",
241+
"sourcefile": "foo.pdf",
242+
}
243+
]
244+
245+
def __aiter__(self):
246+
return self
247+
248+
async def __anext__(self):
249+
if len(self.results) == 0:
250+
raise StopAsyncIteration
251+
return self.results.pop()
252+
253+
async def get_count(self):
254+
return len(self.results)
255+
256+
search_results = AsyncSearchResultsIterator()
257+
258+
searched_filters = []
259+
260+
async def mock_search(self, *args, **kwargs):
261+
self.filter = kwargs.get("filter")
262+
searched_filters.append(self.filter)
263+
return search_results
264+
265+
monkeypatch.setattr(SearchClient, "search", mock_search)
266+
267+
deleted_documents = []
268+
269+
async def mock_delete_documents(self, documents):
270+
deleted_documents.extend(documents)
271+
return documents
272+
273+
monkeypatch.setattr(SearchClient, "delete_documents", mock_delete_documents)
274+
275+
manager = SearchManager(
276+
search_info,
277+
)
278+
279+
await manager.remove_content("foo.pdf")
280+
281+
assert len(searched_filters) == 2, "It should have searched twice (with no results on second try)"
282+
assert searched_filters[0] == "sourcefile eq 'foo.pdf'"
283+
assert len(deleted_documents) == 1, "It should have deleted one document"
284+
assert deleted_documents[0]["id"] == "file-foo_pdf-666F6F2E706466-page-0"

0 commit comments

Comments
 (0)