Skip to content

Commit 5f9b5b5

Browse files
chore: Moving Milvus client to PyMilvus (feast-dev#4907)
* chore: Moving Milvus client to PyMilvus Signed-off-by: Francisco Javier Arceo <[email protected]> * linted and switched implementation to pymilvus Signed-off-by: Francisco Javier Arceo <[email protected]> * adding updates for integration configuration Signed-off-by: Francisco Javier Arceo <[email protected]> * removing drop statement Signed-off-by: Francisco Javier Arceo <[email protected]> --------- Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent 76e1e21 commit 5f9b5b5

File tree

4 files changed

+118
-89
lines changed

4 files changed

+118
-89
lines changed

sdk/python/feast/feature_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1757,7 +1757,7 @@ def retrieve_online_documents(
17571757
query: Union[str, List[float]],
17581758
top_k: int,
17591759
features: Optional[List[str]] = None,
1760-
distance_metric: Optional[str] = None,
1760+
distance_metric: Optional[str] = "L2",
17611761
) -> OnlineResponse:
17621762
"""
17631763
Retrieves the top k closest document features. Note, embeddings are a subset of features.

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

Lines changed: 84 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
CollectionSchema,
88
DataType,
99
FieldSchema,
10-
connections,
10+
MilvusClient,
1111
)
12-
from pymilvus.orm.connections import Connections
1312

1413
from feast import Entity
1514
from feast.feature_view import FeatureView
@@ -85,14 +84,15 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
8584
"""
8685

8786
type: Literal["milvus"] = "milvus"
88-
8987
host: Optional[StrictStr] = "localhost"
9088
port: Optional[int] = 19530
9189
index_type: Optional[str] = "IVF_FLAT"
9290
metric_type: Optional[str] = "L2"
9391
embedding_dim: Optional[int] = 128
9492
vector_enabled: Optional[bool] = True
9593
nlist: Optional[int] = 128
94+
username: Optional[StrictStr] = ""
95+
password: Optional[StrictStr] = ""
9696

9797

9898
class MilvusOnlineStore(OnlineStore):
@@ -103,24 +103,23 @@ class MilvusOnlineStore(OnlineStore):
103103
_collections: Dictionary to cache Milvus collections.
104104
"""
105105

106-
_conn: Optional[Connections] = None
107-
_collections: Dict[str, Collection] = {}
106+
client: Optional[MilvusClient] = None
107+
_collections: Dict[str, Any] = {}
108108

109-
def _connect(self, config: RepoConfig) -> connections:
110-
if not self._conn:
111-
if not connections.has_connection("feast"):
112-
self._conn = connections.connect(
113-
alias="feast",
114-
host=config.online_store.host,
115-
port=str(config.online_store.port),
116-
)
117-
return self._conn
109+
def _connect(self, config: RepoConfig) -> MilvusClient:
110+
if not self.client:
111+
self.client = MilvusClient(
112+
url=f"{config.online_store.host}:{config.online_store.port}",
113+
token=f"{config.online_store.username}:{config.online_store.password}"
114+
if config.online_store.username and config.online_store.password
115+
else "",
116+
)
117+
return self.client
118118

119-
def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
119+
def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, Any]:
120+
self.client = self._connect(config)
120121
collection_name = _table_id(config.project, table)
121122
if collection_name not in self._collections:
122-
self._connect(config)
123-
124123
# Create a composite key by combining entity fields
125124
composite_key_name = (
126125
"_".join([field.name for field in table.entity_columns]) + "_pk"
@@ -166,23 +165,38 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
166165
schema = CollectionSchema(
167166
fields=fields, description="Feast feature view data"
168167
)
169-
collection = Collection(name=collection_name, schema=schema, using="feast")
170-
if not collection.has_index():
171-
index_params = {
172-
"index_type": config.online_store.index_type,
173-
"metric_type": config.online_store.metric_type,
174-
"params": {"nlist": config.online_store.nlist},
175-
}
176-
for vector_field in schema.fields:
177-
if vector_field.dtype in [
178-
DataType.FLOAT_VECTOR,
179-
DataType.BINARY_VECTOR,
180-
]:
181-
collection.create_index(
182-
field_name=vector_field.name, index_params=index_params
183-
)
184-
collection.load()
185-
self._collections[collection_name] = collection
168+
collection_exists = self.client.has_collection(
169+
collection_name=collection_name
170+
)
171+
if not collection_exists:
172+
self.client.create_collection(
173+
collection_name=collection_name,
174+
dimension=config.online_store.embedding_dim,
175+
schema=schema,
176+
)
177+
index_params = self.client.prepare_index_params()
178+
for vector_field in schema.fields:
179+
if vector_field.dtype in [
180+
DataType.FLOAT_VECTOR,
181+
DataType.BINARY_VECTOR,
182+
]:
183+
index_params.add_index(
184+
collection_name=collection_name,
185+
field_name=vector_field.name,
186+
metric_type=config.online_store.metric_type,
187+
index_type=config.online_store.index_type,
188+
index_name=f"vector_index_{vector_field.name}",
189+
params={"nlist": config.online_store.nlist},
190+
)
191+
self.client.create_index(
192+
collection_name=collection_name,
193+
index_params=index_params,
194+
)
195+
else:
196+
self.client.load_collection(collection_name)
197+
self._collections[collection_name] = self.client.describe_collection(
198+
collection_name
199+
)
186200
return self._collections[collection_name]
187201

188202
def online_write_batch(
@@ -199,6 +213,7 @@ def online_write_batch(
199213
],
200214
progress: Optional[Callable[[int], Any]],
201215
) -> None:
216+
self.client = self._connect(config)
202217
collection = self._get_collection(config, table)
203218
entity_batch_to_insert = []
204219
for entity_key, values_dict, timestamp, created_ts in data:
@@ -231,8 +246,9 @@ def online_write_batch(
231246
if progress:
232247
progress(1)
233248

234-
collection.insert(entity_batch_to_insert)
235-
collection.flush()
249+
self.client.insert(
250+
collection_name=collection["collection_name"], data=entity_batch_to_insert
251+
)
236252

237253
def online_read(
238254
self,
@@ -252,14 +268,14 @@ def update(
252268
entities_to_keep: Sequence[Entity],
253269
partial: bool,
254270
):
255-
self._connect(config)
271+
self.client = self._connect(config)
256272
for table in tables_to_keep:
257-
self._get_collection(config, table)
273+
self._collections = self._get_collection(config, table)
274+
258275
for table in tables_to_delete:
259276
collection_name = _table_id(config.project, table)
260-
collection = Collection(name=collection_name)
261-
if collection.exists():
262-
collection.drop()
277+
if self._collections.get(collection_name, None):
278+
self.client.drop_collection(collection_name)
263279
self._collections.pop(collection_name, None)
264280

265281
def plan(
@@ -273,12 +289,12 @@ def teardown(
273289
tables: Sequence[FeatureView],
274290
entities: Sequence[Entity],
275291
):
276-
self._connect(config)
292+
self.client = self._connect(config)
277293
for table in tables:
278-
collection = self._get_collection(config, table)
279-
if collection:
280-
collection.drop()
281-
self._collections.pop(collection.name, None)
294+
collection_name = _table_id(config.project, table)
295+
if self._collections.get(collection_name, None):
296+
self.client.drop_collection(collection_name)
297+
self._collections.pop(collection_name, None)
282298

283299
def retrieve_online_documents(
284300
self,
@@ -298,6 +314,8 @@ def retrieve_online_documents(
298314
Optional[ValueProto],
299315
]
300316
]:
317+
self.client = self._connect(config)
318+
collection_name = _table_id(config.project, table)
301319
collection = self._get_collection(config, table)
302320
if not config.online_store.vector_enabled:
303321
raise ValueError("Vector search is not enabled in the online store config")
@@ -321,42 +339,45 @@ def retrieve_online_documents(
321339
+ ["created_ts", "event_ts"]
322340
)
323341
assert all(
324-
field
342+
field in [f["name"] for f in collection["fields"]]
325343
for field in output_fields
326-
if field in [f.name for f in collection.schema.fields]
327-
), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"
328-
344+
), f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
329345
# Note we choose the first vector field as the field to search on. Not ideal but it's something.
330346
ann_search_field = None
331-
for field in collection.schema.fields:
347+
for field in collection["fields"]:
332348
if (
333-
field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
334-
and field.name in output_fields
349+
field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
350+
and field["name"] in output_fields
335351
):
336-
ann_search_field = field.name
352+
ann_search_field = field["name"]
337353
break
338354

339-
results = collection.search(
355+
self.client.load_collection(collection_name)
356+
results = self.client.search(
357+
collection_name=collection_name,
340358
data=[embedding],
341359
anns_field=ann_search_field,
342-
param=search_params,
360+
search_params=search_params,
343361
limit=top_k,
344362
output_fields=output_fields,
345-
consistency_level="Strong",
346363
)
347364

348365
result_list = []
349366
for hits in results:
350367
for hit in hits:
351368
single_record = {}
352369
for field in output_fields:
353-
single_record[field] = hit.entity.get(field)
370+
single_record[field] = hit.get("entity", {}).get(field, None)
354371

355-
entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name))
356-
embedding = hit.entity.get(ann_search_field)
372+
entity_key_bytes = bytes.fromhex(
373+
hit.get("entity", {}).get(composite_key_name, None)
374+
)
375+
embedding = hit.get("entity", {}).get(ann_search_field)
357376
serialized_embedding = _serialize_vector_to_float_list(embedding)
358-
distance = hit.distance
359-
event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6)
377+
distance = hit.get("distance", None)
378+
event_ts = datetime.fromtimestamp(
379+
hit.get("entity", {}).get("event_ts") / 1e6
380+
)
360381
prepared_result = _build_retrieve_online_document_record(
361382
entity_key_bytes,
362383
# This may have a bug
@@ -412,7 +433,7 @@ def __init__(self, host: str, port: int, name: str):
412433
self._connect()
413434

414435
def _connect(self):
415-
return connections.connect(alias="default", host=self.host, port=str(self.port))
436+
raise NotImplementedError
416437

417438
def to_infra_object_proto(self) -> InfraObjectProto:
418439
# Implement serialization if needed

sdk/python/tests/integration/feature_repos/universal/online_store/milvus.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Dict
22

3-
from testcontainers.milvus import MilvusContainer
3+
import docker
4+
from testcontainers.core.container import DockerContainer
5+
from testcontainers.core.waiting_utils import wait_for_logs
46

57
from tests.integration.feature_repos.universal.online_store_creator import (
68
OnlineStoreCreator,
@@ -11,13 +13,19 @@ class MilvusOnlineStoreCreator(OnlineStoreCreator):
1113
def __init__(self, project_name: str, **kwargs):
1214
super().__init__(project_name)
1315
self.fixed_port = 19530
14-
self.container = MilvusContainer("milvusdb/milvus:v2.4.4").with_exposed_ports(
16+
self.container = DockerContainer("milvusdb/milvus:v2.4.4").with_exposed_ports(
1517
self.fixed_port
1618
)
19+
self.client = docker.from_env()
1720

1821
def create_online_store(self) -> Dict[str, Any]:
1922
self.container.start()
2023
# Wait for Milvus server to be ready
24+
# log_string_to_wait_for = "Ready to accept connections"
25+
log_string_to_wait_for = ""
26+
wait_for_logs(
27+
container=self.container, predicate=log_string_to_wait_for, timeout=30
28+
)
2129
host = "localhost"
2230
port = self.container.get_exposed_port(self.fixed_port)
2331
return {

sdk/python/tests/integration/online_store/test_universal_online.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -897,26 +897,26 @@ def test_retrieve_online_documents(environment, fake_document_data):
897897
).to_dict()
898898

899899

900-
# @pytest.mark.integration
901-
# @pytest.mark.universal_online_stores(only=["milvus"])
902-
# def test_retrieve_online_milvus_documents(environment, fake_document_data):
903-
# fs = environment.feature_store
904-
# df, data_source = fake_document_data
905-
# item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
906-
# fs.apply([item_embeddings_feature_view, item()])
907-
# fs.write_to_online_store("item_embeddings", df)
908-
# documents = fs.retrieve_online_documents(
909-
# feature=None,
910-
# features=[
911-
# "item_embeddings:embedding_float",
912-
# "item_embeddings:item_id",
913-
# "item_embeddings:string_feature",
914-
# ],
915-
# query=[1.0, 2.0],
916-
# top_k=2,
917-
# distance_metric="L2",
918-
# ).to_dict()
919-
# assert len(documents["embedding_float"]) == 2
920-
#
921-
# assert len(documents["item_id"]) == 2
922-
# assert documents["item_id"] == [2, 3]
900+
@pytest.mark.integration
901+
@pytest.mark.universal_online_stores(only=["milvus"])
902+
def test_retrieve_online_milvus_documents(environment, fake_document_data):
903+
fs = environment.feature_store
904+
df, data_source = fake_document_data
905+
item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
906+
fs.apply([item_embeddings_feature_view, item()])
907+
fs.write_to_online_store("item_embeddings", df)
908+
documents = fs.retrieve_online_documents(
909+
feature=None,
910+
features=[
911+
"item_embeddings:embedding_float",
912+
"item_embeddings:item_id",
913+
"item_embeddings:string_feature",
914+
],
915+
query=[1.0, 2.0],
916+
top_k=2,
917+
distance_metric="L2",
918+
).to_dict()
919+
assert len(documents["embedding_float"]) == 2
920+
921+
assert len(documents["item_id"]) == 2
922+
assert documents["item_id"] == [2, 3]

0 commit comments

Comments
 (0)