Skip to content

Commit 2da62c8

Browse files
Wang-Daojiyuan.wang
andauthored
milvus implement (#354)
* milvus implement * milvus implement * milvus implement --------- Co-authored-by: yuan.wang <[email protected]>
1 parent d01c8cf commit 2da62c8

File tree

2 files changed

+378
-0
lines changed

2 files changed

+378
-0
lines changed

src/memos/configs/vec_db.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ def set_default_path(self):
3939
return self
4040

4141

42+
class MilvusVecDBConfig(BaseVecDBConfig):
43+
"""Configuration for Milvus vector database."""
44+
45+
uri: str = Field(..., description="URI for Milvus connection")
46+
collection_name: list[str] = Field(..., description="Name(s) of the collection(s)")
47+
max_length: int = Field(
48+
default=65535, description="Maximum length for string fields (varChar type)"
49+
)
50+
user_name: str = Field(default="", description="User name for Milvus connection")
51+
password: str = Field(default="", description="Password for Milvus connection")
52+
53+
4254
class VectorDBConfigFactory(BaseConfig):
4355
"""Factory class for creating vector database configurations."""
4456

@@ -47,6 +59,7 @@ class VectorDBConfigFactory(BaseConfig):
4759

4860
backend_to_class: ClassVar[dict[str, Any]] = {
4961
"qdrant": QdrantVecDBConfig,
62+
"milvus": MilvusVecDBConfig,
5063
}
5164

5265
@field_validator("backend")

src/memos/vec_dbs/milvus.py

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
from typing import Any
2+
3+
from memos.configs.vec_db import MilvusVecDBConfig
4+
from memos.dependency import require_python_package
5+
from memos.log import get_logger
6+
from memos.vec_dbs.base import BaseVecDB
7+
from memos.vec_dbs.item import VecDBItem
8+
9+
10+
logger = get_logger(__name__)
11+
12+
13+
class MilvusVecDB(BaseVecDB):
14+
"""Milvus vector database implementation."""
15+
16+
@require_python_package(
17+
import_name="pymilvus",
18+
install_command="pip install -U pymilvus",
19+
install_link="https://milvus.io/docs/install-pymilvus.md",
20+
)
21+
def __init__(self, config: MilvusVecDBConfig):
22+
"""Initialize the Milvus vector database and the collection."""
23+
from pymilvus import MilvusClient
24+
self.config = config
25+
26+
# Create Milvus client
27+
self.client = MilvusClient(
28+
uri=self.config.uri, user=self.config.user_name, password=self.config.password
29+
)
30+
self.schema = self.create_schema()
31+
self.index_params = self.create_index()
32+
self.create_collection()
33+
34+
def create_schema(self):
35+
"""Create schema for the milvus collection."""
36+
from pymilvus import DataType
37+
schema = self.client.create_schema(auto_id=False, enable_dynamic_field=True)
38+
schema.add_field(
39+
field_name="id", datatype=DataType.VARCHAR, max_length=65535, is_primary=True
40+
)
41+
schema.add_field(
42+
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimension
43+
)
44+
schema.add_field(field_name="payload", datatype=DataType.JSON)
45+
46+
return schema
47+
48+
def create_index(self):
49+
"""Create index for the milvus collection."""
50+
index_params = self.client.prepare_index_params()
51+
index_params.add_index(
52+
field_name="vector", index_type="FLAT", metric_type=self._get_metric_type()
53+
)
54+
55+
return index_params
56+
57+
def create_collection(self) -> None:
58+
"""Create a new collection with specified parameters."""
59+
for collection_name in self.config.collection_name:
60+
if self.collection_exists(collection_name):
61+
logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.")
62+
continue
63+
64+
self.client.create_collection(
65+
collection_name=collection_name,
66+
dimension=self.config.vector_dimension,
67+
metric_type=self._get_metric_type(),
68+
schema=self.schema,
69+
index_params=self.index_params,
70+
)
71+
72+
logger.info(
73+
f"Collection '{collection_name}' created with {self.config.vector_dimension} dimensions."
74+
)
75+
76+
def create_collection_by_name(self, collection_name: str) -> None:
77+
"""Create a new collection with specified parameters."""
78+
if self.collection_exists(collection_name):
79+
logger.warning(f"Collection '{collection_name}' already exists. Skipping creation.")
80+
return
81+
82+
self.client.create_collection(
83+
collection_name=collection_name,
84+
dimension=self.config.vector_dimension,
85+
metric_type=self._get_metric_type(),
86+
schema=self.schema,
87+
index_params=self.index_params,
88+
)
89+
90+
def list_collections(self) -> list[str]:
91+
"""List all collections."""
92+
return self.client.list_collections()
93+
94+
def delete_collection(self, name: str) -> None:
95+
"""Delete a collection."""
96+
self.client.drop_collection(name)
97+
98+
def collection_exists(self, name: str) -> bool:
99+
"""Check if a collection exists."""
100+
return self.client.has_collection(collection_name=name)
101+
102+
def search(
103+
self,
104+
query_vector: list[float],
105+
collection_name: str,
106+
top_k: int,
107+
filter: dict[str, Any] | None = None,
108+
) -> list[VecDBItem]:
109+
"""
110+
Search for similar items in the database.
111+
112+
Args:
113+
query_vector: Single vector to search
114+
collection_name: Name of the collection to search
115+
top_k: Number of results to return
116+
filter: Payload filters
117+
118+
Returns:
119+
List of search results with distance scores and payloads.
120+
"""
121+
# Convert filter to Milvus expression
122+
expr = self._dict_to_expr(filter) if filter else ""
123+
124+
results = self.client.search(
125+
collection_name=collection_name,
126+
data=[query_vector],
127+
limit=top_k,
128+
filter=expr,
129+
output_fields=["*"], # Return all fields
130+
)
131+
132+
items = []
133+
for hit in results[0]:
134+
entity = hit.get("entity", {})
135+
136+
items.append(
137+
VecDBItem(
138+
id=str(hit["id"]),
139+
vector=entity.get("vector"),
140+
payload=entity.get("payload", {}),
141+
score=1 - float(hit["distance"]),
142+
)
143+
)
144+
145+
logger.info(f"Milvus search completed with {len(items)} results.")
146+
return items
147+
148+
def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str:
149+
"""Convert a dictionary filter to a Milvus expression string."""
150+
if not filter_dict:
151+
return ""
152+
153+
conditions = []
154+
for field, value in filter_dict.items():
155+
# Skip None values as they cause Milvus query syntax errors
156+
if value is None:
157+
continue
158+
# For JSON fields, we need to use payload["field"] syntax
159+
elif isinstance(value, str):
160+
conditions.append(f"payload['{field}'] == '{value}'")
161+
elif isinstance(value, list) and len(value) == 0:
162+
# Skip empty lists as they cause Milvus query syntax errors
163+
continue
164+
elif isinstance(value, list) and len(value) > 0:
165+
conditions.append(f"payload['{field}'] in {value}")
166+
else:
167+
conditions.append(f"payload['{field}'] == '{value}'")
168+
return " and ".join(conditions)
169+
170+
def _get_metric_type(self) -> str:
171+
"""Get the metric type for search."""
172+
metric_map = {
173+
"cosine": "COSINE",
174+
"euclidean": "L2",
175+
"dot": "IP",
176+
}
177+
return metric_map.get(self.config.distance_metric, "L2")
178+
179+
def get_by_id(self, collection_name: str, id: str) -> VecDBItem | None:
180+
"""Get a single item by ID."""
181+
results = self.client.get(
182+
collection_name=collection_name,
183+
ids=[id],
184+
)
185+
186+
if not results:
187+
return None
188+
189+
entity = results[0]
190+
payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]}
191+
192+
return VecDBItem(
193+
id=entity["id"],
194+
vector=entity.get("vector"),
195+
payload=payload,
196+
)
197+
198+
def get_by_ids(self, collection_name: str, ids: list[str]) -> list[VecDBItem]:
199+
"""Get multiple items by their IDs."""
200+
results = self.client.get(
201+
collection_name=collection_name,
202+
ids=ids,
203+
)
204+
205+
if not results:
206+
return []
207+
208+
items = []
209+
for entity in results:
210+
payload = {k: v for k, v in entity.items() if k not in ["id", "vector", "score"]}
211+
items.append(
212+
VecDBItem(
213+
id=entity["id"],
214+
vector=entity.get("vector"),
215+
payload=payload,
216+
)
217+
)
218+
219+
return items
220+
221+
def get_by_filter(
222+
self, collection_name: str, filter: dict[str, Any], scroll_limit: int = 100
223+
) -> list[VecDBItem]:
224+
"""
225+
Retrieve all items that match the given filter criteria using query_iterator.
226+
227+
Args:
228+
filter: Payload filters to match against stored items
229+
scroll_limit: Maximum number of items to retrieve per batch (batch_size)
230+
231+
Returns:
232+
List of items including vectors and payload that match the filter
233+
"""
234+
expr = self._dict_to_expr(filter) if filter else ""
235+
all_items = []
236+
237+
# Use query_iterator for efficient pagination
238+
iterator = self.client.query_iterator(
239+
collection_name=collection_name,
240+
filter=expr,
241+
batch_size=scroll_limit,
242+
output_fields=["*"], # Include all fields including payload
243+
)
244+
245+
# Iterate through all batches
246+
try:
247+
while True:
248+
batch_results = iterator.next()
249+
250+
if not batch_results:
251+
break
252+
253+
# Convert batch results to VecDBItem objects
254+
for entity in batch_results:
255+
# Extract the actual payload from Milvus entity
256+
payload = entity.get("payload", {})
257+
all_items.append(
258+
VecDBItem(
259+
id=entity["id"],
260+
vector=entity.get("vector"),
261+
payload=payload,
262+
)
263+
)
264+
except Exception as e:
265+
logger.warning(
266+
f"Error during Milvus query iteration: {e}. Returning {len(all_items)} items found so far."
267+
)
268+
finally:
269+
# Close the iterator
270+
iterator.close()
271+
272+
logger.info(f"Milvus retrieve by filter completed with {len(all_items)} results.")
273+
return all_items
274+
275+
def get_all(self, collection_name: str, scroll_limit=100) -> list[VecDBItem]:
276+
"""Retrieve all items in the vector database."""
277+
return self.get_by_filter(collection_name, {}, scroll_limit=scroll_limit)
278+
279+
def count(self, collection_name: str, filter: dict[str, Any] | None = None) -> int:
280+
"""Count items in the database, optionally with filter."""
281+
if filter:
282+
# If there's a filter, use query method
283+
expr = self._dict_to_expr(filter) if filter else ""
284+
results = self.client.query(
285+
collection_name=collection_name,
286+
filter=expr,
287+
output_fields=["id"],
288+
)
289+
return len(results)
290+
else:
291+
# For counting all items, use get_collection_stats for accurate count
292+
stats = self.client.get_collection_stats(collection_name)
293+
# Extract row count from stats - stats is a dict, not a list
294+
return int(stats.get("row_count", 0))
295+
296+
def add(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None:
297+
"""
298+
Add data to the vector database.
299+
300+
Args:
301+
data: List of VecDBItem objects or dictionaries containing:
302+
- 'id': unique identifier
303+
- 'vector': embedding vector
304+
- 'payload': additional fields for filtering/retrieval
305+
"""
306+
entities = []
307+
for item in data:
308+
if isinstance(item, dict):
309+
item = item.copy()
310+
item = VecDBItem.from_dict(item)
311+
312+
# Prepare entity data
313+
entity = {
314+
"id": item.id,
315+
"vector": item.vector,
316+
"payload": item.payload if item.payload else {},
317+
}
318+
319+
entities.append(entity)
320+
321+
# Use upsert to be safe (insert or update)
322+
self.client.upsert(
323+
collection_name=collection_name,
324+
data=entities,
325+
)
326+
327+
def update(self, collection_name: str, id: str, data: VecDBItem | dict[str, Any]) -> None:
328+
"""Update an item in the vector database."""
329+
if isinstance(data, dict):
330+
data = data.copy()
331+
data = VecDBItem.from_dict(data)
332+
333+
# Use upsert for updates
334+
self.upsert(collection_name, [data])
335+
336+
def ensure_payload_indexes(self, fields: list[str]) -> None:
337+
"""
338+
Create payload indexes for specified fields in the collection.
339+
This is idempotent: it will skip if index already exists.
340+
341+
Args:
342+
fields (list[str]): List of field names to index (as keyword).
343+
"""
344+
# Note: Milvus doesn't have the same concept of payload indexes as Qdrant
345+
# Field indexes are created automatically for scalar fields
346+
logger.info(f"Milvus automatically indexes scalar fields: {fields}")
347+
348+
def upsert(self, collection_name: str, data: list[VecDBItem | dict[str, Any]]) -> None:
349+
"""
350+
Add or update data in the vector database.
351+
352+
If an item with the same ID exists, it will be updated.
353+
Otherwise, it will be added as a new item.
354+
"""
355+
# Reuse add method since it already uses upsert
356+
self.add(collection_name, data)
357+
358+
def delete(self, collection_name: str, ids: list[str]) -> None:
359+
"""Delete items from the vector database."""
360+
if not ids:
361+
return
362+
self.client.delete(
363+
collection_name=collection_name,
364+
ids=ids,
365+
)

0 commit comments

Comments
 (0)