Skip to content

Commit e5527ad

Browse files
chore: Updating tests to allow for the CLIRunner to use Milvus, also have to handle special case of not running apply and teardown (feast-dev#4915)
* chore: Updating tests to allow for the CLIRunner to use Milvus, also have to handle special case of not running apply and teardown Signed-off-by: Francisco Javier Arceo <[email protected]> * Adding cleanup Signed-off-by: Francisco Javier Arceo <[email protected]> * adding example repo Signed-off-by: Francisco Javier Arceo <[email protected]> * changing defualt to FLAT for local implementation Signed-off-by: Francisco Javier Arceo <[email protected]> --------- Signed-off-by: Francisco Javier Arceo <[email protected]>
1 parent a8aeb79 commit e5527ad

File tree

4 files changed

+293
-33
lines changed

4 files changed

+293
-33
lines changed

sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime
2+
from pathlib import Path
23
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
34

45
from pydantic import StrictStr
@@ -84,9 +85,10 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
8485
"""
8586

8687
type: Literal["milvus"] = "milvus"
88+
path: Optional[StrictStr] = "data/online_store.db"
8789
host: Optional[StrictStr] = "localhost"
8890
port: Optional[int] = 19530
89-
index_type: Optional[str] = "IVF_FLAT"
91+
index_type: Optional[str] = "FLAT"
9092
metric_type: Optional[str] = "L2"
9193
embedding_dim: Optional[int] = 128
9294
vector_enabled: Optional[bool] = True
@@ -106,11 +108,24 @@ class MilvusOnlineStore(OnlineStore):
106108
client: Optional[MilvusClient] = None
107109
_collections: Dict[str, Any] = {}
108110

111+
def _get_db_path(self, config: RepoConfig) -> str:
112+
assert (
113+
config.online_store.type == "milvus"
114+
or config.online_store.type.endswith("MilvusOnlineStore")
115+
)
116+
117+
if config.repo_path and not Path(config.online_store.path).is_absolute():
118+
db_path = str(config.repo_path / config.online_store.path)
119+
else:
120+
db_path = config.online_store.path
121+
return db_path
122+
109123
def _connect(self, config: RepoConfig) -> MilvusClient:
110124
if not self.client:
111125
if config.provider == "local":
112-
print("Connecting to Milvus in local mode using ./milvus_demo.db")
113-
self.client = MilvusClient("./milvus_demo.db")
126+
db_path = self._get_db_path(config)
127+
print(f"Connecting to Milvus in local mode using {db_path}")
128+
self.client = MilvusClient(db_path)
114129
else:
115130
self.client = MilvusClient(
116131
url=f"{config.online_store.host}:{config.online_store.port}",
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from datetime import timedelta
2+
3+
from feast import Entity, FeatureView, Field, FileSource
4+
from feast.types import Array, Float32, Int64, UnixTimestamp
5+
6+
# This is for Milvus
7+
# Note that file source paths are not validated, so there doesn't actually need to be any data
8+
# at the paths for these file sources. Since these paths are effectively fake, this example
9+
# feature repo should not be used for historical retrieval.
10+
11+
rag_documents_source = FileSource(
12+
path="data/embedded_documents.parquet",
13+
timestamp_field="event_timestamp",
14+
created_timestamp_column="created_timestamp",
15+
)
16+
17+
item = Entity(
18+
name="item_id", # The name is derived from this argument, not object name.
19+
join_keys=["item_id"],
20+
)
21+
22+
document_embeddings = FeatureView(
23+
name="embedded_documents",
24+
entities=[item],
25+
schema=[
26+
Field(
27+
name="vector",
28+
dtype=Array(Float32),
29+
vector_index=True,
30+
vector_search_metric="L2",
31+
),
32+
Field(name="item_id", dtype=Int64),
33+
Field(name="created_timestamp", dtype=UnixTimestamp),
34+
Field(name="event_timestamp", dtype=UnixTimestamp),
35+
],
36+
source=rag_documents_source,
37+
ttl=timedelta(hours=24),
38+
)

sdk/python/tests/unit/online_store/test_online_retrieval.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import platform
3+
import random
34
import sqlite3
45
import sys
56
import time
@@ -561,3 +562,182 @@ def test_sqlite_vec_import() -> None:
561562
""").fetchall()
562563
result = [(rowid, round(distance, 2)) for rowid, distance in result]
563564
assert result == [(2, 2.39), (1, 2.39)]
565+
566+
567+
def test_local_milvus() -> None:
568+
import random
569+
570+
from pymilvus import MilvusClient
571+
572+
random.seed(42)
573+
VECTOR_LENGTH: int = 768
574+
COLLECTION_NAME: str = "test_demo_collection"
575+
576+
client = MilvusClient("./milvus_demo.db")
577+
578+
for collection in client.list_collections():
579+
client.drop_collection(collection_name=collection)
580+
client.create_collection(
581+
collection_name=COLLECTION_NAME,
582+
dimension=VECTOR_LENGTH,
583+
)
584+
assert client.list_collections() == [COLLECTION_NAME]
585+
586+
docs = [
587+
"Artificial intelligence was founded as an academic discipline in 1956.",
588+
"Alan Turing was the first person to conduct substantial research in AI.",
589+
"Born in Maida Vale, London, Turing was raised in southern England.",
590+
]
591+
# Use fake representation with random vectors (vector_length dimension).
592+
vectors = [[random.uniform(-1, 1) for _ in range(VECTOR_LENGTH)] for _ in docs]
593+
data = [
594+
{"id": i, "vector": vectors[i], "text": docs[i], "subject": "history"}
595+
for i in range(len(vectors))
596+
]
597+
598+
print("Data has", len(data), "entities, each with fields: ", data[0].keys())
599+
print("Vector dim:", len(data[0]["vector"]))
600+
601+
insert_res = client.insert(collection_name=COLLECTION_NAME, data=data)
602+
assert insert_res == {"insert_count": 3, "ids": [0, 1, 2], "cost": 0}
603+
604+
query_vectors = [[random.uniform(-1, 1) for _ in range(VECTOR_LENGTH)]]
605+
606+
search_res = client.search(
607+
collection_name=COLLECTION_NAME, # target collection
608+
data=query_vectors, # query vectors
609+
limit=2, # number of returned entities
610+
output_fields=["text", "subject"], # specifies fields to be returned
611+
)
612+
assert [j["id"] for j in search_res[0]] == [0, 1]
613+
query_result = client.query(
614+
collection_name=COLLECTION_NAME,
615+
filter="id == 0",
616+
)
617+
assert list(query_result[0].keys()) == ["id", "text", "subject", "vector"]
618+
619+
client.drop_collection(collection_name=COLLECTION_NAME)
620+
621+
622+
def test_milvus_lite_get_online_documents() -> None:
623+
"""
624+
Test retrieving documents from the online store in local mode.
625+
"""
626+
627+
random.seed(42)
628+
n = 10 # number of samples - note: we'll actually double it
629+
vector_length = 10
630+
runner = CliRunner()
631+
with runner.local_repo(
632+
example_repo_py=get_example_repo("example_rag_feature_repo.py"),
633+
offline_store="file",
634+
online_store="milvus",
635+
apply=False,
636+
teardown=False,
637+
) as store:
638+
from datetime import timedelta
639+
640+
from feast import Entity, FeatureView, Field, FileSource
641+
from feast.types import Array, Float32, Int64, UnixTimestamp
642+
643+
# This is for Milvus
644+
# Note that file source paths are not validated, so there doesn't actually need to be any data
645+
# at the paths for these file sources. Since these paths are effectively fake, this example
646+
# feature repo should not be used for historical retrieval.
647+
648+
rag_documents_source = FileSource(
649+
path="data/embedded_documents.parquet",
650+
timestamp_field="event_timestamp",
651+
created_timestamp_column="created_timestamp",
652+
)
653+
654+
item = Entity(
655+
name="item_id", # The name is derived from this argument, not object name.
656+
join_keys=["item_id"],
657+
)
658+
659+
document_embeddings = FeatureView(
660+
name="embedded_documents",
661+
entities=[item],
662+
schema=[
663+
Field(
664+
name="vector",
665+
dtype=Array(Float32),
666+
vector_index=True,
667+
vector_search_metric="L2",
668+
),
669+
Field(name="item_id", dtype=Int64),
670+
Field(name="created_timestamp", dtype=UnixTimestamp),
671+
Field(name="event_timestamp", dtype=UnixTimestamp),
672+
],
673+
source=rag_documents_source,
674+
ttl=timedelta(hours=24),
675+
)
676+
677+
store.apply([rag_documents_source, item, document_embeddings])
678+
679+
# Write some data to two tables
680+
document_embeddings_fv = store.get_feature_view(name="embedded_documents")
681+
682+
provider = store._get_provider()
683+
684+
item_keys = [
685+
EntityKeyProto(
686+
join_keys=["item_id"], entity_values=[ValueProto(int64_val=i)]
687+
)
688+
for i in range(n)
689+
]
690+
data = []
691+
for item_key in item_keys:
692+
data.append(
693+
(
694+
item_key,
695+
{
696+
"vector": ValueProto(
697+
float_list_val=FloatListProto(
698+
val=np.random.random(
699+
vector_length,
700+
)
701+
)
702+
)
703+
},
704+
_utc_now(),
705+
_utc_now(),
706+
)
707+
)
708+
709+
provider.online_write_batch(
710+
config=store.config,
711+
table=document_embeddings_fv,
712+
data=data,
713+
progress=None,
714+
)
715+
documents_df = pd.DataFrame(
716+
{
717+
"item_id": [str(i) for i in range(n)],
718+
"vector": [
719+
np.random.random(
720+
vector_length,
721+
)
722+
for i in range(n)
723+
],
724+
"event_timestamp": [_utc_now() for _ in range(n)],
725+
"created_timestamp": [_utc_now() for _ in range(n)],
726+
}
727+
)
728+
729+
store.write_to_online_store(
730+
feature_view_name="embedded_documents",
731+
df=documents_df,
732+
)
733+
734+
query_embedding = np.random.random(
735+
vector_length,
736+
)
737+
result = store.retrieve_online_documents(
738+
feature="embedded_documents:vector", query=query_embedding, top_k=3
739+
).to_dict()
740+
741+
assert "vector" in result
742+
assert "distance" in result
743+
assert len(result["distance"]) == 3

