4747 from pymilvus .exceptions import MilvusException
4848 from pymilvus .milvus_client import IndexParams
4949
50- from apache_beam .transforms .enrichment import Enrichment
51- from apache_beam .ml .rag .test_utils import MilvusTestHelpers
52- from apache_beam .ml .rag .test_utils import VectorDBContainerInfo
5350 from apache_beam .ml .rag .enrichment .milvus_search import HybridSearchParameters
5451 from apache_beam .ml .rag .enrichment .milvus_search import KeywordSearchMetrics
5552 from apache_beam .ml .rag .enrichment .milvus_search import KeywordSearchParameters
5956 from apache_beam .ml .rag .enrichment .milvus_search import MilvusSearchParameters
6057 from apache_beam .ml .rag .enrichment .milvus_search import VectorSearchMetrics
6158 from apache_beam .ml .rag .enrichment .milvus_search import VectorSearchParameters
59+ from apache_beam .ml .rag .test_utils import MilvusTestHelpers
60+ from apache_beam .ml .rag .test_utils import VectorDBContainerInfo
61+ from apache_beam .transforms .enrichment import Enrichment
6262except ImportError as e :
6363 raise unittest .SkipTest (f'Milvus dependencies not installed: { str (e )} ' )
6464
@@ -235,77 +235,6 @@ def __getitem__(self, key):
235235}
236236
237237
238- def initialize_db_with_data (connc_params : MilvusConnectionParameters ):
239- # Open the connection to the milvus db with retry.
240- def create_client ():
241- return MilvusClient (** connc_params .__dict__ )
242-
243- client = retry_with_backoff (
244- create_client ,
245- max_retries = 3 ,
246- retry_delay = 1.0 ,
247- operation_name = "Test Milvus client connection" ,
248- exception_types = (MilvusException , ))
249-
250- # Configure schema.
251- field_schemas : List [FieldSchema ] = cast (
252- List [FieldSchema ], MILVUS_IT_CONFIG ["fields" ])
253- schema = CollectionSchema (
254- fields = field_schemas , functions = MILVUS_IT_CONFIG ["functions" ])
255-
256- # Create collection with the schema.
257- collection_name = MILVUS_IT_CONFIG ["collection_name" ]
258- index_function : Callable [[], IndexParams ] = cast (
259- Callable [[], IndexParams ], MILVUS_IT_CONFIG ["index" ])
260- client .create_collection (
261- collection_name = collection_name ,
262- schema = schema ,
263- index_params = index_function ())
264-
265- # Assert that collection was created.
266- collection_error = f"Expected collection '{ collection_name } ' to be created."
267- assert client .has_collection (collection_name ), collection_error
268-
269- # Gather all fields we have excluding 'sparse_embedding_bm25' special field.
270- fields = list (map (lambda field : field .name , field_schemas ))
271-
272- # Prep data for indexing. Currently we can't insert sparse vectors for BM25
273- # sparse embedding field as it would be automatically generated by Milvus
274- # through the registered BM25 function.
275- data_ready_to_index = []
276- for doc in MILVUS_IT_CONFIG ["corpus" ]:
277- item = {}
278- for field in fields :
279- if field .startswith ("dense_embedding" ):
280- item [field ] = doc ["dense_embedding" ]
281- elif field == "sparse_embedding_inner_product" :
282- item [field ] = doc ["sparse_embedding" ]
283- elif field == "sparse_embedding_bm25" :
284- # It is automatically generated by Milvus from the content field.
285- continue
286- else :
287- item [field ] = doc [field ]
288- data_ready_to_index .append (item )
289-
290- # Index data.
291- result = client .insert (
292- collection_name = collection_name , data = data_ready_to_index )
293-
294- # Assert that the intended data has been properly indexed.
295- insertion_err = f'failed to insert the { result ["insert_count" ]} data points'
296- assert result ["insert_count" ] == len (data_ready_to_index ), insertion_err
297-
298- # Release the collection from memory. It will be loaded lazily when the
299- # enrichment handler is invoked.
300- client .release_collection (collection_name )
301-
302- # Close the connection to the Milvus database, as no further preparation
303- # operations are needed before executing the enrichment handler.
304- client .close ()
305-
306- return collection_name
307-
308-
309238@pytest .mark .require_docker_in_docker
310239@unittest .skipUnless (
311240 platform .system () == "Linux" ,
@@ -329,7 +258,8 @@ def setUpClass(cls):
329258 db_name = cls ._db .id ,
330259 token = cls ._db .token )
331260 cls ._collection_load_params = MilvusCollectionLoadParameters ()
332- cls ._collection_name = initialize_db_with_data (cls ._connection_params )
261+ cls ._collection_name = initialize_db_with_data (
262+ cls ._connection_params , MILVUS_IT_CONFIG )
333263
334264 @classmethod
335265 def tearDownClass (cls ):
0 commit comments