Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 92 additions & 36 deletions llama_stack/providers/remote/vector_io/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import os
from typing import Any

from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker

from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
Expand Down Expand Up @@ -48,28 +47,58 @@

class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
parent_adapter=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concerned about this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is part of the fix of the init order in commit

please let me know if there is a better approach :_

):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level
self.kvstore = kvstore
self._parent_adapter = parent_adapter

async def initialize(self):
# MilvusIndex does not require explicit initialization
# TODO: could move collection creation into initialization but it is not really necessary
pass

async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use two try and except blocks here?

squashing the list_collections and drop_collection into one isn't ideal.

collections = await self.client.list_collections()
if self.collection_name in collections:
await self.client.drop_collection(collection_name=self.collection_name)
except Exception as e:
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")

async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)

if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
try:
collections = await self.client.list_collections()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given the usage here gain, it may make sense to just add a helper _has_collection() that has the try and except block and checks the collection_name

collection_exists = self.collection_name in collections
except Exception as e:
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
# If it's an event loop issue, try to recreate the client
if "attached to a different loop" in str(e):
logger.warning("Recreating client due to event loop issue")

if hasattr(self, "_parent_adapter"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this _parent_adapter? did you have issues with the event loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i had this issue

 Problem: AsyncMilvusClient connections were established in one event loop but used in different event loops during test execution.

  # BROKEN - Client created in setup loop, used in test loop
  RuntimeError: Task <Task pending> got Future <Future pending> attached to a different loop

  Symptoms:
  - RuntimeError: attached to a different loop errors
  - Tests failing during both execution and teardown
  - Milvus operations hanging or crashing

the cause was that MilvusIndex which reference MilvusClient was being initialized before MilvusClient

Therefore the failure, the fix is to swap the init order, which is being fixed separately in this commit

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is that log from? Is that from Claude or some other AI tool?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, Claude

await self._parent_adapter._recreate_client()
collections = await self.client.list_collections()
collection_exists = self.collection_name in collections
Comment on lines +90 to +93
Copy link

Copilot AI Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After recreating the client through the parent adapter, the self.client reference may still point to the old client. The client reference should be updated to the new client instance after recreation.

Copilot uses AI. Check for mistakes.
else:
# Assume collection doesn't exist if we can't check
collection_exists = False
else:
# Assume collection doesn't exist if we can't check due to other issues
collection_exists = False

if not collection_exists:
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
Expand Down Expand Up @@ -123,13 +152,16 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
)
schema.add_function(bm25_function)

await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice thanks for adding this.

await self.client.create_collection(
self.collection_name,
schema=schema,
index_params=index_params,
consistency_level=self.consistency_level,
)
except Exception as e:
logger.error(f"Failed to create collection {self.collection_name}: {e}")
raise e

data = []
for chunk, embedding in zip(chunks, embeddings, strict=False):
Expand All @@ -143,8 +175,7 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
}
)
try:
await asyncio.to_thread(
self.client.insert,
await self.client.insert(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we should improve the error in the raise

self.collection_name,
data=data,
)
Expand All @@ -153,8 +184,7 @@ async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
raise e

async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
Expand All @@ -177,8 +207,7 @@ async def query_keyword(
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
Expand Down Expand Up @@ -219,8 +248,7 @@ async def _fallback_keyword_search(
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
search_res = await self.client.query(
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
Expand Down Expand Up @@ -267,8 +295,7 @@ async def query_hybrid(
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)

search_res = await asyncio.to_thread(
self.client.hybrid_search,
search_res = await self.client.hybrid_search(
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
Expand All @@ -294,9 +321,7 @@ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> No
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
Expand All @@ -321,6 +346,15 @@ def __init__(

async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd be good to also add exception handling here. i think someone had an issue with the db_path at some point and we probably could've caught that bug better since it can be configured in the run.yaml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, could you propose the exception to add ?

also shall I add it here, or in a follow-up PR ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this PR it's fine to add!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm doing a little scope creep but while we're here it's a good thing to add a little more resilience as end users will be happy :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm 💯

will do tmrw, i am sleeping on my desk :D

if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)

start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
Expand All @@ -334,23 +368,38 @@ async def initialize(self) -> None:
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
parent_adapter=self,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)

# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()

async def shutdown(self) -> None:
self.client.close()
if self.client:
await self.client.close()

async def _recreate_client(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not convinced we need this and it feels like we shouldn't need it but maybe i'm missing something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its part of the init order in commit

I really did not have other options, all these failures happened after the migration, which actually was supposed to be straight forward :)

"""Recreate the AsyncMilvusClient when event loop issues occur"""
try:
if self.client:
await self.client.close()
except Exception as e:
logger.warning(f"Error closing old client: {e}")

if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)

for index_wrapper in self.cache.values():
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
index_wrapper.index.client = self.client

async def register_vector_db(
self,
Expand All @@ -362,7 +411,12 @@ async def register_vector_db(
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=consistency_level,
parent_adapter=self,
),
inference_api=self.inference_api,
)

Expand All @@ -381,7 +435,9 @@ async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWit

index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
index=MilvusIndex(
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/providers/vector_io/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from pymilvus import AsyncMilvusClient, connections

from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
Expand Down Expand Up @@ -139,7 +139,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize()
index.db_path = db_path
yield index
index.delete()
await index.delete()


@pytest.fixture
Expand Down Expand Up @@ -176,13 +176,15 @@ def milvus_vec_db_path(tmp_path_factory):

@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index

await client.close()


@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
Expand Down
Loading
Loading