Skip to content

Commit 83ebe73

Browse files
[1/3] sdks/python: refactor Milvus-related utilities as preparation step for Milvus Sink I/O integration (#35708)
* sdks/python: replace the deprecated testcontainer max tries * sdks/python: handle transient testcontainer startup/teardown errors * sdks/python: bump `testcontainers` py pkg version * sdks/python: integrate milvus sink I/O * sdks/python: fix linting issues * sdks/python: add missing apache beam liscense header for `test_utils.py` * notebooks/beam-ml: use new refactored code in milvus enrichment handler * CHANGES.md: update release notes * sdks/python: mark milvus itests with `require_docker_in_docker` marker * sdks/python: override milvus db version with the default * sdsk/python: add missing import in rag utils * sdks/python: fix linting issue * rag/ingestion/milvus_search_itest.py: ensure flushing in-memory data before querying * sdks/python: fix linting issues * sdks/python: fix formatting issues * sdks/python: fix arising linting issue * rag: reuse `retry_with_backoff` for one-time setup operations * sdks/python: fix linting issues * sdks/python: fix py docs CI issue * sdks/python: fix linting issues * sdks/python: fix linting issues * sdks/python: isolate milvus sink integration to be in follow-up PR * CHANGES.md: remove milvus from release notes in the refactoring PR * sdks/python: remove `with_sparse_embedding_spec` column specs builder In this commit, we remove that builder method to remain functional and be used in the next Milvus sink integration PR * sdks/python: fix linting issues * Revert "notebooks/beam-ml: use new refactored code in milvus enrichment handler" This reverts commit 461c8fe. * sdks/python: fix linting issues * sdks/python: fix linting issues * sdks/python: fix linting issues * sdks/python: fix linting issues * CI: fix import errors in CI * sdks/python: fix linting issues * sdks/python: fix linting issues * sdks/python: fix linting issues * sdks/python: fix linting issues
1 parent 4a59cb7 commit 83ebe73

File tree

6 files changed

+759
-556
lines changed

6 files changed

+759
-556
lines changed

sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,11 @@
5252
ConnectionConfig,
5353
CloudSQLConnectionConfig,
5454
ExternalSQLDBConnectionConfig)
55-
from apache_beam.ml.rag.enrichment.milvus_search import (
56-
MilvusConnectionParameters)
57-
from apache_beam.ml.rag.enrichment.milvus_search_it_test import (
58-
MilvusEnrichmentTestHelper,
59-
MilvusDBContainerInfo,
60-
parse_chunk_strings,
61-
assert_chunks_equivalent)
55+
from apache_beam.ml.rag.enrichment.milvus_search import MilvusConnectionParameters
56+
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
57+
from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
58+
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
59+
from apache_beam.ml.rag.utils import parse_chunk_strings
6260
from apache_beam.io.requestresponse import RequestResponseIO
6361
except ImportError as e:
6462
raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}')
@@ -69,6 +67,11 @@ class TestContainerStartupError(Exception):
6967
pass
7068

7169

