Skip to content

Commit ee48e71

Browse files
[3/3] sdks/python: enrich data with Milvus Search [Vector, Keyword, Hybrid] (#35467)
* exmaples+website+sdks/python: update docs and exmaples for milvus transform * examples: update jupyter notebook example * CHANGES.md: add release note * sdks/python: update import err exception * sdks/python: experiment with setting milvus as extra dependency this way * sdks/python: revert pytest marker to use test containers * .github: trigger postcommit python * sdks/python: undo `require_docker_in_docker` pytest marker * sdks/python: fix formatting issues * python: mark `test_enrichment_with_milvus` with require_docker_in_docker * sdks/python: test milvus example * sdks/python: update jupyter notebook example * CHANGES.md: update release notes * sdks/python: fix linting issues * sdks/python: properly skip milvus test on any container startup failures * sdks/python: properly skip sql tests on any container startup failure * sdks/python: fix linting issues * examples: address comments on milvus jupyter notebook * ml/rag: enforce running etcd in milvus itests in standalone mode * examples: update jupyter notebook mainly to pin milvus db version * website: remove `Related transforms` section * sdks/python: pin milvus db version in py examples * sdks/python: skip validation if there's no enrichment data * sdks/python: pin milvus db version `v2.5.10` I have tested milvus db versions `v2.6.X` and it's not working given we need to update pymilvus client to match that change as well. Updating pymilvus to `v2.6.X` would cause compatibility beam issues with existing grpc-related packages so it may not be the most feasible upgrade to do in the meantime * milvus: add descriptive comments about updating db version in tests
1 parent e081879 commit ee48e71

File tree

9 files changed

+2951
-51
lines changed

9 files changed

+2951
-51
lines changed

CHANGES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
## New Features / Improvements
7474

7575
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
76+
* Python examples added for Milvus search enrichment handler on [Beam Website](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment-milvus/)
77+
including jupyter notebook example (Python) ([#36176](https://github.com/apache/beam/issues/36176)).
7678

7779
## Breaking Changes
7880

examples/notebooks/beam-ml/milvus_enrichment_transform.ipynb

Lines changed: 2657 additions & 0 deletions
Large diffs are not rendered by default.

sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,15 @@ def enrichment_with_google_cloudsql_pg():
156156
where_clause_template=where_clause_template,
157157
where_clause_fields=where_clause_fields)
158158

159-
cloudsql_handler = CloudSQLEnrichmentHandler(
159+
handler = CloudSQLEnrichmentHandler(
160160
connection_config=connection_config,
161161
table_id=table_id,
162162
query_config=query_config)
163163
with beam.Pipeline() as p:
164164
_ = (
165165
p
166166
| "Create" >> beam.Create(data)
167-
|
168-
"Enrich W/ Google CloudSQL PostgreSQL" >> Enrichment(cloudsql_handler)
167+
| "Enrich W/ Google CloudSQL PostgreSQL" >> Enrichment(handler)
169168
| "Print" >> beam.Map(print))
170169
# [END enrichment_with_google_cloudsql_pg]
171170

@@ -327,3 +326,75 @@ def enrichment_with_external_sqlserver():
327326
| "Enrich W/ Unmanaged SQL Server" >> Enrichment(cloudsql_handler)
328327
| "Print" >> beam.Map(print))
329328
# [END enrichment_with_external_sqlserver]
329+
330+
331+
def enrichment_with_milvus():
332+
# [START enrichment_with_milvus]
333+
import os
334+
import apache_beam as beam
335+
from apache_beam.ml.rag.types import Content
336+
from apache_beam.ml.rag.types import Chunk
337+
from apache_beam.ml.rag.types import Embedding
338+
from apache_beam.transforms.enrichment import Enrichment
339+
from apache_beam.ml.rag.enrichment.milvus_search import (
340+
MilvusSearchEnrichmentHandler,
341+
MilvusConnectionParameters,
342+
MilvusSearchParameters,
343+
MilvusCollectionLoadParameters,
344+
VectorSearchParameters,
345+
VectorSearchMetrics)
346+
347+
uri = os.environ.get("MILVUS_VECTOR_DB_URI")
348+
user = os.environ.get("MILVUS_VECTOR_DB_USER")
349+
password = os.environ.get("MILVUS_VECTOR_DB_PASSWORD")
350+
db_id = os.environ.get("MILVUS_VECTOR_DB_ID")
351+
token = os.environ.get("MILVUS_VECTOR_DB_TOKEN")
352+
collection_name = os.environ.get("MILVUS_VECTOR_DB_COLLECTION_NAME")
353+
354+
data = [
355+
Chunk(
356+
id="query1",
357+
embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]),
358+
content=Content())
359+
]
360+
361+
connection_parameters = MilvusConnectionParameters(
362+
uri, user, password, db_id, token)
363+
364+
# The first condition (language == "en") excludes documents in other
365+
# languages. Initially, this gives us two documents. After applying the second
366+
# condition (cost < 50), only the first document returns in search results.
367+
filter_expr = 'metadata["language"] == "en" AND cost < 50'
368+
369+
search_params = {"metric_type": VectorSearchMetrics.COSINE.value, "nprobe": 1}
370+
371+
vector_search_params = VectorSearchParameters(
372+
anns_field="dense_embedding_cosine",
373+
limit=3,
374+
filter=filter_expr,
375+
search_params=search_params)
376+
377+
search_parameters = MilvusSearchParameters(
378+
collection_name=collection_name,
379+
search_strategy=vector_search_params,
380+
output_fields=["id", "content", "domain", "cost", "metadata"],
381+
round_decimal=2)
382+
383+
# The collection load parameters are optional. They provide fine-graine
384+
# control over how collections are loaded into memory. For simple use cases or
385+
# when getting started, this parameter can be omitted to use default loading
386+
# behavior. Consider using it in resource-constrained environments to optimize
387+
# memory usage and query performance.
388+
collection_load_parameters = MilvusCollectionLoadParameters()
389+
390+
milvus_search_handler = MilvusSearchEnrichmentHandler(
391+
connection_parameters=connection_parameters,
392+
search_parameters=search_parameters,
393+
collection_load_parameters=collection_load_parameters)
394+
with beam.Pipeline() as p:
395+
_ = (
396+
p
397+
| "Create" >> beam.Create(data)
398+
| "Enrich W/ Milvus" >> Enrichment(milvus_search_handler)
399+
| "Print" >> beam.Map(print))
400+
# [END enrichment_with_milvus]

sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py

Lines changed: 133 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
enrichment_with_google_cloudsql_pg,
4343
enrichment_with_external_pg,
4444
enrichment_with_external_mysql,
45-
enrichment_with_external_sqlserver)
45+
enrichment_with_external_sqlserver,
46+
enrichment_with_milvus)
4647
from apache_beam.transforms.enrichment_handlers.cloudsql import (
4748
DatabaseTypeAdapter)
4849
from apache_beam.transforms.enrichment_handlers.cloudsql_it_test import (
@@ -51,9 +52,21 @@
5152
ConnectionConfig,
5253
CloudSQLConnectionConfig,
5354
ExternalSQLDBConnectionConfig)
55+
from apache_beam.ml.rag.enrichment.milvus_search import (
56+
MilvusConnectionParameters)
57+
from apache_beam.ml.rag.enrichment.milvus_search_it_test import (
58+
MilvusEnrichmentTestHelper,
59+
MilvusDBContainerInfo,
60+
parse_chunk_strings,
61+
assert_chunks_equivalent)
5462
from apache_beam.io.requestresponse import RequestResponseIO
5563
except ImportError as e:
56-
raise unittest.SkipTest(f'RequestResponseIO dependencies not installed: {e}')
64+
raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}')
65+
66+
67+
class TestContainerStartupError(Exception):
68+
"""Raised when any test container fails to start."""
69+
pass
5770

5871

5972
def validate_enrichment_with_bigtable():
@@ -119,6 +132,13 @@ def validate_enrichment_with_external_sqlserver():
119132
return expected
120133

121134

135+
def validate_enrichment_with_milvus():
136+
expected = '''[START enrichment_with_milvus]
137+
Chunk(content=Content(text=None), id='query1', index=0, metadata={'enrichment_data': defaultdict(<class 'list'>, {'id': [1], 'distance': [1.0], 'fields': [{'content': 'This is a test document', 'cost': 49, 'domain': 'medical', 'id': 1, 'metadata': {'language': 'en'}}]})}, embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3], sparse_embedding=None))
138+
[END enrichment_with_milvus]'''.splitlines()[1:-1]
139+
return expected
140+
141+
122142
@mock.patch('sys.stdout', new_callable=StringIO)
123143
@pytest.mark.uses_testcontainer
124144
class EnrichmentTest(unittest.TestCase):
@@ -148,48 +168,69 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
148168
os.environ.get('ALLOYDB_PASSWORD'),
149169
"ALLOYDB_PASSWORD environment var is not provided")
150170
def test_enrichment_with_google_cloudsql_pg(self, mock_stdout):
151-
db_adapter = DatabaseTypeAdapter.POSTGRESQL
152-
with EnrichmentTestHelpers.sql_test_context(True, db_adapter):
153-
try:
171+
try:
172+
db_adapter = DatabaseTypeAdapter.POSTGRESQL
173+
with EnrichmentTestHelpers.sql_test_context(True, db_adapter):
154174
enrichment_with_google_cloudsql_pg()
155175
output = mock_stdout.getvalue().splitlines()
156176
expected = validate_enrichment_with_google_cloudsql_pg()
157177
self.assertEqual(output, expected)
158-
except Exception as e:
159-
self.fail(f"Test failed with unexpected error: {e}")
178+
except Exception as e:
179+
self.fail(f"Test failed with unexpected error: {e}")
160180

