-
Couldn't load subscription status.
- Fork 1.2k
feat(client): migrate MilvusClient to AsyncMilvusClient #3376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
142bd24
5a6c203
e7444c1
733d0c7
5482396
295d8b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,12 +4,11 @@ | |
| # This source code is licensed under the terms described in the LICENSE file in | ||
| # the root directory of this source tree. | ||
|
|
||
| import asyncio | ||
| import os | ||
| from typing import Any | ||
|
|
||
| from numpy.typing import NDArray | ||
| from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker | ||
| from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker | ||
|
|
||
| from llama_stack.apis.common.errors import VectorStoreNotFoundError | ||
| from llama_stack.apis.files.files import Files | ||
|
|
@@ -48,28 +47,58 @@ | |
|
|
||
| class MilvusIndex(EmbeddingIndex): | ||
| def __init__( | ||
| self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None | ||
| self, | ||
| client: AsyncMilvusClient, | ||
| collection_name: str, | ||
| consistency_level="Strong", | ||
| kvstore: KVStore | None = None, | ||
| parent_adapter=None, | ||
| ): | ||
| self.client = client | ||
| self.collection_name = sanitize_collection_name(collection_name) | ||
| self.consistency_level = consistency_level | ||
| self.kvstore = kvstore | ||
| self._parent_adapter = parent_adapter | ||
|
|
||
| async def initialize(self): | ||
| # MilvusIndex does not require explicit initialization | ||
| # TODO: could move collection creation into initialization but it is not really necessary | ||
| pass | ||
|
|
||
| async def delete(self): | ||
| if await asyncio.to_thread(self.client.has_collection, self.collection_name): | ||
| await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) | ||
| try: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use two try and except blocks here? squashing the |
||
| collections = await self.client.list_collections() | ||
| if self.collection_name in collections: | ||
| await self.client.drop_collection(collection_name=self.collection_name) | ||
| except Exception as e: | ||
| logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}") | ||
|
|
||
| async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): | ||
| assert len(chunks) == len(embeddings), ( | ||
| f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" | ||
| ) | ||
|
|
||
| if not await asyncio.to_thread(self.client.has_collection, self.collection_name): | ||
| try: | ||
| collections = await self.client.list_collections() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given the usage here gain, it may make sense to just add a helper |
||
| collection_exists = self.collection_name in collections | ||
| except Exception as e: | ||
| logger.error(f"Failed to check collection existence: {self.collection_name} ({e})") | ||
| # If it's an event loop issue, try to recreate the client | ||
| if "attached to a different loop" in str(e): | ||
| logger.warning("Recreating client due to event loop issue") | ||
|
|
||
| if hasattr(self, "_parent_adapter"): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need this _parent_adapter? did you have issues with the event loop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, i had this issue the cause was that Therefore the failure, the fix is to swap the init order, which is being fixed separately in this commit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is that log from? Is that from Claude or some other AI tool? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, Claude |
||
| await self._parent_adapter._recreate_client() | ||
| collections = await self.client.list_collections() | ||
| collection_exists = self.collection_name in collections | ||
|
Comment on lines
+90
to
+93
|
||
| else: | ||
| # Assume collection doesn't exist if we can't check | ||
| collection_exists = False | ||
| else: | ||
| # Assume collection doesn't exist if we can't check due to other issues | ||
| collection_exists = False | ||
|
|
||
| if not collection_exists: | ||
| logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") | ||
| # Create schema for vector search | ||
| schema = self.client.create_schema() | ||
|
|
@@ -123,13 +152,16 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): | |
| ) | ||
| schema.add_function(bm25_function) | ||
|
|
||
| await asyncio.to_thread( | ||
| self.client.create_collection, | ||
| self.collection_name, | ||
| schema=schema, | ||
| index_params=index_params, | ||
| consistency_level=self.consistency_level, | ||
| ) | ||
| try: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice thanks for adding this. |
||
| await self.client.create_collection( | ||
| self.collection_name, | ||
| schema=schema, | ||
| index_params=index_params, | ||
| consistency_level=self.consistency_level, | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"Failed to create collection {self.collection_name}: {e}") | ||
| raise e | ||
|
|
||
| data = [] | ||
| for chunk, embedding in zip(chunks, embeddings, strict=False): | ||
|
|
@@ -143,8 +175,7 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): | |
| } | ||
| ) | ||
| try: | ||
| await asyncio.to_thread( | ||
| self.client.insert, | ||
| await self.client.insert( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably we should improve the error in the raise |
||
| self.collection_name, | ||
| data=data, | ||
| ) | ||
|
|
@@ -153,8 +184,7 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): | |
| raise e | ||
|
|
||
| async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: | ||
| search_res = await asyncio.to_thread( | ||
| self.client.search, | ||
| search_res = await self.client.search( | ||
| collection_name=self.collection_name, | ||
| data=[embedding], | ||
| anns_field="vector", | ||
|
|
@@ -177,8 +207,7 @@ async def query_keyword( | |
| """ | ||
| try: | ||
| # Use Milvus's built-in BM25 search | ||
| search_res = await asyncio.to_thread( | ||
| self.client.search, | ||
| search_res = await self.client.search( | ||
| collection_name=self.collection_name, | ||
| data=[query_string], # Raw text query | ||
| anns_field="sparse", # Use sparse field for BM25 | ||
|
|
@@ -219,8 +248,7 @@ async def _fallback_keyword_search( | |
| Fallback to simple text search when BM25 search is not available. | ||
| """ | ||
| # Simple text search using content field | ||
| search_res = await asyncio.to_thread( | ||
| self.client.query, | ||
| search_res = await self.client.query( | ||
| collection_name=self.collection_name, | ||
| filter='content like "%{content}%"', | ||
| filter_params={"content": query_string}, | ||
mattf marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -267,8 +295,7 @@ async def query_hybrid( | |
| impact_factor = (reranker_params or {}).get("impact_factor", 60.0) | ||
| rerank = RRFRanker(impact_factor) | ||
|
|
||
| search_res = await asyncio.to_thread( | ||
| self.client.hybrid_search, | ||
| search_res = await self.client.hybrid_search( | ||
| collection_name=self.collection_name, | ||
| reqs=search_requests, | ||
| ranker=rerank, | ||
|
|
@@ -294,9 +321,7 @@ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> No | |
| try: | ||
| # Use IN clause with square brackets and single quotes for VARCHAR field | ||
| chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) | ||
| await asyncio.to_thread( | ||
| self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" | ||
| ) | ||
| await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]") | ||
| except Exception as e: | ||
| logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") | ||
| raise | ||
|
|
@@ -321,6 +346,15 @@ def __init__( | |
|
|
||
| async def initialize(self) -> None: | ||
| self.kvstore = await kvstore_impl(self.config.kvstore) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it'd be good to also add exception handling here. i think someone had an issue with the db_path at some point and we probably could've caught that bug better since it can be configured in the run.yaml There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, could you propose the exception to add ? also shall I add it here, or in a follow-up PR ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in this PR it's fine to add! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm doing a little scope creep but while we're here it's a good thing to add a little more resilience as end users will be happy :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lgtm 💯 will do tmrw, i am sleeping on my desk :D |
||
| if isinstance(self.config, RemoteMilvusVectorIOConfig): | ||
| logger.info(f"Connecting to Milvus server at {self.config.uri}") | ||
| self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) | ||
| else: | ||
| logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") | ||
| uri = os.path.expanduser(self.config.db_path) | ||
| self.client = AsyncMilvusClient(uri=uri) | ||
|
|
||
| start_key = VECTOR_DBS_PREFIX | ||
| end_key = f"{VECTOR_DBS_PREFIX}\xff" | ||
| stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) | ||
|
|
@@ -334,23 +368,38 @@ async def initialize(self) -> None: | |
| collection_name=vector_db.identifier, | ||
| consistency_level=self.config.consistency_level, | ||
| kvstore=self.kvstore, | ||
| parent_adapter=self, | ||
| ), | ||
| inference_api=self.inference_api, | ||
| ) | ||
| self.cache[vector_db.identifier] = index | ||
| if isinstance(self.config, RemoteMilvusVectorIOConfig): | ||
| logger.info(f"Connecting to Milvus server at {self.config.uri}") | ||
| self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) | ||
| else: | ||
| logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") | ||
| uri = os.path.expanduser(self.config.db_path) | ||
| self.client = MilvusClient(uri=uri) | ||
|
|
||
| # Load existing OpenAI vector stores into the in-memory cache | ||
| await self.initialize_openai_vector_stores() | ||
|
|
||
| async def shutdown(self) -> None: | ||
| self.client.close() | ||
| if self.client: | ||
| await self.client.close() | ||
|
|
||
| async def _recreate_client(self) -> None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not convinced we need this and it feels like we shouldn't need it but maybe i'm missing something. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its part of the init order in commit I really did not have other options, all these failures happened after the migration, which actually was supposed to be straight forward :) |
||
| """Recreate the AsyncMilvusClient when event loop issues occur""" | ||
| try: | ||
| if self.client: | ||
| await self.client.close() | ||
| except Exception as e: | ||
| logger.warning(f"Error closing old client: {e}") | ||
|
|
||
| if isinstance(self.config, RemoteMilvusVectorIOConfig): | ||
| logger.info(f"Recreating connection to Milvus server at {self.config.uri}") | ||
| self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) | ||
| else: | ||
| logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}") | ||
| uri = os.path.expanduser(self.config.db_path) | ||
| self.client = AsyncMilvusClient(uri=uri) | ||
|
|
||
| for index_wrapper in self.cache.values(): | ||
| if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"): | ||
| index_wrapper.index.client = self.client | ||
|
|
||
| async def register_vector_db( | ||
| self, | ||
|
|
@@ -362,7 +411,12 @@ async def register_vector_db( | |
| consistency_level = "Strong" | ||
| index = VectorDBWithIndex( | ||
| vector_db=vector_db, | ||
| index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), | ||
| index=MilvusIndex( | ||
| client=self.client, | ||
| collection_name=vector_db.identifier, | ||
| consistency_level=consistency_level, | ||
| parent_adapter=self, | ||
| ), | ||
| inference_api=self.inference_api, | ||
| ) | ||
|
|
||
|
|
@@ -381,7 +435,9 @@ async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWit | |
|
|
||
| index = VectorDBWithIndex( | ||
| vector_db=vector_db, | ||
| index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), | ||
| index=MilvusIndex( | ||
| client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self | ||
| ), | ||
| inference_api=self.inference_api, | ||
| ) | ||
| self.cache[vector_db_id] = index | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
concerned about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is part of the fix of the init order in commit
please let me know if there is a better approach :_