2020import unittest
2121from dataclasses import dataclass
2222from dataclasses import field
23- from typing import Callable
2423from typing import Dict
25- from typing import List
26- from typing import cast
2724
2825import pytest
2926
3027import apache_beam as beam
3128from apache_beam .ml .rag .types import Chunk
3229from apache_beam .ml .rag .types import Content
3330from apache_beam .ml .rag .types import Embedding
34- from apache_beam .ml .rag .utils import retry_with_backoff
3531from apache_beam .testing .test_pipeline import TestPipeline
3632from apache_beam .testing .util import assert_that
3733
3834# pylint: disable=ungrouped-imports
3935try :
40- from pymilvus import CollectionSchema
4136 from pymilvus import DataType
4237 from pymilvus import FieldSchema
4338 from pymilvus import Function
4439 from pymilvus import FunctionType
45- from pymilvus import MilvusClient
4640 from pymilvus import RRFRanker
47- from pymilvus .exceptions import MilvusException
4841 from pymilvus .milvus_client import IndexParams
4942
50- from apache_beam .transforms .enrichment import Enrichment
51- from apache_beam .ml .rag .test_utils import MilvusTestHelpers
52- from apache_beam .ml .rag .test_utils import VectorDBContainerInfo
5343 from apache_beam .ml .rag .enrichment .milvus_search import HybridSearchParameters
5444 from apache_beam .ml .rag .enrichment .milvus_search import KeywordSearchMetrics
5545 from apache_beam .ml .rag .enrichment .milvus_search import KeywordSearchParameters
5949 from apache_beam .ml .rag .enrichment .milvus_search import MilvusSearchParameters
6050 from apache_beam .ml .rag .enrichment .milvus_search import VectorSearchMetrics
6151 from apache_beam .ml .rag .enrichment .milvus_search import VectorSearchParameters
52+ from apache_beam .ml .rag .test_utils import MilvusTestHelpers
53+ from apache_beam .ml .rag .test_utils import VectorDBContainerInfo
54+ from apache_beam .transforms .enrichment import Enrichment
6255except ImportError as e :
6356 raise unittest .SkipTest (f'Milvus dependencies not installed: { str (e )} ' )
6457
65- _LOGGER = logging .getLogger (__name__ )
66-
67-
6858def _construct_index_params ():
6959 index_params = IndexParams ()
7060
@@ -235,77 +225,6 @@ def __getitem__(self, key):
235225}
236226
237227
238- def initialize_db_with_data (connc_params : MilvusConnectionParameters ):
239- # Open the connection to the milvus db with retry.
240- def create_client ():
241- return MilvusClient (** connc_params .__dict__ )
242-
243- client = retry_with_backoff (
244- create_client ,
245- max_retries = 3 ,
246- retry_delay = 1.0 ,
247- operation_name = "Test Milvus client connection" ,
248- exception_types = (MilvusException , ))
249-
250- # Configure schema.
251- field_schemas : List [FieldSchema ] = cast (
252- List [FieldSchema ], MILVUS_IT_CONFIG ["fields" ])
253- schema = CollectionSchema (
254- fields = field_schemas , functions = MILVUS_IT_CONFIG ["functions" ])
255-
256- # Create collection with the schema.
257- collection_name = MILVUS_IT_CONFIG ["collection_name" ]
258- index_function : Callable [[], IndexParams ] = cast (
259- Callable [[], IndexParams ], MILVUS_IT_CONFIG ["index" ])
260- client .create_collection (
261- collection_name = collection_name ,
262- schema = schema ,
263- index_params = index_function ())
264-
265- # Assert that collection was created.
266- collection_error = f"Expected collection '{ collection_name } ' to be created."
267- assert client .has_collection (collection_name ), collection_error
268-
269- # Gather all fields we have excluding 'sparse_embedding_bm25' special field.
270- fields = list (map (lambda field : field .name , field_schemas ))
271-
272- # Prep data for indexing. Currently we can't insert sparse vectors for BM25
273- # sparse embedding field as it would be automatically generated by Milvus
274- # through the registered BM25 function.
275- data_ready_to_index = []
276- for doc in MILVUS_IT_CONFIG ["corpus" ]:
277- item = {}
278- for field in fields :
279- if field .startswith ("dense_embedding" ):
280- item [field ] = doc ["dense_embedding" ]
281- elif field == "sparse_embedding_inner_product" :
282- item [field ] = doc ["sparse_embedding" ]
283- elif field == "sparse_embedding_bm25" :
284- # It is automatically generated by Milvus from the content field.
285- continue
286- else :
287- item [field ] = doc [field ]
288- data_ready_to_index .append (item )
289-
290- # Index data.
291- result = client .insert (
292- collection_name = collection_name , data = data_ready_to_index )
293-
294- # Assert that the intended data has been properly indexed.
295- insertion_err = f'failed to insert the { result ["insert_count" ]} data points'
296- assert result ["insert_count" ] == len (data_ready_to_index ), insertion_err
297-
298- # Release the collection from memory. It will be loaded lazily when the
299- # enrichment handler is invoked.
300- client .release_collection (collection_name )
301-
302- # Close the connection to the Milvus database, as no further preparation
303- # operations are needed before executing the enrichment handler.
304- client .close ()
305-
306- return collection_name
307-
308-
309228@pytest .mark .require_docker_in_docker
310229@unittest .skipUnless (
311230 platform .system () == "Linux" ,
@@ -329,7 +248,8 @@ def setUpClass(cls):
329248 db_name = cls ._db .id ,
330249 token = cls ._db .token )
331250 cls ._collection_load_params = MilvusCollectionLoadParameters ()
332- cls ._collection_name = initialize_db_with_data (cls ._connection_params )
251+ cls ._collection_name = MilvusTestHelpers .initialize_db_with_data (
252+ cls ._connection_params , MILVUS_IT_CONFIG )
333253
334254 @classmethod
335255 def tearDownClass (cls ):
0 commit comments