161181
def test_enrichment_with_external_pg(self, mock_stdout):
162-
db_adapter = DatabaseTypeAdapter.POSTGRESQL
163-
with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
164-
try:
182+
try:
183+
db_adapter = DatabaseTypeAdapter.POSTGRESQL
184+
with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
165185
enrichment_with_external_pg()
166186
output = mock_stdout.getvalue().splitlines()
167187
expected = validate_enrichment_with_external_pg()
168188
self.assertEqual(output, expected)
169-
except Exception as e:
170-
self.fail(f"Test failed with unexpected error: {e}")
189+
except TestContainerStartupError as e:
190+
raise unittest.SkipTest(str(e))
191+
except Exception as e:
192+
self.fail(f"Test failed with unexpected error: {e}")
171193

172194
def test_enrichment_with_external_mysql(self, mock_stdout):
173-
db_adapter = DatabaseTypeAdapter.MYSQL
174-
with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
175-
try:
195+
try:
196+
db_adapter = DatabaseTypeAdapter.MYSQL
197+
with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
176198
enrichment_with_external_mysql()
177199
output = mock_stdout.getvalue().splitlines()
178200
expected = validate_enrichment_with_external_mysql()
179201
self.assertEqual(output, expected)
180-
except Exception as e:
181-
self.fail(f"Test failed with unexpected error: {e}")
202+
except TestContainerStartupError as e:
203+
raise unittest.SkipTest(str(e))
204+
except Exception as e:
205+
self.fail(f"Test failed with unexpected error: {e}")
182206

183207
def test_enrichment_with_external_sqlserver(self, mock_stdout):
184-
db_adapter = DatabaseTypeAdapter.SQLSERVER
185-
with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
186-
try:
208+
try:
209+
db_adapter = DatabaseTypeAdapter.SQLSERVER
210+
with EnrichmentTestHelpers.sql_test_context(False, db_adapter):
187211
enrichment_with_external_sqlserver()
188212
output = mock_stdout.getvalue().splitlines()
189213
expected = validate_enrichment_with_external_sqlserver()
190214
self.assertEqual(output, expected)
191-
except Exception as e:
192-
self.fail(f"Test failed with unexpected error: {e}")
215+
except TestContainerStartupError as e:
216+
raise unittest.SkipTest(str(e))
217+
except Exception as e:
218+
self.fail(f"Test failed with unexpected error: {e}")
219+
220+
def test_enrichment_with_milvus(self, mock_stdout):
221+
try:
222+
with EnrichmentTestHelpers.milvus_test_context():
223+
enrichment_with_milvus()
224+
output = mock_stdout.getvalue().splitlines()
225+
expected = validate_enrichment_with_milvus()
226+
self.maxDiff = None
227+
output = parse_chunk_strings(output)
228+
expected = parse_chunk_strings(expected)
229+
assert_chunks_equivalent(output, expected)
230+
except TestContainerStartupError as e:
231+
raise unittest.SkipTest(str(e))
232+
except Exception as e:
233+
self.fail(f"Test failed with unexpected error: {e}")
193234

194235

195236
@dataclass
@@ -201,6 +242,7 @@ class CloudSQLEnrichmentTestDataConstruct:
201242

202243

203244
class EnrichmentTestHelpers:
245+
@staticmethod
204246
@contextmanager
205247
def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
206248
result: Optional[CloudSQLEnrichmentTestDataConstruct] = None
@@ -212,6 +254,17 @@ def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
212254
if result:
213255
EnrichmentTestHelpers.post_sql_enrichment_test(result)
214256

