Skip to content

Commit 8ae9896

Browse files
authored
refactor: Improve query performance by adding filtered IDs (#137)
* feat: Improved query performance by adding filters * added atlas mongo function
1 parent dfdce6a commit 8ae9896

File tree

6 files changed

+29
-6
lines changed

6 files changed

+29
-6
lines changed

app/routes/document_routes.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ async def health_check():
8383
async def get_documents_by_ids(ids: list[str] = Query(...)):
8484
try:
8585
if isinstance(vector_store, AsyncPgVector):
86-
existing_ids = await vector_store.get_all_ids()
86+
existing_ids = await vector_store.get_filtered_ids(ids)
8787
documents = await vector_store.get_documents_by_ids(ids)
8888
else:
89-
existing_ids = vector_store.get_all_ids()
89+
existing_ids = vector_store.get_filtered_ids(ids)
9090
documents = vector_store.get_documents_by_ids(ids)
9191

9292
# Ensure all requested ids exist
@@ -121,10 +121,10 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):
121121
async def delete_documents(document_ids: List[str] = Body(...)):
122122
try:
123123
if isinstance(vector_store, AsyncPgVector):
124-
existing_ids = await vector_store.get_all_ids()
124+
existing_ids = await vector_store.get_filtered_ids(document_ids)
125125
await vector_store.delete(ids=document_ids)
126126
else:
127-
existing_ids = vector_store.get_all_ids()
127+
existing_ids = vector_store.get_filtered_ids(document_ids)
128128
vector_store.delete(ids=document_ids)
129129

130130
if not all(id in existing_ids for id in document_ids):
@@ -456,10 +456,10 @@ async def load_document_context(id: str):
456456
ids = [id]
457457
try:
458458
if isinstance(vector_store, AsyncPgVector):
459-
existing_ids = await vector_store.get_all_ids()
459+
existing_ids = await vector_store.get_filtered_ids(ids)
460460
documents = await vector_store.get_documents_by_ids(ids)
461461
else:
462-
existing_ids = vector_store.get_all_ids()
462+
existing_ids = vector_store.get_filtered_ids(ids)
463463
documents = vector_store.get_documents_by_ids(ids)
464464

465465
# Ensure the requested id exists

app/services/vector_store/async_pg_vector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
class AsyncPgVector(ExtendedPgVector):
77
async def get_all_ids(self) -> list[str]:
88
return await run_in_executor(None, super().get_all_ids)
9+
10+
async def get_filtered_ids(self, ids: list[str]) -> list[str]:
11+
return await run_in_executor(None, super().get_filtered_ids, ids)
912

1013
async def get_documents_by_ids(self, ids: list[str]) -> list[Document]:
1114
return await run_in_executor(None, super().get_documents_by_ids, ids)

app/services/vector_store/atlas_mongo_vector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ def similarity_search_with_score_by_vector(
4444
def get_all_ids(self) -> list[str]:
4545
# Return unique file_id fields in self._collection
4646
return self._collection.distinct("file_id")
47+
48+
def get_filtered_ids(self, ids: list[str]) -> list[str]:
49+
# Return unique file_id fields filtered by the provided ids
50+
return self._collection.distinct("file_id", {"file_id": {"$in": ids}})
4751

4852
def get_documents_by_ids(self, ids: list[str]) -> list[Document]:
4953
# Return documents filtered by file_id

app/services/vector_store/extended_pg_vector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ def get_all_ids(self) -> list[str]:
1010
with Session(self._bind) as session:
1111
results = session.query(self.EmbeddingStore.custom_id).all()
1212
return [result[0] for result in results if result[0] is not None]
13+
14+
def get_filtered_ids(self, ids: list[str]) -> list[str]:
15+
with Session(self._bind) as session:
16+
query = session.query(self.EmbeddingStore.custom_id).filter(self.EmbeddingStore.custom_id.in_(ids))
17+
results = query.all()
18+
return [result[0] for result in results if result[0] is not None]
1319

1420
def get_documents_by_ids(self, ids: list[str]) -> list[Document]:
1521
with Session(self._bind) as session:

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ def dummy_post_init(self):
2626
class DummyVectorStore:
2727
def get_all_ids(self) -> list[str]:
2828
return ["testid1", "testid2"]
29+
30+
def get_filtered_ids(self, ids) -> list[str]:
31+
dummy_ids = ["testid1", "testid2"]
32+
return [id for id in dummy_ids if id in ids]
2933

3034
async def get_documents_by_ids(self, ids: list[str]) -> list[Document]:
3135
return [

tests/test_main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ async def dummy_get_all_ids():
2929
return ["testid1", "testid2"]
3030
monkeypatch.setattr(vector_store, "get_all_ids", dummy_get_all_ids)
3131

32+
# Override get_filtered_ids as an async function.
33+
async def dummy_get_filtered_ids(ids):
34+
dummy_ids = ["testid1", "testid2"]
35+
return [id for id in dummy_ids if id in ids]
36+
monkeypatch.setattr(vector_store, "get_filtered_ids", dummy_get_filtered_ids)
37+
3238
# Override get_documents_by_ids as an async function.
3339
async def dummy_get_documents_by_ids(ids):
3440
return [

0 commit comments

Comments
 (0)