4242 enrichment_with_google_cloudsql_pg ,
4343 enrichment_with_external_pg ,
4444 enrichment_with_external_mysql ,
45- enrichment_with_external_sqlserver )
45+ enrichment_with_external_sqlserver ,
46+ enrichment_with_milvus )
4647 from apache_beam .transforms .enrichment_handlers .cloudsql import (
4748 DatabaseTypeAdapter )
4849 from apache_beam .transforms .enrichment_handlers .cloudsql_it_test import (
5152 ConnectionConfig ,
5253 CloudSQLConnectionConfig ,
5354 ExternalSQLDBConnectionConfig )
55+ from apache_beam .ml .rag .enrichment .milvus_search import (
56+ MilvusConnectionParameters )
57+ from apache_beam .ml .rag .enrichment .milvus_search_it_test import (
58+ MilvusEnrichmentTestHelper ,
59+ MilvusDBContainerInfo ,
60+ parse_chunk_strings ,
61+ assert_chunks_equivalent )
5462 from apache_beam .io .requestresponse import RequestResponseIO
5563except ImportError as e :
56- raise unittest .SkipTest (f'RequestResponseIO dependencies not installed: { e } ' )
64+ raise unittest .SkipTest (f'Examples dependencies are not installed: { str (e )} ' )
65+
66+
67+ class TestContainerStartupError (Exception ):
68+ """Raised when any test container fails to start."""
69+ pass
5770
5871
5972def validate_enrichment_with_bigtable ():
@@ -119,6 +132,13 @@ def validate_enrichment_with_external_sqlserver():
119132 return expected
120133
121134
135+ def validate_enrichment_with_milvus ():
136+ expected = '''[START enrichment_with_milvus]
137+ Chunk(content=Content(text=None), id='query1', index=0, metadata={'enrichment_data': defaultdict(<class 'list'>, {'id': [1], 'distance': [1.0], 'fields': [{'content': 'This is a test document', 'cost': 49, 'domain': 'medical', 'id': 1, 'metadata': {'language': 'en'}}]})}, embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3], sparse_embedding=None))
138+ [END enrichment_with_milvus]''' .splitlines ()[1 :- 1 ]
139+ return expected
140+
141+
122142@mock .patch ('sys.stdout' , new_callable = StringIO )
123143@pytest .mark .uses_testcontainer
124144class EnrichmentTest (unittest .TestCase ):
@@ -148,48 +168,69 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
148168 os .environ .get ('ALLOYDB_PASSWORD' ),
149169 "ALLOYDB_PASSWORD environment var is not provided" )
150170 def test_enrichment_with_google_cloudsql_pg (self , mock_stdout ):
151- db_adapter = DatabaseTypeAdapter . POSTGRESQL
152- with EnrichmentTestHelpers . sql_test_context ( True , db_adapter ):
153- try :
171+ try :
172+ db_adapter = DatabaseTypeAdapter . POSTGRESQL
173+ with EnrichmentTestHelpers . sql_test_context ( True , db_adapter ) :
154174 enrichment_with_google_cloudsql_pg ()
155175 output = mock_stdout .getvalue ().splitlines ()
156176 expected = validate_enrichment_with_google_cloudsql_pg ()
157177 self .assertEqual (output , expected )
158- except Exception as e :
159- self .fail (f"Test failed with unexpected error: { e } " )
178+ except Exception as e :
179+ self .fail (f"Test failed with unexpected error: { e } " )
160180
161181 def test_enrichment_with_external_pg (self , mock_stdout ):
162- db_adapter = DatabaseTypeAdapter . POSTGRESQL
163- with EnrichmentTestHelpers . sql_test_context ( False , db_adapter ):
164- try :
182+ try :
183+ db_adapter = DatabaseTypeAdapter . POSTGRESQL
184+ with EnrichmentTestHelpers . sql_test_context ( False , db_adapter ) :
165185 enrichment_with_external_pg ()
166186 output = mock_stdout .getvalue ().splitlines ()
167187 expected = validate_enrichment_with_external_pg ()
168188 self .assertEqual (output , expected )
169- except Exception as e :
170- self .fail (f"Test failed with unexpected error: { e } " )
189+ except TestContainerStartupError as e :
190+ raise unittest .SkipTest (str (e ))
191+ except Exception as e :
192+ self .fail (f"Test failed with unexpected error: { e } " )
171193
172194 def test_enrichment_with_external_mysql (self , mock_stdout ):
173- db_adapter = DatabaseTypeAdapter . MYSQL
174- with EnrichmentTestHelpers . sql_test_context ( False , db_adapter ):
175- try :
195+ try :
196+ db_adapter = DatabaseTypeAdapter . MYSQL
197+ with EnrichmentTestHelpers . sql_test_context ( False , db_adapter ) :
176198 enrichment_with_external_mysql ()
177199 output = mock_stdout .getvalue ().splitlines ()
178200 expected = validate_enrichment_with_external_mysql ()
179201 self .assertEqual (output , expected )
180- except Exception as e :
181- self .fail (f"Test failed with unexpected error: { e } " )
202+ except TestContainerStartupError as e :
203+ raise unittest .SkipTest (str (e ))
204+ except Exception as e :
205+ self .fail (f"Test failed with unexpected error: { e } " )
182206
183207 def test_enrichment_with_external_sqlserver (self , mock_stdout ):
184- db_adapter = DatabaseTypeAdapter . SQLSERVER
185- with EnrichmentTestHelpers . sql_test_context ( False , db_adapter ):
186- try :
208+ try :
209+ db_adapter = DatabaseTypeAdapter . SQLSERVER
210+ with EnrichmentTestHelpers . sql_test_context ( False , db_adapter ) :
187211 enrichment_with_external_sqlserver ()
188212 output = mock_stdout .getvalue ().splitlines ()
189213 expected = validate_enrichment_with_external_sqlserver ()
190214 self .assertEqual (output , expected )
191- except Exception as e :
192- self .fail (f"Test failed with unexpected error: { e } " )
215+ except TestContainerStartupError as e :
216+ raise unittest .SkipTest (str (e ))
217+ except Exception as e :
218+ self .fail (f"Test failed with unexpected error: { e } " )
219+
220+ def test_enrichment_with_milvus (self , mock_stdout ):
221+ try :
222+ with EnrichmentTestHelpers .milvus_test_context ():
223+ enrichment_with_milvus ()
224+ output = mock_stdout .getvalue ().splitlines ()
225+ expected = validate_enrichment_with_milvus ()
226+ self .maxDiff = None
227+ output = parse_chunk_strings (output )
228+ expected = parse_chunk_strings (expected )
229+ assert_chunks_equivalent (output , expected )
230+ except TestContainerStartupError as e :
231+ raise unittest .SkipTest (str (e ))
232+ except Exception as e :
233+ self .fail (f"Test failed with unexpected error: { e } " )
193234
194235
195236@dataclass
@@ -201,6 +242,7 @@ class CloudSQLEnrichmentTestDataConstruct:
201242
202243
203244class EnrichmentTestHelpers :
245+ @staticmethod
204246 @contextmanager
205247 def sql_test_context (is_cloudsql : bool , db_adapter : DatabaseTypeAdapter ):
206248 result : Optional [CloudSQLEnrichmentTestDataConstruct ] = None
@@ -212,6 +254,17 @@ def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
212254 if result :
213255 EnrichmentTestHelpers .post_sql_enrichment_test (result )
214256
257+ @staticmethod
258+ @contextmanager
259+ def milvus_test_context ():
260+ db : Optional [MilvusDBContainerInfo ] = None
261+ try :
262+ db = EnrichmentTestHelpers .pre_milvus_enrichment ()
263+ yield
264+ finally :
265+ if db :
266+ EnrichmentTestHelpers .post_milvus_enrichment (db )
267+
215268 @staticmethod
216269 def pre_sql_enrichment_test (
217270 is_cloudsql : bool ,
@@ -259,20 +312,25 @@ def pre_sql_enrichment_test(
259312 password = password ,
260313 db_id = db_id )
261314 else :
262- db = SQLEnrichmentTestHelper .start_sql_db_container (db_adapter )
263- os .environ ['EXTERNAL_SQL_DB_HOST' ] = db .host
264- os .environ ['EXTERNAL_SQL_DB_PORT' ] = str (db .port )
265- os .environ ['EXTERNAL_SQL_DB_ID' ] = db .id
266- os .environ ['EXTERNAL_SQL_DB_USER' ] = db .user
267- os .environ ['EXTERNAL_SQL_DB_PASSWORD' ] = db .password
268- os .environ ['EXTERNAL_SQL_DB_TABLE_ID' ] = table_id
269- connection_config = ExternalSQLDBConnectionConfig (
270- db_adapter = db_adapter ,
271- host = db .host ,
272- port = db .port ,
273- user = db .user ,
274- password = db .password ,
275- db_id = db .id )
315+ try :
316+ db = SQLEnrichmentTestHelper .start_sql_db_container (db_adapter )
317+ os .environ ['EXTERNAL_SQL_DB_HOST' ] = db .host
318+ os .environ ['EXTERNAL_SQL_DB_PORT' ] = str (db .port )
319+ os .environ ['EXTERNAL_SQL_DB_ID' ] = db .id
320+ os .environ ['EXTERNAL_SQL_DB_USER' ] = db .user
321+ os .environ ['EXTERNAL_SQL_DB_PASSWORD' ] = db .password
322+ os .environ ['EXTERNAL_SQL_DB_TABLE_ID' ] = table_id
323+ connection_config = ExternalSQLDBConnectionConfig (
324+ db_adapter = db_adapter ,
325+ host = db .host ,
326+ port = db .port ,
327+ user = db .user ,
328+ password = db .password ,
329+ db_id = db .id )
330+ except Exception as e :
331+ db_name = db_adapter .value .lower ()
332+ raise TestContainerStartupError (
333+ f"{ db_name } container failed to start: { str (e )} " )
276334
277335 conenctor = connection_config .get_connector_handler ()
278336 engine = create_engine (
@@ -311,6 +369,45 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct):
311369 os .environ .pop ('GOOGLE_CLOUD_SQL_DB_PASSWORD' , None )
312370 os .environ .pop ('GOOGLE_CLOUD_SQL_DB_TABLE_ID' , None )
313371
372+ @staticmethod
373+ def pre_milvus_enrichment () -> MilvusDBContainerInfo :
374+ try :
375+ db = MilvusEnrichmentTestHelper .start_db_container ()
376+ except Exception as e :
377+ raise TestContainerStartupError (
378+ f"Milvus container failed to start: { str (e )} " )
379+
380+ connection_params = MilvusConnectionParameters (
381+ uri = db .uri ,
382+ user = db .user ,
383+ password = db .password ,
384+ db_id = db .id ,
385+ token = db .token )
386+
387+ collection_name = MilvusEnrichmentTestHelper .initialize_db_with_data (
388+ connection_params )
389+
390+ # Setup environment variables for db and collection configuration. This will
391+ # be used downstream by the milvus enrichment handler.
392+ os .environ ['MILVUS_VECTOR_DB_URI' ] = db .uri
393+ os .environ ['MILVUS_VECTOR_DB_USER' ] = db .user
394+ os .environ ['MILVUS_VECTOR_DB_PASSWORD' ] = db .password
395+ os .environ ['MILVUS_VECTOR_DB_ID' ] = db .id
396+ os .environ ['MILVUS_VECTOR_DB_TOKEN' ] = db .token
397+ os .environ ['MILVUS_VECTOR_DB_COLLECTION_NAME' ] = collection_name
398+
399+ return db
400+
401+ @staticmethod
402+ def post_milvus_enrichment (db : MilvusDBContainerInfo ):
403+ MilvusEnrichmentTestHelper .stop_db_container (db )
404+ os .environ .pop ('MILVUS_VECTOR_DB_URI' , None )
405+ os .environ .pop ('MILVUS_VECTOR_DB_USER' , None )
406+ os .environ .pop ('MILVUS_VECTOR_DB_PASSWORD' , None )
407+ os .environ .pop ('MILVUS_VECTOR_DB_ID' , None )
408+ os .environ .pop ('MILVUS_VECTOR_DB_TOKEN' , None )
409+ os .environ .pop ('MILVUS_VECTOR_DB_COLLECTION_NAME' , None )
410+
314411
315412if __name__ == '__main__' :
316413 unittest .main ()
0 commit comments