Skip to content

Commit 21860ba

Browse files
feat: change vector db call to async
2 parents fdc7a91 + f6def2e commit 21860ba

13 files changed

+2033
-1899
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@ dependencies = [
1717
"uvicorn>=0.35.0",
1818
"watchdog>=6.0.0",
1919
"weaviate-client[agents]>=4.16.5",
20-
"pymilvus>=2.5.14",
20+
"pymilvus>=2.6.0",
2121
"pyyaml>=6.0.2",
2222
"jsonschema>=4.25.0",
2323
"fastmcp>=2.11.0",
2424
"six>=1.17.0",
2525
"sentence-transformers>=2.5.1",
2626
"scikit-learn>=1.5.0",
2727
"numpy>=1.26.0",
28+
"pytest-asyncio>=1.1.0",
2829
]
2930

3031
[tool.ruff.lint]

src/db/vector_db_base.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def supported_embeddings(self) -> list[str]:
4444
pass
4545

4646
@abstractmethod
47-
def setup(
47+
async def setup(
4848
self,
4949
embedding: str = "default",
5050
collection_name: str = None,
@@ -61,7 +61,7 @@ def setup(
6161
pass
6262

6363
@abstractmethod
64-
def write_documents(
64+
async def write_documents(
6565
self,
6666
documents: list[dict[str, Any]],
6767
embedding: str = "default",
@@ -77,7 +77,7 @@ def write_documents(
7777
"""
7878
pass
7979

80-
def write_documents_to_collection(
80+
async def write_documents_to_collection(
8181
self,
8282
documents: list[dict[str, Any]],
8383
collection_name: str,
@@ -91,9 +91,9 @@ def write_documents_to_collection(
9191
For Milvus, documents may also include a 'vector' field.
9292
collection_name: Name of the collection to write to
9393
"""
94-
return self.write_documents(documents, embedding, collection_name)
94+
return await self.write_documents(documents, embedding, collection_name)
9595

96-
def write_document(
96+
async def write_document(
9797
self,
9898
document: dict[str, Any],
9999
embedding: str = "default",
@@ -107,10 +107,12 @@ def write_document(
107107
For Milvus, document may also include a 'vector' field.
108108
collection_name: Name of the collection to write to (optional)
109109
"""
110-
return self.write_documents([document], embedding, collection_name)
110+
return await self.write_documents([document], embedding, collection_name)
111111

112112
@abstractmethod
113-
def list_documents(self, limit: int = 10, offset: int = 0) -> list[dict[str, Any]]:
113+
async def list_documents(
114+
self, limit: int = 10, offset: int = 0
115+
) -> list[dict[str, Any]]:
114116
"""
115117
List documents from the vector database.
116118
@@ -123,7 +125,7 @@ def list_documents(self, limit: int = 10, offset: int = 0) -> list[dict[str, Any
123125
"""
124126
pass
125127

126-
def list_documents_in_collection(
128+
async def list_documents_in_collection(
127129
self, collection_name: str, limit: int = 10, offset: int = 0
128130
) -> list[dict[str, Any]]:
129131
"""
@@ -141,12 +143,12 @@ def list_documents_in_collection(
141143
original_collection = self.collection_name
142144
self.collection_name = collection_name
143145
try:
144-
return self.list_documents(limit, offset)
146+
return await self.list_documents(limit, offset)
145147
finally:
146148
self.collection_name = original_collection
147149

148150
@abstractmethod
149-
def get_document(
151+
async def get_document(
150152
self, doc_name: str, collection_name: str = None
151153
) -> dict[str, Any]:
152154
"""
@@ -165,7 +167,7 @@ def get_document(
165167
pass
166168

167169
@abstractmethod
168-
def count_documents(self) -> int:
170+
async def count_documents(self) -> int:
169171
"""
170172
Get the current count of documents in the collection.
171173
@@ -174,7 +176,7 @@ def count_documents(self) -> int:
174176
"""
175177
pass
176178

177-
def count_documents_in_collection(self, collection_name: str) -> int:
179+
async def count_documents_in_collection(self, collection_name: str) -> int:
178180
"""
179181
Get the current count of documents in a specific collection.
180182
@@ -188,12 +190,12 @@ def count_documents_in_collection(self, collection_name: str) -> int:
188190
original_collection = self.collection_name
189191
self.collection_name = collection_name
190192
try:
191-
return self.count_documents()
193+
return await self.count_documents()
192194
finally:
193195
self.collection_name = original_collection
194196

195197
@abstractmethod
196-
def list_collections(self) -> list[str]:
198+
async def list_collections(self) -> list[str]:
197199
"""
198200
List all collections in the vector database.
199201
@@ -203,7 +205,7 @@ def list_collections(self) -> list[str]:
203205
pass
204206

205207
@abstractmethod
206-
def get_collection_info(self, collection_name: str = None) -> dict[str, Any]:
208+
async def get_collection_info(self, collection_name: str = None) -> dict[str, Any]:
207209
"""
208210
Get detailed information about a collection.
209211
@@ -221,7 +223,7 @@ def get_collection_info(self, collection_name: str = None) -> dict[str, Any]:
221223
pass
222224

223225
@abstractmethod
224-
def delete_documents(self, document_ids: list[str]) -> None:
226+
async def delete_documents(self, document_ids: list[str]) -> None:
225227
"""
226228
Delete documents from the vector database by their IDs.
227229
@@ -230,17 +232,17 @@ def delete_documents(self, document_ids: list[str]) -> None:
230232
"""
231233
pass
232234

233-
def delete_document(self, document_id: str) -> None:
235+
async def delete_document(self, document_id: str) -> None:
234236
"""
235237
Delete a single document from the vector database by its ID.
236238
237239
Args:
238240
document_id: Document ID to delete
239241
"""
240-
return self.delete_documents([document_id])
242+
return await self.delete_documents([document_id])
241243

242244
@abstractmethod
243-
def delete_collection(self, collection_name: str = None) -> None:
245+
async def delete_collection(self, collection_name: str = None) -> None:
244246
"""
245247
Delete an entire collection from the database.
246248
@@ -256,7 +258,9 @@ def create_query_agent(self) -> "VectorDatabase":
256258
pass
257259

258260
@abstractmethod
259-
def query(self, query: str, limit: int = 5, collection_name: str = None) -> str:
261+
async def query(
262+
self, query: str, limit: int = 5, collection_name: str = None
263+
) -> str:
260264
"""
261265
Query the vector database using the default query agent.
262266
@@ -271,7 +275,7 @@ def query(self, query: str, limit: int = 5, collection_name: str = None) -> str:
271275
pass
272276

273277
@abstractmethod
274-
def search(
278+
async def search(
275279
self, query: str, limit: int = 5, collection_name: str = None
276280
) -> list[dict]:
277281
"""
@@ -288,11 +292,11 @@ def search(
288292
pass
289293

290294
@abstractmethod
291-
def cleanup(self) -> None:
295+
async def cleanup(self) -> None:
292296
"""Clean up resources and close connections."""
293297
pass
294298

295-
def get_document_chunks(
299+
async def get_document_chunks(
296300
self, doc_id: str, collection_name: str = None
297301
) -> list[dict[str, Any]]:
298302
"""

0 commit comments

Comments
 (0)