sdk/python/tests/utils/cli_repo_creator.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@ def run_with_output(self, args: List[str], cwd: Path) -> Tuple[int, bytes]:
5151
return e.returncode, e.output
5252

5353
@contextmanager
54-
def local_repo(self, example_repo_py: str, offline_store: str):
54+
def local_repo(
55+
self,
56+
example_repo_py: str,
57+
offline_store: str,
58+
online_store: str = "sqlite",
59+
apply=True,
60+
teardown=True,
61+
):
5562
"""
5663
Convenience method to set up all the boilerplate for a local feature repo.
5764
"""
@@ -67,41 +74,61 @@ def local_repo(self, example_repo_py: str, offline_store: str):
6774
data_path = Path(data_dir_name)
6875

6976
repo_config = repo_path / "feature_store.yaml"
70-
71-
repo_config.write_text(
72-
dedent(
77+
if online_store == "sqlite":
78+
yaml_config = dedent(
7379
f"""
74-
project: {project_id}
75-
registry: {data_path / "registry.db"}
76-
provider: local
77-
online_store:
78-
path: {data_path / "online_store.db"}
79-
offline_store:
80-
type: {offline_store}
81-
entity_key_serialization_version: 2
82-
"""
80+
project: {project_id}
81+
registry: {data_path / "registry.db"}
82+
provider: local
83+
online_store:
84+
path: {data_path / "online_store.db"}
85+
offline_store:
86+
type: {offline_store}
87+
entity_key_serialization_version: 2
88+
"""
8389
)
84-
)
90+
elif online_store == "milvus":
91+
yaml_config = dedent(
92+
f"""
93+
project: {project_id}
94+
registry: {data_path / "registry.db"}
95+
provider: local
96+
online_store:
97+
path: {data_path / "online_store.db"}
98+
type: milvus
99+
vector_enabled: true
100+
embedding_dim: 10
101+
offline_store:
102+
type: {offline_store}
103+
entity_key_serialization_version: 3
104+
"""
105+
)
106+
else:
107+
pass
108+
109+
repo_config.write_text(yaml_config)
85110

86111
repo_example = repo_path / "example.py"
87112
repo_example.write_text(example_repo_py)
88113

89-
result = self.run(["apply"], cwd=repo_path)
90-
stdout = result.stdout.decode("utf-8")
91-
stderr = result.stderr.decode("utf-8")
92-
print(f"Apply stdout:\n{stdout}")
93-
print(f"Apply stderr:\n{stderr}")
94-
assert (
95-
result.returncode == 0
96-
), f"stdout: {result.stdout}\nstderr: {result.stderr}"
114+
if apply:
115+
result = self.run(["apply"], cwd=repo_path)
116+
stdout = result.stdout.decode("utf-8")
117+
stderr = result.stderr.decode("utf-8")
118+
print(f"Apply stdout:\n{stdout}")
119+
print(f"Apply stderr:\n{stderr}")
120+
assert (
121+
result.returncode == 0
122+
), f"stdout: {result.stdout}\nstderr: {result.stderr}"
97123

98124
yield FeatureStore(repo_path=str(repo_path), config=None)
99125

100-
result = self.run(["teardown"], cwd=repo_path)
101-
stdout = result.stdout.decode("utf-8")
102-
stderr = result.stderr.decode("utf-8")
103-
print(f"Apply stdout:\n{stdout}")
104-
print(f"Apply stderr:\n{stderr}")
105-
assert (
106-
result.returncode == 0
107-
), f"stdout: {result.stdout}\nstderr: {result.stderr}"
126+
if teardown:
127+
result = self.run(["teardown"], cwd=repo_path)
128+
stdout = result.stdout.decode("utf-8")
129+
stderr = result.stderr.decode("utf-8")
130+
print(f"Apply stdout:\n{stdout}")
131+
print(f"Apply stderr:\n{stderr}")
132+
assert (
133+
result.returncode == 0
134+
), f"stdout: {result.stdout}\nstderr: {result.stderr}"

0 commit comments

Comments
 (0)