7
7
CollectionSchema ,
8
8
DataType ,
9
9
FieldSchema ,
10
- connections ,
10
+ MilvusClient ,
11
11
)
12
- from pymilvus .orm .connections import Connections
13
12
14
13
from feast import Entity
15
14
from feast .feature_view import FeatureView
@@ -85,14 +84,15 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
85
84
"""
86
85
87
86
type : Literal ["milvus" ] = "milvus"
88
-
89
87
host : Optional [StrictStr ] = "localhost"
90
88
port : Optional [int ] = 19530
91
89
index_type : Optional [str ] = "IVF_FLAT"
92
90
metric_type : Optional [str ] = "L2"
93
91
embedding_dim : Optional [int ] = 128
94
92
vector_enabled : Optional [bool ] = True
95
93
nlist : Optional [int ] = 128
94
+ username : Optional [StrictStr ] = ""
95
+ password : Optional [StrictStr ] = ""
96
96
97
97
98
98
class MilvusOnlineStore (OnlineStore ):
@@ -103,24 +103,23 @@ class MilvusOnlineStore(OnlineStore):
103
103
_collections: Dictionary to cache Milvus collections.
104
104
"""
105
105
106
- _conn : Optional [Connections ] = None
107
- _collections : Dict [str , Collection ] = {}
106
+ client : Optional [MilvusClient ] = None
107
+ _collections : Dict [str , Any ] = {}
108
108
109
- def _connect (self , config : RepoConfig ) -> connections :
110
- if not self ._conn :
111
- if not connections . has_connection ( "feast" ):
112
- self . _conn = connections . connect (
113
- alias = "feast" ,
114
- host = config .online_store .host ,
115
- port = str ( config . online_store . port ) ,
116
- )
117
- return self ._conn
109
+ def _connect (self , config : RepoConfig ) -> MilvusClient :
110
+ if not self .client :
111
+ self . client = MilvusClient (
112
+ url = f" { config . online_store . host } : { config . online_store . port } " ,
113
+ token = f" { config . online_store . username } : { config . online_store . password } "
114
+ if config . online_store . username and config .online_store .password
115
+ else "" ,
116
+ )
117
+ return self .client
118
118
119
- def _get_collection (self , config : RepoConfig , table : FeatureView ) -> Collection :
119
+ def _get_collection (self , config : RepoConfig , table : FeatureView ) -> Dict [str , Any ]:
120
+ self .client = self ._connect (config )
120
121
collection_name = _table_id (config .project , table )
121
122
if collection_name not in self ._collections :
122
- self ._connect (config )
123
-
124
123
# Create a composite key by combining entity fields
125
124
composite_key_name = (
126
125
"_" .join ([field .name for field in table .entity_columns ]) + "_pk"
@@ -166,23 +165,38 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
166
165
schema = CollectionSchema (
167
166
fields = fields , description = "Feast feature view data"
168
167
)
169
- collection = Collection (name = collection_name , schema = schema , using = "feast" )
170
- if not collection .has_index ():
171
- index_params = {
172
- "index_type" : config .online_store .index_type ,
173
- "metric_type" : config .online_store .metric_type ,
174
- "params" : {"nlist" : config .online_store .nlist },
175
- }
176
- for vector_field in schema .fields :
177
- if vector_field .dtype in [
178
- DataType .FLOAT_VECTOR ,
179
- DataType .BINARY_VECTOR ,
180
- ]:
181
- collection .create_index (
182
- field_name = vector_field .name , index_params = index_params
183
- )
184
- collection .load ()
185
- self ._collections [collection_name ] = collection
168
+ collection_exists = self .client .has_collection (
169
+ collection_name = collection_name
170
+ )
171
+ if not collection_exists :
172
+ self .client .create_collection (
173
+ collection_name = collection_name ,
174
+ dimension = config .online_store .embedding_dim ,
175
+ schema = schema ,
176
+ )
177
+ index_params = self .client .prepare_index_params ()
178
+ for vector_field in schema .fields :
179
+ if vector_field .dtype in [
180
+ DataType .FLOAT_VECTOR ,
181
+ DataType .BINARY_VECTOR ,
182
+ ]:
183
+ index_params .add_index (
184
+ collection_name = collection_name ,
185
+ field_name = vector_field .name ,
186
+ metric_type = config .online_store .metric_type ,
187
+ index_type = config .online_store .index_type ,
188
+ index_name = f"vector_index_{ vector_field .name } " ,
189
+ params = {"nlist" : config .online_store .nlist },
190
+ )
191
+ self .client .create_index (
192
+ collection_name = collection_name ,
193
+ index_params = index_params ,
194
+ )
195
+ else :
196
+ self .client .load_collection (collection_name )
197
+ self ._collections [collection_name ] = self .client .describe_collection (
198
+ collection_name
199
+ )
186
200
return self ._collections [collection_name ]
187
201
188
202
def online_write_batch (
@@ -199,6 +213,7 @@ def online_write_batch(
199
213
],
200
214
progress : Optional [Callable [[int ], Any ]],
201
215
) -> None :
216
+ self .client = self ._connect (config )
202
217
collection = self ._get_collection (config , table )
203
218
entity_batch_to_insert = []
204
219
for entity_key , values_dict , timestamp , created_ts in data :
@@ -231,8 +246,9 @@ def online_write_batch(
231
246
if progress :
232
247
progress (1 )
233
248
234
- collection .insert (entity_batch_to_insert )
235
- collection .flush ()
249
+ self .client .insert (
250
+ collection_name = collection ["collection_name" ], data = entity_batch_to_insert
251
+ )
236
252
237
253
def online_read (
238
254
self ,
@@ -252,14 +268,14 @@ def update(
252
268
entities_to_keep : Sequence [Entity ],
253
269
partial : bool ,
254
270
):
255
- self ._connect (config )
271
+ self .client = self . _connect (config )
256
272
for table in tables_to_keep :
257
- self ._get_collection (config , table )
273
+ self ._collections = self ._get_collection (config , table )
274
+
258
275
for table in tables_to_delete :
259
276
collection_name = _table_id (config .project , table )
260
- collection = Collection (name = collection_name )
261
- if collection .exists ():
262
- collection .drop ()
277
+ if self ._collections .get (collection_name , None ):
278
+ self .client .drop_collection (collection_name )
263
279
self ._collections .pop (collection_name , None )
264
280
265
281
def plan (
@@ -273,12 +289,12 @@ def teardown(
273
289
tables : Sequence [FeatureView ],
274
290
entities : Sequence [Entity ],
275
291
):
276
- self ._connect (config )
292
+ self .client = self . _connect (config )
277
293
for table in tables :
278
- collection = self . _get_collection (config , table )
279
- if collection :
280
- collection . drop ( )
281
- self ._collections .pop (collection . name , None )
294
+ collection_name = _table_id (config . project , table )
295
+ if self . _collections . get ( collection_name , None ) :
296
+ self . client . drop_collection ( collection_name )
297
+ self ._collections .pop (collection_name , None )
282
298
283
299
def retrieve_online_documents (
284
300
self ,
@@ -298,6 +314,8 @@ def retrieve_online_documents(
298
314
Optional [ValueProto ],
299
315
]
300
316
]:
317
+ self .client = self ._connect (config )
318
+ collection_name = _table_id (config .project , table )
301
319
collection = self ._get_collection (config , table )
302
320
if not config .online_store .vector_enabled :
303
321
raise ValueError ("Vector search is not enabled in the online store config" )
@@ -321,42 +339,45 @@ def retrieve_online_documents(
321
339
+ ["created_ts" , "event_ts" ]
322
340
)
323
341
assert all (
324
- field
342
+ field in [ f [ "name" ] for f in collection [ "fields" ]]
325
343
for field in output_fields
326
- if field in [f .name for f in collection .schema .fields ]
327
- ), f"field(s) [{ [field for field in output_fields if field not in [f .name for f in collection .schema .fields ]]} '] not found in collection schema"
328
-
344
+ ), f"field(s) [{ [field for field in output_fields if field not in [f ['name' ] for f in collection ['fields' ]]]} ] not found in collection schema"
329
345
# Note we choose the first vector field as the field to search on. Not ideal but it's something.
330
346
ann_search_field = None
331
- for field in collection . schema . fields :
347
+ for field in collection [ " fields" ] :
332
348
if (
333
- field . dtype in [DataType .FLOAT_VECTOR , DataType .BINARY_VECTOR ]
334
- and field . name in output_fields
349
+ field [ "type" ] in [DataType .FLOAT_VECTOR , DataType .BINARY_VECTOR ]
350
+ and field [ " name" ] in output_fields
335
351
):
336
- ann_search_field = field . name
352
+ ann_search_field = field [ " name" ]
337
353
break
338
354
339
- results = collection .search (
355
+ self .client .load_collection (collection_name )
356
+ results = self .client .search (
357
+ collection_name = collection_name ,
340
358
data = [embedding ],
341
359
anns_field = ann_search_field ,
342
- param = search_params ,
360
+ search_params = search_params ,
343
361
limit = top_k ,
344
362
output_fields = output_fields ,
345
- consistency_level = "Strong" ,
346
363
)
347
364
348
365
result_list = []
349
366
for hits in results :
350
367
for hit in hits :
351
368
single_record = {}
352
369
for field in output_fields :
353
- single_record [field ] = hit .entity .get (field )
370
+ single_record [field ] = hit .get ( " entity" , {}) .get (field , None )
354
371
355
- entity_key_bytes = bytes .fromhex (hit .entity .get (composite_key_name ))
356
- embedding = hit .entity .get (ann_search_field )
372
+ entity_key_bytes = bytes .fromhex (
373
+ hit .get ("entity" , {}).get (composite_key_name , None )
374
+ )
375
+ embedding = hit .get ("entity" , {}).get (ann_search_field )
357
376
serialized_embedding = _serialize_vector_to_float_list (embedding )
358
- distance = hit .distance
359
- event_ts = datetime .fromtimestamp (hit .entity .get ("event_ts" ) / 1e6 )
377
+ distance = hit .get ("distance" , None )
378
+ event_ts = datetime .fromtimestamp (
379
+ hit .get ("entity" , {}).get ("event_ts" ) / 1e6
380
+ )
360
381
prepared_result = _build_retrieve_online_document_record (
361
382
entity_key_bytes ,
362
383
# This may have a bug
@@ -412,7 +433,7 @@ def __init__(self, host: str, port: int, name: str):
412
433
self ._connect ()
413
434
414
435
def _connect (self ):
415
- return connections . connect ( alias = "default" , host = self . host , port = str ( self . port ))
436
+ raise NotImplementedError
416
437
417
438
def to_infra_object_proto (self ) -> InfraObjectProto :
418
439
# Implement serialization if needed
0 commit comments