257+
@staticmethod
258+
@contextmanager
259+
def milvus_test_context():
260+
db: Optional[MilvusDBContainerInfo] = None
261+
try:
262+
db = EnrichmentTestHelpers.pre_milvus_enrichment()
263+
yield
264+
finally:
265+
if db:
266+
EnrichmentTestHelpers.post_milvus_enrichment(db)
267+
215268
@staticmethod
216269
def pre_sql_enrichment_test(
217270
is_cloudsql: bool,
@@ -259,20 +312,25 @@ def pre_sql_enrichment_test(
259312
password=password,
260313
db_id=db_id)
261314
else:
262-
db = SQLEnrichmentTestHelper.start_sql_db_container(db_adapter)
263-
os.environ['EXTERNAL_SQL_DB_HOST'] = db.host
264-
os.environ['EXTERNAL_SQL_DB_PORT'] = str(db.port)
265-
os.environ['EXTERNAL_SQL_DB_ID'] = db.id
266-
os.environ['EXTERNAL_SQL_DB_USER'] = db.user
267-
os.environ['EXTERNAL_SQL_DB_PASSWORD'] = db.password
268-
os.environ['EXTERNAL_SQL_DB_TABLE_ID'] = table_id
269-
connection_config = ExternalSQLDBConnectionConfig(
270-
db_adapter=db_adapter,
271-
host=db.host,
272-
port=db.port,
273-
user=db.user,
274-
password=db.password,
275-
db_id=db.id)
315+
try:
316+
db = SQLEnrichmentTestHelper.start_sql_db_container(db_adapter)
317+
os.environ['EXTERNAL_SQL_DB_HOST'] = db.host
318+
os.environ['EXTERNAL_SQL_DB_PORT'] = str(db.port)
319+
os.environ['EXTERNAL_SQL_DB_ID'] = db.id
320+
os.environ['EXTERNAL_SQL_DB_USER'] = db.user
321+
os.environ['EXTERNAL_SQL_DB_PASSWORD'] = db.password
322+
os.environ['EXTERNAL_SQL_DB_TABLE_ID'] = table_id
323+
connection_config = ExternalSQLDBConnectionConfig(
324+
db_adapter=db_adapter,
325+
host=db.host,
326+
port=db.port,
327+
user=db.user,
328+
password=db.password,
329+
db_id=db.id)
330+
except Exception as e:
331+
db_name = db_adapter.value.lower()
332+
raise TestContainerStartupError(
333+
f"{db_name} container failed to start: {str(e)}")
276334

277335
conenctor = connection_config.get_connector_handler()
278336
engine = create_engine(
@@ -311,6 +369,45 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct):
311369
os.environ.pop('GOOGLE_CLOUD_SQL_DB_PASSWORD', None)
312370
os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)
313371

372+
@staticmethod
373+
def pre_milvus_enrichment() -> MilvusDBContainerInfo:
374+
try:
375+
db = MilvusEnrichmentTestHelper.start_db_container()
376+
except Exception as e:
377+
raise TestContainerStartupError(
378+
f"Milvus container failed to start: {str(e)}")
379+
380+
connection_params = MilvusConnectionParameters(
381+
uri=db.uri,
382+
user=db.user,
383+
password=db.password,
384+
db_id=db.id,
385+
token=db.token)
386+
387+
collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
388+
connection_params)
389+
390+
# Setup environment variables for db and collection configuration. This will
391+
# be used downstream by the milvus enrichment handler.
392+
os.environ['MILVUS_VECTOR_DB_URI'] = db.uri
393+
os.environ['MILVUS_VECTOR_DB_USER'] = db.user
394+
os.environ['MILVUS_VECTOR_DB_PASSWORD'] = db.password
395+
os.environ['MILVUS_VECTOR_DB_ID'] = db.id
396+
os.environ['MILVUS_VECTOR_DB_TOKEN'] = db.token
397+
os.environ['MILVUS_VECTOR_DB_COLLECTION_NAME'] = collection_name
398+
399+
return db
400+
401+
@staticmethod
402+
def post_milvus_enrichment(db: MilvusDBContainerInfo):
403+
MilvusEnrichmentTestHelper.stop_db_container(db)
404+
os.environ.pop('MILVUS_VECTOR_DB_URI', None)
405+
os.environ.pop('MILVUS_VECTOR_DB_USER', None)
406+
os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None)
407+
os.environ.pop('MILVUS_VECTOR_DB_ID', None)
408+
os.environ.pop('MILVUS_VECTOR_DB_TOKEN', None)
409+
os.environ.pop('MILVUS_VECTOR_DB_COLLECTION_NAME', None)
410+
314411

315412
if __name__ == '__main__':
316413
unittest.main()

0 commit comments

Comments
 (0)