70+
class TestContainerTeardownError(Exception):
71+
"""Raised when any test container fails to teardown."""
72+
pass
73+
74+
7275
def validate_enrichment_with_bigtable():
7376
expected = '''[START enrichment_with_bigtable]
7477
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
@@ -186,7 +189,7 @@ def test_enrichment_with_external_pg(self, mock_stdout):
186189
output = mock_stdout.getvalue().splitlines()
187190
expected = validate_enrichment_with_external_pg()
188191
self.assertEqual(output, expected)
189-
except TestContainerStartupError as e:
192+
except (TestContainerStartupError, TestContainerTeardownError) as e:
190193
raise unittest.SkipTest(str(e))
191194
except Exception as e:
192195
self.fail(f"Test failed with unexpected error: {e}")
@@ -199,7 +202,7 @@ def test_enrichment_with_external_mysql(self, mock_stdout):
199202
output = mock_stdout.getvalue().splitlines()
200203
expected = validate_enrichment_with_external_mysql()
201204
self.assertEqual(output, expected)
202-
except TestContainerStartupError as e:
205+
except (TestContainerStartupError, TestContainerTeardownError) as e:
203206
raise unittest.SkipTest(str(e))
204207
except Exception as e:
205208
self.fail(f"Test failed with unexpected error: {e}")
@@ -212,7 +215,7 @@ def test_enrichment_with_external_sqlserver(self, mock_stdout):
212215
output = mock_stdout.getvalue().splitlines()
213216
expected = validate_enrichment_with_external_sqlserver()
214217
self.assertEqual(output, expected)
215-
except TestContainerStartupError as e:
218+
except (TestContainerStartupError, TestContainerTeardownError) as e:
216219
raise unittest.SkipTest(str(e))
217220
except Exception as e:
218221
self.fail(f"Test failed with unexpected error: {e}")
@@ -226,8 +229,8 @@ def test_enrichment_with_milvus(self, mock_stdout):
226229
self.maxDiff = None
227230
output = parse_chunk_strings(output)
228231
expected = parse_chunk_strings(expected)
229-
assert_chunks_equivalent(output, expected)
230-
except TestContainerStartupError as e:
232+
MilvusTestHelpers.assert_chunks_equivalent(output, expected)
233+
except (TestContainerStartupError, TestContainerTeardownError) as e:
231234
raise unittest.SkipTest(str(e))
232235
except Exception as e:
233236
self.fail(f"Test failed with unexpected error: {e}")
@@ -257,7 +260,7 @@ def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
257260
@staticmethod
258261
@contextmanager
259262
def milvus_test_context():
260-
db: Optional[MilvusDBContainerInfo] = None
263+
db: Optional[VectorDBContainerInfo] = None
261264
try:
262265
db = EnrichmentTestHelpers.pre_milvus_enrichment()
263266
yield
@@ -370,23 +373,21 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct):
370373
os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)
371374

372375
@staticmethod
373-
def pre_milvus_enrichment() -> MilvusDBContainerInfo:
376+
def pre_milvus_enrichment() -> VectorDBContainerInfo:
374377
try:
375-
db = MilvusEnrichmentTestHelper.start_db_container()
378+
db = MilvusTestHelpers.start_db_container()
379+
connection_params = MilvusConnectionParameters(
380+
uri=db.uri,
381+
user=db.user,
382+
password=db.password,
383+
db_id=db.id,
384+
token=db.token)
385+
collection_name = MilvusTestHelpers.initialize_db_with_data(
386+
connection_params)
376387
except Exception as e:
377388
raise TestContainerStartupError(
378389
f"Milvus container failed to start: {str(e)}")
379390

380-
connection_params = MilvusConnectionParameters(
381-
uri=db.uri,
382-
user=db.user,
383-
password=db.password,
384-
db_id=db.id,
385-
token=db.token)
386-
387-
collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
388-
connection_params)
389-
390391
# Setup environment variables for db and collection configuration. This will
391392
# be used downstream by the milvus enrichment handler.
392393
os.environ['MILVUS_VECTOR_DB_URI'] = db.uri
@@ -399,8 +400,13 @@ def pre_milvus_enrichment() -> MilvusDBContainerInfo:
399400
return db
400401

401402
@staticmethod
402-
def post_milvus_enrichment(db: MilvusDBContainerInfo):
403-
MilvusEnrichmentTestHelper.stop_db_container(db)
403+
def post_milvus_enrichment(db: VectorDBContainerInfo):
404+
try:
405+
MilvusTestHelpers.stop_db_container(db)
406+
except Exception as e:
407+
raise TestContainerTeardownError(
408+
f"Milvus container failed to tear down: {str(e)}")
409+
404410
os.environ.pop('MILVUS_VECTOR_DB_URI', None)
405411
os.environ.pop('MILVUS_VECTOR_DB_USER', None)
406412
os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None)

sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py

Lines changed: 38 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,14 @@
3232
from pymilvus import Hits
3333
from pymilvus import MilvusClient
3434
from pymilvus import SearchResult
35+
from pymilvus.exceptions import MilvusException
3536

3637
from apache_beam.ml.rag.types import Chunk
3738
from apache_beam.ml.rag.types import Embedding
39+
from apache_beam.ml.rag.utils import MilvusConnectionParameters
40+
from apache_beam.ml.rag.utils import MilvusHelpers
41+
from apache_beam.ml.rag.utils import retry_with_backoff
42+
from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
3843
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
3944

4045

@@ -104,44 +109,6 @@ def __str__(self):
104109
return self.dict().__str__()
105110

106111

107-
@dataclass
108-
class MilvusConnectionParameters:
109-
"""Parameters for establishing connections to Milvus servers.
110-
111-
Args:
112-
uri: URI endpoint for connecting to Milvus server in the format
113-
"http(s)://hostname:port".
114-
user: Username for authentication. Required if authentication is enabled and
115-
not using token authentication.
116-
password: Password for authentication. Required if authentication is enabled
117-
and not using token authentication.
118-
db_id: Database ID to connect to. Specifies which Milvus database to use.
119-
Defaults to 'default'.
120-
token: Authentication token as an alternative to username/password.
121-
timeout: Connection timeout in seconds. Uses client default if None.
122-
max_retries: Maximum number of connection retry attempts. Defaults to 3.
123-
retry_delay: Initial delay between retries in seconds. Defaults to 1.0.
124-
retry_backoff_factor: Multiplier for retry delay after each attempt.
125-
Defaults to 2.0 (exponential backoff).
126-
kwargs: Optional keyword arguments for additional connection parameters.
127-
Enables forward compatibility.
128-
"""
129-
uri: str
130-
user: str = field(default_factory=str)
131-
password: str = field(default_factory=str)
132-
db_id: str = "default"
133-
token: str = field(default_factory=str)
134-
timeout: Optional[float] = None
135-
max_retries: int = 3
136-
retry_delay: float = 1.0
137-
retry_backoff_factor: float = 2.0
138-
kwargs: Dict[str, Any] = field(default_factory=dict)
139-
140-
def __post_init__(self):
141-
if not self.uri:
142-
raise ValueError("URI must be provided for Milvus connection")
143-
144-
145112
@dataclass
146113
class BaseSearchParameters:
147114
"""Base parameters for both vector and keyword search operations.
@@ -361,15 +328,15 @@ def __init__(
361328
**kwargs):
362329
"""
363330
Example Usage:
364-
connection_paramters = MilvusConnectionParameters(
331+
connection_parameters = MilvusConnectionParameters(
365332
uri="http://localhost:19530")
366333
search_parameters = MilvusSearchParameters(
367334
collection_name="my_collection",
368335
search_strategy=VectorSearchParameters(anns_field="embedding"))
369336
collection_load_parameters = MilvusCollectionLoadParameters(
370337
load_fields=["embedding", "metadata"]),
371338
milvus_handler = MilvusSearchEnrichmentHandler(
372-
connection_paramters,
339+
connection_parameters,
373340
search_parameters,
374341
collection_load_parameters=collection_load_parameters,
375342
min_batch_size=10,
@@ -407,52 +374,43 @@ def __init__(
407374
'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size
408375
}
409376
self.kwargs = kwargs
377+
self._client = None
410378
self.join_fn = join_fn
411379
self.use_custom_types = True
412380

413381
def __enter__(self):
414-
import logging
415-
import time
416-
417-
from pymilvus.exceptions import MilvusException
418-
419-
connection_params = unpack_dataclass_with_kwargs(
420-
self._connection_parameters)
421-
collection_load_params = unpack_dataclass_with_kwargs(
422-
self._collection_load_parameters)
423-
424-
# Extract retry parameters from connection_params
425-
max_retries = connection_params.pop('max_retries', 3)
426-
retry_delay = connection_params.pop('retry_delay', 1.0)
427-
retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
428-
429-
# Retry logic for MilvusClient connection
430-
last_exception = None
431-
for attempt in range(max_retries + 1):
432-
try:
433-
self._client = MilvusClient(**connection_params)
434-
self._client.load_collection(
382+
"""Enters the context manager and establishes Milvus connection.
383+
384+
Returns:
385+
Self, enabling use in 'with' statements.
386+
"""
387+
if not self._client:
388+
connection_params = unpack_dataclass_with_kwargs(
389+
self._connection_parameters)
390+
collection_load_params = unpack_dataclass_with_kwargs(
391+
self._collection_load_parameters)
392+
393+
# Extract retry parameters from connection_params.
394+
max_retries = connection_params.pop('max_retries', 3)
395+
retry_delay = connection_params.pop('retry_delay', 1.0)
396+
retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)
397+
398+
def connect_and_load():
399+
client = MilvusClient(**connection_params)
400+
client.load_collection(
435401
collection_name=self.collection_name,
436402
partition_names=self.partition_names,
437403
**collection_load_params)
438-
logging.info(
439-
"Successfully connected to Milvus on attempt %d", attempt + 1)
440-
return
441-
except MilvusException as e:
442-
last_exception = e
443-
if attempt < max_retries:
444-
delay = retry_delay * (retry_backoff_factor**attempt)
445-
logging.warning(
446-
"Milvus connection attempt %d failed: %s. "
447-
"Retrying in %.2f seconds...",
448-
attempt + 1,
449-
e,
450-
delay)
451-
time.sleep(delay)
452-
else:
453-
logging.error(
454-
"Failed to connect to Milvus after %d attempts", max_retries + 1)
455-
raise last_exception
404+
return client
405+
406+
self._client = retry_with_backoff(
407+
connect_and_load,
408+
max_retries=max_retries,
409+
retry_delay=retry_delay,
410+
retry_backoff_factor=retry_backoff_factor,
411+
operation_name="Milvus connection and collection load",
412+
exception_types=(MilvusException, ))
413+
return self
456414

457415
def __call__(self, request: Union[Chunk, List[Chunk]], *args,
458416
**kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]:
@@ -535,10 +493,7 @@ def _get_keyword_search_data(self, chunk: Chunk):
535493
raise ValueError(
536494
f"Chunk {chunk.id} missing both text content and sparse embedding "
537495
"required for keyword search")
538-
539-
sparse_embedding = self.convert_sparse_embedding_to_milvus_format(
540-
chunk.sparse_embedding)
541-
496+
sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding)
542497
return chunk.content.text or sparse_embedding
543498

544499
def _get_call_response(
@@ -628,15 +583,3 @@ def batch_elements_kwargs(self) -> Dict[str, int]:
628583
def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding:
629584
left.metadata['enrichment_data'] = right
630585
return left
631-
632-
633-
def unpack_dataclass_with_kwargs(dataclass_instance):
634-
# Create a copy of the dataclass's __dict__.
635-
params_dict: dict = dataclass_instance.__dict__.copy()
636-
637-
# Extract the nested kwargs dictionary.
638-
nested_kwargs = params_dict.pop('kwargs', {})
639-
640-
# Merge the dictionaries, with nested_kwargs taking precedence
641-
# in case of duplicate keys.
642-
return {**params_dict, **nested_kwargs}

0 commit comments

Comments
 (0)