diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index c8e988a52c5d..ed2b0c131e0c 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -52,13 +52,11 @@ ConnectionConfig, CloudSQLConnectionConfig, ExternalSQLDBConnectionConfig) - from apache_beam.ml.rag.enrichment.milvus_search import ( - MilvusConnectionParameters) - from apache_beam.ml.rag.enrichment.milvus_search_it_test import ( - MilvusEnrichmentTestHelper, - MilvusDBContainerInfo, - parse_chunk_strings, - assert_chunks_equivalent) + from apache_beam.ml.rag.enrichment.milvus_search import MilvusConnectionParameters + from apache_beam.ml.rag.test_utils import MilvusTestHelpers + from apache_beam.ml.rag.test_utils import VectorDBContainerInfo + from apache_beam.ml.rag.test_utils import MilvusTestHelpers + from apache_beam.ml.rag.utils import parse_chunk_strings from apache_beam.io.requestresponse import RequestResponseIO except ImportError as e: raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}') @@ -69,6 +67,11 @@ class TestContainerStartupError(Exception): pass +class TestContainerTeardownError(Exception): + """Raised when any test container fails to teardown.""" + pass + + def validate_enrichment_with_bigtable(): expected = '''[START enrichment_with_bigtable] Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) @@ -186,7 +189,7 @@ def test_enrichment_with_external_pg(self, mock_stdout): output = mock_stdout.getvalue().splitlines() expected = validate_enrichment_with_external_pg() self.assertEqual(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -199,7 +202,7 @@ def test_enrichment_with_external_mysql(self, mock_stdout): output = mock_stdout.getvalue().splitlines() expected = validate_enrichment_with_external_mysql() self.assertEqual(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -212,7 +215,7 @@ def test_enrichment_with_external_sqlserver(self, mock_stdout): output = mock_stdout.getvalue().splitlines() expected = validate_enrichment_with_external_sqlserver() self.assertEqual(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -226,8 +229,8 @@ def test_enrichment_with_milvus(self, mock_stdout): self.maxDiff = None output = parse_chunk_strings(output) expected = parse_chunk_strings(expected) - assert_chunks_equivalent(output, expected) - except TestContainerStartupError as e: + MilvusTestHelpers.assert_chunks_equivalent(output, expected) + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -257,7 +260,7 @@ def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter): @staticmethod @contextmanager def milvus_test_context(): - db: Optional[MilvusDBContainerInfo] = None + db: Optional[VectorDBContainerInfo] = None try: db = EnrichmentTestHelpers.pre_milvus_enrichment() yield @@ -370,23 +373,21 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct): os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None) @staticmethod - def pre_milvus_enrichment() -> MilvusDBContainerInfo: + def pre_milvus_enrichment() -> VectorDBContainerInfo: try: - db = MilvusEnrichmentTestHelper.start_db_container() + db = MilvusTestHelpers.start_db_container() + connection_params = MilvusConnectionParameters( + uri=db.uri, + user=db.user, + password=db.password, + db_id=db.id, + token=db.token) + collection_name = MilvusTestHelpers.initialize_db_with_data( + connection_params) except Exception as e: raise TestContainerStartupError( f"Milvus container failed to start: {str(e)}") - connection_params = MilvusConnectionParameters( - uri=db.uri, - user=db.user, - password=db.password, - db_id=db.id, - token=db.token) - - collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( - connection_params) - # Setup environment variables for db and collection configuration. This will # be used downstream by the milvus enrichment handler. os.environ['MILVUS_VECTOR_DB_URI'] = db.uri @@ -399,8 +400,13 @@ def pre_milvus_enrichment() -> MilvusDBContainerInfo: return db @staticmethod - def post_milvus_enrichment(db: MilvusDBContainerInfo): - MilvusEnrichmentTestHelper.stop_db_container(db) + def post_milvus_enrichment(db: VectorDBContainerInfo): + try: + MilvusTestHelpers.stop_db_container(db) + except Exception as e: + raise TestContainerTeardownError( + f"Milvus container failed to tear down: {str(e)}") + os.environ.pop('MILVUS_VECTOR_DB_URI', None) os.environ.pop('MILVUS_VECTOR_DB_USER', None) os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py index 8f631746748b..41355e8c10aa 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py @@ -32,9 +32,14 @@ from pymilvus import Hits from pymilvus import MilvusClient from pymilvus import SearchResult +from pymilvus.exceptions import MilvusException from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.utils import MilvusHelpers +from apache_beam.ml.rag.utils import retry_with_backoff +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs from apache_beam.transforms.enrichment import EnrichmentSourceHandler @@ -104,44 +109,6 @@ def __str__(self): return self.dict().__str__() -@dataclass -class MilvusConnectionParameters: - """Parameters for establishing connections to Milvus servers. - - Args: - uri: URI endpoint for connecting to Milvus server in the format - "http(s)://hostname:port". - user: Username for authentication. Required if authentication is enabled and - not using token authentication. - password: Password for authentication. Required if authentication is enabled - and not using token authentication. - db_id: Database ID to connect to. Specifies which Milvus database to use. - Defaults to 'default'. - token: Authentication token as an alternative to username/password. - timeout: Connection timeout in seconds. Uses client default if None. - max_retries: Maximum number of connection retry attempts. Defaults to 3. - retry_delay: Initial delay between retries in seconds. Defaults to 1.0. - retry_backoff_factor: Multiplier for retry delay after each attempt. - Defaults to 2.0 (exponential backoff). - kwargs: Optional keyword arguments for additional connection parameters. - Enables forward compatibility. - """ - uri: str - user: str = field(default_factory=str) - password: str = field(default_factory=str) - db_id: str = "default" - token: str = field(default_factory=str) - timeout: Optional[float] = None - max_retries: int = 3 - retry_delay: float = 1.0 - retry_backoff_factor: float = 2.0 - kwargs: Dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - if not self.uri: - raise ValueError("URI must be provided for Milvus connection") - - @dataclass class BaseSearchParameters: """Base parameters for both vector and keyword search operations. @@ -361,7 +328,7 @@ def __init__( **kwargs): """ Example Usage: - connection_paramters = MilvusConnectionParameters( + connection_parameters = MilvusConnectionParameters( uri="http://localhost:19530") search_parameters = MilvusSearchParameters( collection_name="my_collection", @@ -369,7 +336,7 @@ def __init__( collection_load_parameters = MilvusCollectionLoadParameters( load_fields=["embedding", "metadata"]), milvus_handler = MilvusSearchEnrichmentHandler( - connection_paramters, + connection_parameters, search_parameters, collection_load_parameters=collection_load_parameters, min_batch_size=10, @@ -407,52 +374,43 @@ def __init__( 'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size } self.kwargs = kwargs + self._client = None self.join_fn = join_fn self.use_custom_types = True def __enter__(self): - import logging - import time - - from pymilvus.exceptions import MilvusException - - connection_params = unpack_dataclass_with_kwargs( - self._connection_parameters) - collection_load_params = unpack_dataclass_with_kwargs( - self._collection_load_parameters) - - # Extract retry parameters from connection_params - max_retries = connection_params.pop('max_retries', 3) - retry_delay = connection_params.pop('retry_delay', 1.0) - retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0) - - # Retry logic for MilvusClient connection - last_exception = None - for attempt in range(max_retries + 1): - try: - self._client = MilvusClient(**connection_params) - self._client.load_collection( + """Enters the context manager and establishes Milvus connection. + + Returns: + Self, enabling use in 'with' statements. + """ + if not self._client: + connection_params = unpack_dataclass_with_kwargs( + self._connection_parameters) + collection_load_params = unpack_dataclass_with_kwargs( + self._collection_load_parameters) + + # Extract retry parameters from connection_params. + max_retries = connection_params.pop('max_retries', 3) + retry_delay = connection_params.pop('retry_delay', 1.0) + retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0) + + def connect_and_load(): + client = MilvusClient(**connection_params) + client.load_collection( collection_name=self.collection_name, partition_names=self.partition_names, **collection_load_params) - logging.info( - "Successfully connected to Milvus on attempt %d", attempt + 1) - return - except MilvusException as e: - last_exception = e - if attempt < max_retries: - delay = retry_delay * (retry_backoff_factor**attempt) - logging.warning( - "Milvus connection attempt %d failed: %s. " - "Retrying in %.2f seconds...", - attempt + 1, - e, - delay) - time.sleep(delay) - else: - logging.error( - "Failed to connect to Milvus after %d attempts", max_retries + 1) - raise last_exception + return client + + self._client = retry_with_backoff( + connect_and_load, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff_factor=retry_backoff_factor, + operation_name="Milvus connection and collection load", + exception_types=(MilvusException, )) + return self def __call__(self, request: Union[Chunk, List[Chunk]], *args, **kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]: @@ -535,10 +493,7 @@ def _get_keyword_search_data(self, chunk: Chunk): raise ValueError( f"Chunk {chunk.id} missing both text content and sparse embedding " "required for keyword search") - - sparse_embedding = self.convert_sparse_embedding_to_milvus_format( - chunk.sparse_embedding) - + sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding) return chunk.content.text or sparse_embedding def _get_call_response( @@ -628,15 +583,3 @@ def batch_elements_kwargs(self) -> Dict[str, int]: def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding: left.metadata['enrichment_data'] = right return left - - -def unpack_dataclass_with_kwargs(dataclass_instance): - # Create a copy of the dataclass's __dict__. - params_dict: dict = dataclass_instance.__dict__.copy() - - # Extract the nested kwargs dictionary. - nested_kwargs = params_dict.pop('kwargs', {}) - - # Merge the dictionaries, with nested_kwargs taking precedence - # in case of duplicate keys. - return {**params_dict, **nested_kwargs} diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py index b3a0dcd55722..34cb3f9050fc 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -15,25 +15,13 @@ # limitations under the License. # -import contextlib -import logging -import os import platform -import re -import socket -import tempfile import unittest -from collections import defaultdict from dataclasses import dataclass from dataclasses import field -from typing import Callable from typing import Dict -from typing import List -from typing import Optional -from typing import cast import pytest -import yaml import apache_beam as beam from apache_beam.ml.rag.types import Chunk @@ -44,18 +32,12 @@ # pylint: disable=ungrouped-imports try: - from pymilvus import CollectionSchema from pymilvus import DataType from pymilvus import FieldSchema from pymilvus import Function from pymilvus import FunctionType - from pymilvus import MilvusClient from pymilvus import RRFRanker from pymilvus.milvus_client import IndexParams - from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES - from testcontainers.core.config import testcontainers_config - from testcontainers.core.generic import DbContainer - from testcontainers.milvus import MilvusContainer from apache_beam.ml.rag.enrichment.milvus_search import HybridSearchParameters from apache_beam.ml.rag.enrichment.milvus_search import KeywordSearchMetrics @@ -66,12 +48,12 @@ from apache_beam.ml.rag.enrichment.milvus_search import MilvusSearchParameters from apache_beam.ml.rag.enrichment.milvus_search import VectorSearchMetrics from apache_beam.ml.rag.enrichment.milvus_search import VectorSearchParameters + from apache_beam.ml.rag.test_utils import MilvusTestHelpers + from apache_beam.ml.rag.test_utils import VectorDBContainerInfo from apache_beam.transforms.enrichment import Enrichment except ImportError as e: raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') -_LOGGER = logging.getLogger(__name__) - def _construct_index_params(): index_params = IndexParams() @@ -243,244 +225,6 @@ def __getitem__(self, key): } -@dataclass -class MilvusDBContainerInfo: - container: DbContainer - host: str - port: int - user: Optional[str] = "" - password: Optional[str] = "" - token: Optional[str] = "" - id: Optional[str] = "default" - - @property - def uri(self) -> str: - return f"http://{self.host}:{self.port}" - - -class CustomMilvusContainer(MilvusContainer): - def __init__( - self, - image: str, - service_container_port, - healthcheck_container_port, - **kwargs, - ) -> None: - # Skip the parent class's constructor and go straight to - # GenericContainer. - super(MilvusContainer, self).__init__(image=image, **kwargs) - self.port = service_container_port - self.healthcheck_port = healthcheck_container_port - self.with_exposed_ports(service_container_port, healthcheck_container_port) - - # Get free host ports. - service_host_port = MilvusEnrichmentTestHelper.find_free_port() - healthcheck_host_port = MilvusEnrichmentTestHelper.find_free_port() - - # Bind container and host ports. - self.with_bind_ports(service_container_port, service_host_port) - self.with_bind_ports(healthcheck_container_port, healthcheck_host_port) - self.cmd = "milvus run standalone" - - # Set environment variables needed for Milvus. - envs = { - "ETCD_USE_EMBED": "true", - "ETCD_DATA_DIR": "/var/lib/milvus/etcd", - "COMMON_STORAGETYPE": "local", - "METRICS_PORT": str(healthcheck_container_port) - } - for env, value in envs.items(): - self.with_env(env, value) - - -class MilvusEnrichmentTestHelper: - # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus - # Python SDK client in setup.py is updated to match. Referring to the Milvus - # release notes compatibility matrix at - # https://milvus.io/docs/release_notes.md or PyPI at - # https://pypi.org/project/pymilvus/ for version compatibility. - # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required). - @staticmethod - def start_db_container( - image="milvusdb/milvus:v2.5.10", - max_vec_fields=5, - vector_client_max_retries=3, - tc_max_retries=TC_MAX_TRIES) -> Optional[MilvusDBContainerInfo]: - service_container_port = MilvusEnrichmentTestHelper.find_free_port() - healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port() - user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml - with user_yaml_creator(service_container_port, max_vec_fields) as cfg: - info = None - testcontainers_config.max_tries = tc_max_retries - for i in range(vector_client_max_retries): - try: - vector_db_container = CustomMilvusContainer( - image=image, - service_container_port=service_container_port, - healthcheck_container_port=healthcheck_container_port) - vector_db_container = vector_db_container.with_volume_mapping( - cfg, "/milvus/configs/user.yaml") - vector_db_container.start() - host = vector_db_container.get_container_host_ip() - port = vector_db_container.get_exposed_port(service_container_port) - info = MilvusDBContainerInfo(vector_db_container, host, port) - testcontainers_config.max_tries = TC_MAX_TRIES - _LOGGER.info( - "milvus db container started successfully on %s.", info.uri) - break - except Exception as e: - stdout_logs, stderr_logs = vector_db_container.get_logs() - stdout_logs = stdout_logs.decode("utf-8") - stderr_logs = stderr_logs.decode("utf-8") - _LOGGER.warning( - "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. " - "STDOUT logs:\n%s\nSTDERR logs:\n%s", - i + 1, - vector_client_max_retries, - e, - stdout_logs, - stderr_logs) - if i == vector_client_max_retries - 1: - _LOGGER.error( - "Unable to start milvus db container for I/O tests after %d " - "retries. Tests cannot proceed. STDOUT logs:\n%s\n" - "STDERR logs:\n%s", - vector_client_max_retries, - stdout_logs, - stderr_logs) - raise e - return info - - @staticmethod - def stop_db_container(db_info: MilvusDBContainerInfo): - if db_info is None: - _LOGGER.warning("Milvus db info is None. Skipping stop operation.") - return - try: - _LOGGER.debug("Stopping milvus db container.") - db_info.container.stop() - _LOGGER.info("milvus db container stopped successfully.") - except Exception as e: - _LOGGER.warning( - "Error encountered while stopping milvus db container: %s", e) - - @staticmethod - def initialize_db_with_data(connc_params: MilvusConnectionParameters): - # Open the connection to the milvus db. - client = MilvusClient(**connc_params.__dict__) - - # Configure schema. - field_schemas: List[FieldSchema] = cast( - List[FieldSchema], MILVUS_IT_CONFIG["fields"]) - schema = CollectionSchema( - fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"]) - - # Create collection with the schema. - collection_name = MILVUS_IT_CONFIG["collection_name"] - index_function: Callable[[], IndexParams] = cast( - Callable[[], IndexParams], MILVUS_IT_CONFIG["index"]) - client.create_collection( - collection_name=collection_name, - schema=schema, - index_params=index_function()) - - # Assert that collection was created. - collection_error = f"Expected collection '{collection_name}' to be created." - assert client.has_collection(collection_name), collection_error - - # Gather all fields we have excluding 'sparse_embedding_bm25' special field. - fields = list(map(lambda field: field.name, field_schemas)) - - # Prep data for indexing. Currently we can't insert sparse vectors for BM25 - # sparse embedding field as it would be automatically generated by Milvus - # through the registered BM25 function. - data_ready_to_index = [] - for doc in MILVUS_IT_CONFIG["corpus"]: - item = {} - for field in fields: - if field.startswith("dense_embedding"): - item[field] = doc["dense_embedding"] - elif field == "sparse_embedding_inner_product": - item[field] = doc["sparse_embedding"] - elif field == "sparse_embedding_bm25": - # It is automatically generated by Milvus from the content field. - continue - else: - item[field] = doc[field] - data_ready_to_index.append(item) - - # Index data. - result = client.insert( - collection_name=collection_name, data=data_ready_to_index) - - # Assert that the intended data has been properly indexed. - insertion_err = f'failed to insert the {result["insert_count"]} data points' - assert result["insert_count"] == len(data_ready_to_index), insertion_err - - # Release the collection from memory. It will be loaded lazily when the - # enrichment handler is invoked. - client.release_collection(collection_name) - - # Close the connection to the Milvus database, as no further preparation - # operations are needed before executing the enrichment handler. - client.close() - - return collection_name - - @staticmethod - def find_free_port(): - """Find a free port on the local machine.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - # Bind to port 0, which asks OS to assign a free port. - s.bind(('', 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # Return the port number assigned by OS. - return s.getsockname()[1] - - @staticmethod - @contextlib.contextmanager - def create_user_yaml(service_port: int, max_vector_field_num=5): - """Creates a temporary user.yaml file for Milvus configuration. - - This user yaml file overrides Milvus default configurations. It sets - the Milvus service port to the specified container service port. The - default for maxVectorFieldNum is 4, but we need 5 - (one unique field for each metric). - - Args: - service_port: Port number for the Milvus service. - max_vector_field_num: Max number of vec fields allowed per collection. - - Yields: - str: Path to the created temporary yaml file. - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', - delete=False) as temp_file: - # Define the content for user.yaml. - user_config = { - 'proxy': { - 'maxVectorFieldNum': max_vector_field_num, 'port': service_port - }, - 'etcd': { - 'use': { - 'embed': True - }, 'data': { - 'dir': '/var/lib/milvus/etcd' - } - } - } - - # Write the content to the file. - yaml.dump(user_config, temp_file, default_flow_style=False) - path = temp_file.name - - try: - yield path - finally: - if os.path.exists(path): - os.remove(path) - - @pytest.mark.require_docker_in_docker @unittest.skipUnless( platform.system() == "Linux", @@ -492,25 +236,24 @@ def create_user_yaml(service_port: int, max_vector_field_num=5): class TestMilvusSearchEnrichment(unittest.TestCase): """Tests for search functionality across all search strategies""" - _db: MilvusDBContainerInfo + _db: VectorDBContainerInfo @classmethod def setUpClass(cls): - cls._db = MilvusEnrichmentTestHelper.start_db_container() + cls._db = MilvusTestHelpers.start_db_container() cls._connection_params = MilvusConnectionParameters( uri=cls._db.uri, user=cls._db.user, password=cls._db.password, - db_id=cls._db.id, - token=cls._db.token, - timeout=60.0) # Increase timeout to 60s for container startup + db_name=cls._db.id, + token=cls._db.token) cls._collection_load_params = MilvusCollectionLoadParameters() - cls._collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( - cls._connection_params) + cls._collection_name = MilvusTestHelpers.initialize_db_with_data( + cls._connection_params, MILVUS_IT_CONFIG) @classmethod def tearDownClass(cls): - MilvusEnrichmentTestHelper.stop_db_container(cls._db) + MilvusTestHelpers.stop_db_container(cls._db) cls._db = None def test_invalid_query_on_non_existent_collection(self): @@ -589,8 +332,8 @@ def test_empty_input_chunks(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) def test_filtered_search_with_cosine_similarity_and_batching(self): test_chunks = [ @@ -717,8 +460,8 @@ def test_filtered_search_with_cosine_similarity_and_batching(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) def test_filtered_search_with_bm25_full_text_and_batching(self): test_chunks = [ @@ -822,8 +565,8 @@ def test_filtered_search_with_bm25_full_text_and_batching(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) def test_vector_search_with_euclidean_distance(self): test_chunks = [ @@ -963,8 +706,8 @@ def test_vector_search_with_euclidean_distance(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) def test_vector_search_with_inner_product_similarity(self): test_chunks = [ @@ -1103,8 +846,8 @@ def test_vector_search_with_inner_product_similarity(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) def test_keyword_search_with_inner_product_sparse_embedding(self): test_chunks = [ @@ -1168,8 +911,8 @@ def test_keyword_search_with_inner_product_sparse_embedding(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) def test_hybrid_search(self): test_chunks = [ @@ -1241,134 +984,8 @@ def test_hybrid_search(self): with TestPipeline() as p: result = (p | beam.Create(test_chunks) | Enrichment(handler)) assert_that( - result, - lambda actual: assert_chunks_equivalent(actual, expected_chunks)) - - -def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]: - parsed_chunks = [] - - # Define safe globals and disable built-in functions for safety. - safe_globals = { - 'Chunk': Chunk, - 'Content': Content, - 'Embedding': Embedding, - 'defaultdict': defaultdict, - 'list': list, - '__builtins__': {} - } - - for raw_str in chunk_str_list: - try: - # replace "" with actual list reference. - cleaned_str = re.sub( - r"defaultdict\(", "defaultdict(list", raw_str) - - # Evaluate string in restricted environment. - chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used - if isinstance(chunk, Chunk): - parsed_chunks.append(chunk) - else: - raise ValueError("Parsed object is not a Chunk instance") - except Exception as e: - raise ValueError(f"Error parsing string:\n{raw_str}\n{e}") - - return parsed_chunks - - -def assert_chunks_equivalent( - actual_chunks: List[Chunk], expected_chunks: List[Chunk]): - """assert_chunks_equivalent checks for presence rather than exact match""" - # Sort both lists by ID to ensure consistent ordering. - actual_sorted = sorted(actual_chunks, key=lambda c: c.id) - expected_sorted = sorted(expected_chunks, key=lambda c: c.id) - - actual_len = len(actual_sorted) - expected_len = len(expected_sorted) - err_msg = ( - f"Different number of chunks, actual: {actual_len}, " - f"expected: {expected_len}") - assert actual_len == expected_len, err_msg - - for actual, expected in zip(actual_sorted, expected_sorted): - # Assert that IDs match. - assert actual.id == expected.id - - # Assert that dense embeddings match. - err_msg = f"Dense embedding mismatch for chunk {actual.id}" - assert actual.dense_embedding == expected.dense_embedding, err_msg - - # Assert that sparse embeddings match. - err_msg = f"Sparse embedding mismatch for chunk {actual.id}" - assert actual.sparse_embedding == expected.sparse_embedding, err_msg - - # Assert that text content match. - err_msg = f"Text Content mismatch for chunk {actual.id}" - assert actual.content.text == expected.content.text, err_msg - - # For enrichment_data, be more flexible. - # If "expected" has values for enrichment_data but actual doesn't, that's - # acceptable since vector search results can vary based on many factors - # including implementation details, vector database state, and slight - # variations in similarity calculations. - - # First ensure the enrichment data key exists. - err_msg = f"Missing enrichment_data key in chunk {actual.id}" - assert 'enrichment_data' in actual.metadata, err_msg - - # For enrichment_data, ensure consistent ordering of results. - actual_data = actual.metadata['enrichment_data'] - expected_data = expected.metadata['enrichment_data'] - - # If actual has enrichment data, then perform detailed validation. - if actual_data and actual_data.get('id'): - # Validate IDs have consistent ordering. - actual_ids = sorted(actual_data['id']) - expected_ids = sorted(expected_data['id']) - err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}" - assert actual_ids == expected_ids, err_msg - - # Ensure the distance key exist. - err_msg = f"Missing distance key in metadata {actual.id}" - assert 'distance' in actual_data, err_msg - - # Validate distances exist and have same length as IDs. - actual_distances = actual_data['distance'] - expected_distances = expected_data['distance'] - err_msg = ( - "Number of distances doesn't match number of IDs for " - f"chunk {actual.id}") - assert len(actual_distances) == len(expected_distances), err_msg - - # Ensure the fields key exist. - err_msg = f"Missing fields key in metadata {actual.id}" - assert 'fields' in actual_data, err_msg - - # Validate fields have consistent content. - # Sort fields by 'id' to ensure consistent ordering. - actual_fields_sorted = sorted( - actual_data['fields'], key=lambda f: f.get('id', 0)) - expected_fields_sorted = sorted( - expected_data['fields'], key=lambda f: f.get('id', 0)) - - # Compare field IDs. - actual_field_ids = [f.get('id') for f in actual_fields_sorted] - expected_field_ids = [f.get('id') for f in expected_fields_sorted] - err_msg = f"Field IDs don't match for chunk {actual.id}" - assert actual_field_ids == expected_field_ids, err_msg - - # Compare field content. - for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted): - # Ensure the id key exist. - err_msg = f"Missing id key in metadata.fields {actual.id}" - assert 'id' in a_f - - err_msg = f"Field ID mismatch chunk {actual.id}" - assert a_f['id'] == e_f['id'], err_msg - - # Validate field metadata. - err_msg = f"Field Metadata doesn't match for chunk {actual.id}" - assert a_f['metadata'] == e_f['metadata'], err_msg + result, lambda actual: MilvusTestHelpers.assert_chunks_equivalent( + actual, expected_chunks)) if __name__ == '__main__': diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py index eca740a4e9c3..68afa56e399e 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py @@ -30,16 +30,16 @@ def chunk_embedding_fn(chunk: Chunk) -> str: """Convert embedding to PostgreSQL array string. - + Formats dense embedding as a PostgreSQL-compatible array string. Example: [1.0, 2.0] -> '{1.0,2.0}' - + Args: chunk: Input Chunk object. - + Returns: str: PostgreSQL array string representation of the embedding. - + Raises: ValueError: If chunk has no dense embedding. """ @@ -51,7 +51,7 @@ def chunk_embedding_fn(chunk: Chunk) -> str: @dataclass class ColumnSpec: """Specification for mapping Chunk fields to SQL columns for insertion. - + Defines how to extract and format values from Chunks into database columns, handling the full pipeline from Python value to SQL insertion. @@ -71,7 +71,7 @@ class ColumnSpec: Common examples: - "::float[]" for vector arrays - "::jsonb" for JSON data - + Examples: Basic text column (uses standard JDBC type mapping): >>> ColumnSpec.text( @@ -83,7 +83,7 @@ class ColumnSpec: Vector column with explicit array casting: >>> ColumnSpec.vector( ... column_name="embedding", - ... value_fn=lambda chunk: '{' + + ... value_fn=lambda chunk: '{' + ... ','.join(map(str, chunk.embedding.dense_embedding)) + '}' ... ) # Results in: INSERT INTO table (embedding) VALUES (?::float[]) @@ -168,17 +168,17 @@ def with_id_spec( convert_fn: Optional[Callable[[str], Any]] = None, sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': """Add ID :class:`.ColumnSpec` with optional type and conversion. - + Args: column_name: Name for the ID column (defaults to "id") python_type: Python type for the column (defaults to str) convert_fn: Optional function to convert the chunk ID If None, uses ID as-is sql_typecast: Optional SQL type cast - + Returns: Self for method chaining - + Example: >>> builder.with_id_spec( ... column_name="doc_id", @@ -205,17 +205,17 @@ def with_content_spec( convert_fn: Optional[Callable[[str], Any]] = None, sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': """Add content :class:`.ColumnSpec` with optional type and conversion. - + Args: column_name: Name for the content column (defaults to "content") python_type: Python type for the column (defaults to str) convert_fn: Optional function to convert the content text If None, uses content text as-is sql_typecast: Optional SQL type cast - + Returns: Self for method chaining - + Example: >>> builder.with_content_spec( ... column_name="content_length", @@ -244,17 +244,17 @@ def with_metadata_spec( convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None, sql_typecast: Optional[str] = "::jsonb") -> 'ColumnSpecsBuilder': """Add metadata :class:`.ColumnSpec` with optional type and conversion. - + Args: column_name: Name for the metadata column (defaults to "metadata") python_type: Python type for the column (defaults to str) convert_fn: Optional function to convert the metadata dictionary If None and python_type is str, converts to JSON string sql_typecast: Optional SQL type cast (defaults to "::jsonb") - + Returns: Self for method chaining - + Example: >>> builder.with_metadata_spec( ... column_name="meta_tags", @@ -283,19 +283,19 @@ def with_embedding_spec( convert_fn: Optional[Callable[[List[float]], Any]] = None ) -> 'ColumnSpecsBuilder': """Add embedding :class:`.ColumnSpec` with optional conversion. - + Args: column_name: Name for the embedding column (defaults to "embedding") convert_fn: Optional function to convert the dense embedding values If None, uses default PostgreSQL array format - + Returns: Self for method chaining - + Example: >>> builder.with_embedding_spec( ... column_name="embedding_vector", - ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" + ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" ... for x in values) + '}' ... ) """ @@ -330,7 +330,7 @@ def add_metadata_field( desired type. If None, value is used as-is default: Default value if field is missing from metadata sql_typecast: Optional SQL type cast (e.g. "::timestamp") - + Returns: Self for chaining @@ -385,17 +385,17 @@ def value_fn(chunk: Chunk) -> Any: def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder': """Add a custom :class:`.ColumnSpec` to the builder. - + Use this method when you need complete control over the :class:`.ColumnSpec` , including custom value extraction and type handling. - + Args: spec: A :class:`.ColumnSpec` instance defining the column name, type, value extraction, and optional SQL type casting. - + Returns: Self for method chaining - + Examples: Custom text column from chunk metadata: @@ -430,12 +430,12 @@ class ConflictResolution: IGNORE: Skips conflicting records. update_fields: Optional list of fields to update on conflict. If None, all non-conflict fields are updated. - + Examples: Simple primary key: >>> ConflictResolution("id") - + Composite key with specific update fields: >>> ConflictResolution( @@ -443,7 +443,7 @@ class ConflictResolution: ... action="UPDATE", ... update_fields=["embedding", "content"] ... ) - + Ignore conflicts: >>> ConflictResolution( diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py b/sdks/python/apache_beam/ml/rag/test_utils.py new file mode 100644 index 000000000000..f4acb105892c --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/test_utils.py @@ -0,0 +1,413 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import contextlib +import logging +import os +import socket +import tempfile +import unittest +from dataclasses import dataclass +from typing import Callable +from typing import List +from typing import Optional +from typing import cast + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.utils import retry_with_backoff + +# pylint: disable=ungrouped-imports +try: + import yaml + from pymilvus import CollectionSchema + from pymilvus import FieldSchema + from pymilvus import MilvusClient + from pymilvus.exceptions import MilvusException + from pymilvus.milvus_client import IndexParams + from testcontainers.core.config import testcontainers_config + from testcontainers.core.generic import DbContainer + from testcontainers.milvus import MilvusContainer + + from apache_beam.ml.rag.enrichment.milvus_search import MilvusConnectionParameters +except ImportError as e: + raise unittest.SkipTest(f'RAG test util dependencies not installed: {str(e)}') + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class VectorDBContainerInfo: + """Container information for vector database test instances. + + Holds connection details and container reference for testing with + vector databases like Milvus in containerized environments. + """ + container: DbContainer + host: str + port: int + user: str = "" + password: str = "" + token: str = "" + id: str = "default" + + @property + def uri(self) -> str: + return f"http://{self.host}:{self.port}" + + +class TestHelpers: + @staticmethod + def find_free_port(): + """Find a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Bind to port 0, which asks OS to assign a free port. + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # Return the port number assigned by OS. + return s.getsockname()[1] + + +class CustomMilvusContainer(MilvusContainer): + """Custom Milvus container with configurable ports and environment setup. + + Extends MilvusContainer to provide custom port binding and environment + configuration for testing with standalone Milvus instances. + """ + def __init__( + self, + image: str, + service_container_port, + healthcheck_container_port, + **kwargs, + ) -> None: + # Skip the parent class's constructor and go straight to + # GenericContainer. + super(MilvusContainer, self).__init__(image=image, **kwargs) + self.port = service_container_port + self.healthcheck_port = healthcheck_container_port + self.with_exposed_ports(service_container_port, healthcheck_container_port) + + # Get free host ports. + service_host_port = TestHelpers.find_free_port() + healthcheck_host_port = TestHelpers.find_free_port() + + # Bind container and host ports. + self.with_bind_ports(service_container_port, service_host_port) + self.with_bind_ports(healthcheck_container_port, healthcheck_host_port) + self.cmd = "milvus run standalone" + + # Set environment variables needed for Milvus. + envs = { + "ETCD_USE_EMBED": "true", + "ETCD_DATA_DIR": "/var/lib/milvus/etcd", + "COMMON_STORAGETYPE": "local", + "METRICS_PORT": str(healthcheck_container_port) + } + for env, value in envs.items(): + self.with_env(env, value) + + +class MilvusTestHelpers: + """Helper utilities for testing Milvus vector database operations. + + Provides static methods for managing test containers, configuration files, + and chunk comparison utilities for Milvus-based integration tests. + """ + # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus + # Python SDK client in setup.py is updated to match. Referring to the Milvus + # release notes compatibility matrix at + # https://milvus.io/docs/release_notes.md or PyPI at + # https://pypi.org/project/pymilvus/ for version compatibility. + # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required). + @staticmethod + def start_db_container( + image="milvusdb/milvus:v2.5.10", + max_vec_fields=5, + vector_client_max_retries=3, + tc_max_retries=None) -> Optional[VectorDBContainerInfo]: + service_container_port = TestHelpers.find_free_port() + healthcheck_container_port = TestHelpers.find_free_port() + user_yaml_creator = MilvusTestHelpers.create_user_yaml + with user_yaml_creator(service_container_port, max_vec_fields) as cfg: + info = None + original_tc_max_tries = testcontainers_config.max_tries + if tc_max_retries is not None: + testcontainers_config.max_tries = tc_max_retries + for i in range(vector_client_max_retries): + try: + vector_db_container = CustomMilvusContainer( + image=image, + service_container_port=service_container_port, + healthcheck_container_port=healthcheck_container_port) + vector_db_container = vector_db_container.with_volume_mapping( + cfg, "/milvus/configs/user.yaml") + vector_db_container.start() + host = vector_db_container.get_container_host_ip() + port = vector_db_container.get_exposed_port(service_container_port) + info = VectorDBContainerInfo(vector_db_container, host, port) + _LOGGER.info( + "milvus db container started successfully on %s.", info.uri) + except Exception as e: + stdout_logs, stderr_logs = vector_db_container.get_logs() + stdout_logs = stdout_logs.decode("utf-8") + stderr_logs = stderr_logs.decode("utf-8") + _LOGGER.warning( + "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. " + "STDOUT logs:\n%s\nSTDERR logs:\n%s", + i + 1, + vector_client_max_retries, + e, + stdout_logs, + stderr_logs) + if i == vector_client_max_retries - 1: + _LOGGER.error( + "Unable to start milvus db container for I/O tests after %d " + "retries. Tests cannot proceed. STDOUT logs:\n%s\n" + "STDERR logs:\n%s", + vector_client_max_retries, + stdout_logs, + stderr_logs) + raise e + finally: + testcontainers_config.max_tries = original_tc_max_tries + return info + + @staticmethod + def stop_db_container(db_info: VectorDBContainerInfo): + if db_info is None: + _LOGGER.warning("Milvus db info is None. Skipping stop operation.") + return + _LOGGER.debug("Stopping milvus db container.") + db_info.container.stop() + _LOGGER.info("milvus db container stopped successfully.") + + @staticmethod + def initialize_db_with_data( + connc_params: MilvusConnectionParameters, config: dict): + # Open the connection to the milvus db with retry. + def create_client(): + return MilvusClient(**connc_params.__dict__) + + client = retry_with_backoff( + create_client, + max_retries=3, + retry_delay=1.0, + operation_name="Test Milvus client connection", + exception_types=(MilvusException, )) + + # Configure schema. + field_schemas: List[FieldSchema] = cast(List[FieldSchema], config["fields"]) + schema = CollectionSchema( + fields=field_schemas, functions=config["functions"]) + + # Create collection with the schema. + collection_name = config["collection_name"] + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], config["index"]) + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Assert that collection was created. + collection_error = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), collection_error + + # Gather all fields we have excluding 'sparse_embedding_bm25' special field. + fields = list(map(lambda field: field.name, field_schemas)) + + # Prep data for indexing. Currently we can't insert sparse vectors for BM25 + # sparse embedding field as it would be automatically generated by Milvus + # through the registered BM25 function. + data_ready_to_index = [] + for doc in config["corpus"]: + item = {} + for field in fields: + if field.startswith("dense_embedding"): + item[field] = doc["dense_embedding"] + elif field == "sparse_embedding_inner_product": + item[field] = doc["sparse_embedding"] + elif field == "sparse_embedding_bm25": + # It is automatically generated by Milvus from the content field. + continue + else: + item[field] = doc[field] + data_ready_to_index.append(item) + + # Index data. + result = client.insert( + collection_name=collection_name, data=data_ready_to_index) + + # Assert that the intended data has been properly indexed. + insertion_err = f'failed to insert the {result["insert_count"]} data points' + assert result["insert_count"] == len(data_ready_to_index), insertion_err + + # Release the collection from memory. It will be loaded lazily when the + # enrichment handler is invoked. + client.release_collection(collection_name) + + # Close the connection to the Milvus database, as no further preparation + # operations are needed before executing the enrichment handler. + client.close() + + return collection_name + + @staticmethod + @contextlib.contextmanager + def create_user_yaml(service_port: int, max_vector_field_num=5): + """Creates a temporary user.yaml file for Milvus configuration. + + This user yaml file overrides Milvus default configurations. It sets + the Milvus service port to the specified container service port. The + default for maxVectorFieldNum is 4, but we need 5 + (one unique field for each metric). + + Args: + service_port: Port number for the Milvus service. + max_vector_field_num: Max number of vec fields allowed per collection. + + Yields: + str: Path to the created temporary yaml file. + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', + delete=False) as temp_file: + # Define the content for user.yaml. + user_config = { + 'proxy': { + 'maxVectorFieldNum': max_vector_field_num, 'port': service_port + }, + 'etcd': { + 'use': { + 'embed': True + }, 'data': { + 'dir': '/var/lib/milvus/etcd' + } + } + } + + # Write the content to the file. + yaml.dump(user_config, temp_file, default_flow_style=False) + path = temp_file.name + + try: + yield path + finally: + if os.path.exists(path): + os.remove(path) + + @staticmethod + def assert_chunks_equivalent( + actual_chunks: List[Chunk], expected_chunks: List[Chunk]): + """assert_chunks_equivalent checks for presence rather than exact match""" + # Sort both lists by ID to ensure consistent ordering. + actual_sorted = sorted(actual_chunks, key=lambda c: c.id) + expected_sorted = sorted(expected_chunks, key=lambda c: c.id) + + actual_len = len(actual_sorted) + expected_len = len(expected_sorted) + err_msg = ( + f"Different number of chunks, actual: {actual_len}, " + f"expected: {expected_len}") + assert actual_len == expected_len, err_msg + + for actual, expected in zip(actual_sorted, expected_sorted): + # Assert that IDs match. + assert actual.id == expected.id + + # Assert that dense embeddings match. + err_msg = f"Dense embedding mismatch for chunk {actual.id}" + assert actual.dense_embedding == expected.dense_embedding, err_msg + + # Assert that sparse embeddings match. + err_msg = f"Sparse embedding mismatch for chunk {actual.id}" + assert actual.sparse_embedding == expected.sparse_embedding, err_msg + + # Assert that text content match. + err_msg = f"Text Content mismatch for chunk {actual.id}" + assert actual.content.text == expected.content.text, err_msg + + # For enrichment_data, be more flexible. + # If "expected" has values for enrichment_data but actual doesn't, that's + # acceptable since vector search results can vary based on many factors + # including implementation details, vector database state, and slight + # variations in similarity calculations. + + # First ensure the enrichment data key exists. + err_msg = f"Missing enrichment_data key in chunk {actual.id}" + assert 'enrichment_data' in actual.metadata, err_msg + + # For enrichment_data, ensure consistent ordering of results. + actual_data = actual.metadata['enrichment_data'] + expected_data = expected.metadata['enrichment_data'] + + # If actual has enrichment data, then perform detailed validation. + if actual_data: + # Ensure the id key exist. + err_msg = f"Missing id key in metadata {actual.id}" + assert 'id' in actual_data, err_msg + + # Validate IDs have consistent ordering. + actual_ids = sorted(actual_data['id']) + expected_ids = sorted(expected_data['id']) + err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}" + assert actual_ids == expected_ids, err_msg + + # Ensure the distance key exist. + err_msg = f"Missing distance key in metadata {actual.id}" + assert 'distance' in actual_data, err_msg + + # Validate distances exist and have same length as IDs. + actual_distances = actual_data['distance'] + expected_distances = expected_data['distance'] + err_msg = ( + "Number of distances doesn't match number of IDs for " + f"chunk {actual.id}") + assert len(actual_distances) == len(expected_distances), err_msg + + # Ensure the fields key exist. + err_msg = f"Missing fields key in metadata {actual.id}" + assert 'fields' in actual_data, err_msg + + # Validate fields have consistent content. + # Sort fields by 'id' to ensure consistent ordering. + actual_fields_sorted = sorted( + actual_data['fields'], key=lambda f: f.get('id', 0)) + expected_fields_sorted = sorted( + expected_data['fields'], key=lambda f: f.get('id', 0)) + + # Compare field IDs. + actual_field_ids = [f.get('id') for f in actual_fields_sorted] + expected_field_ids = [f.get('id') for f in expected_fields_sorted] + err_msg = f"Field IDs don't match for chunk {actual.id}" + assert actual_field_ids == expected_field_ids, err_msg + + # Compare field content. + for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted): + # Ensure the id key exist. + err_msg = f"Missing id key in metadata.fields {actual.id}" + assert 'id' in a_f, err_msg + + err_msg = f"Field ID mismatch chunk {actual.id}" + assert a_f['id'] == e_f['id'], err_msg + + # Validate field metadata. + err_msg = f"Field Metadata doesn't match for chunk {actual.id}" + assert a_f['metadata'] == e_f['metadata'], err_msg + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/utils.py b/sdks/python/apache_beam/ml/rag/utils.py new file mode 100644 index 000000000000..d45e99be0ecb --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/utils.py @@ -0,0 +1,224 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import re +import time +import uuid +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding + +_LOGGER = logging.getLogger(__name__) + +# Default batch size for writing data to Milvus, matching +# JdbcIO.DEFAULT_BATCH_SIZE. +DEFAULT_WRITE_BATCH_SIZE = 1000 + + +@dataclass +class MilvusConnectionParameters: + """Configurations for establishing connections to Milvus servers. + + Args: + uri: URI endpoint for connecting to Milvus server in the format + "http(s)://hostname:port". + user: Username for authentication. Required if authentication is enabled and + not using token authentication. + password: Password for authentication. Required if authentication is enabled + and not using token authentication. + db_name: Database Name to connect to. Specifies which Milvus database to + use. Defaults to 'default'. + token: Authentication token as an alternative to username/password. + timeout: Connection timeout in seconds. Uses client default if None. + kwargs: Optional keyword arguments for additional connection parameters. + Enables forward compatibility. + """ + uri: str + user: str = field(default_factory=str) + password: str = field(default_factory=str) + db_name: str = "default" + token: str = field(default_factory=str) + timeout: Optional[float] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.uri: + raise ValueError("URI must be provided for Milvus connection") + + # Generate unique alias if not provided. One-to-one mapping between alias + # and connection - each alias represents exactly one Milvus connection. + if "alias" not in self.kwargs: + alias = f"milvus_conn_{uuid.uuid4().hex[:8]}" + self.kwargs["alias"] = alias + + +class MilvusHelpers: + """Utility class providing helper methods for Milvus vector db operations.""" + @staticmethod + def sparse_embedding( + sparse_vector: Tuple[List[int], + List[float]]) -> Optional[Dict[int, float]]: + if not sparse_vector: + return None + # Converts sparse embedding from (indices, values) tuple format to + # Milvus-compatible values dict format {dimension_index: value, ...}. + indices, values = sparse_vector + return {int(idx): float(val) for idx, val in zip(indices, values)} + + +def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]: + parsed_chunks = [] + + # Define safe globals and disable built-in functions for safety. + safe_globals = { + 'Chunk': Chunk, + 'Content': Content, + 'Embedding': Embedding, + 'defaultdict': defaultdict, + 'list': list, + '__builtins__': {} + } + + for raw_str in chunk_str_list: + try: + # replace "" with actual list reference. + cleaned_str = re.sub( + r"defaultdict\(", "defaultdict(list", raw_str) + + # Evaluate string in restricted environment. + chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used + if isinstance(chunk, Chunk): + parsed_chunks.append(chunk) + else: + raise ValueError("Parsed object is not a Chunk instance") + except Exception as e: + raise ValueError(f"Error parsing string:\n{raw_str}\n{e}") + + return parsed_chunks + + +def unpack_dataclass_with_kwargs(dataclass_instance): + """Unpacks dataclass fields into a flat dict, merging kwargs with precedence. + + Args: + dataclass_instance: Dataclass instance to unpack. + + Returns: + dict: Flattened dictionary with kwargs taking precedence over fields. + """ + # Create a copy of the dataclass's __dict__. + params_dict: dict = dataclass_instance.__dict__.copy() + + # Extract the nested kwargs dictionary. + nested_kwargs = params_dict.pop('kwargs', {}) + + # Merge the dictionaries, with nested_kwargs taking precedence + # in case of duplicate keys. + return {**params_dict, **nested_kwargs} + + +def retry_with_backoff( + operation: Callable[[], Any], + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff_factor: float = 2.0, + operation_name: str = "operation", + exception_types: Tuple[Type[BaseException], ...] = (Exception, ) +) -> Any: + """Executes an operation with retry logic and exponential backoff. + + This is a generic retry utility that can be used for any operation that may + fail transiently. It retries the operation with exponential backoff between + attempts. + + Note: + This utility is designed for one-time setup operations and complements + Apache Beam's RequestResponseIO pattern. Use retry_with_backoff() for: + + * Establishing client connections in __enter__() methods (e.g., creating + MilvusClient instances, database connections) before processing elements + * One-time setup/teardown operations in DoFn lifecycle methods + * Operations outside of per-element processing where retry is needed + + For per-element operations (e.g., API calls within Caller.__call__), + use RequestResponseIO which already provides automatic retry with + exponential backoff, failure handling, caching, and other features. + See: https://beam.apache.org/documentation/io/built-in/webapis/ + + Args: + operation: Callable that performs the operation to retry. Should return + the result of the operation. + max_retries: Maximum number of retry attempts. Default is 3. + retry_delay: Initial delay in seconds between retries. Default is 1.0. + retry_backoff_factor: Multiplier for the delay after each retry. Default + is 2.0 (exponential backoff). + operation_name: Name of the operation for logging purposes. Default is + "operation". + exception_types: Tuple of exception types to catch and retry. Default is + (Exception,) which catches all exceptions. + + Returns: + The result of the operation if successful. + + Raises: + The last exception encountered if all retry attempts fail. + + Example: + >>> def connect_to_service(): + ... return service.connect(host="localhost") + >>> client = retry_with_backoff( + ... connect_to_service, + ... max_retries=5, + ... retry_delay=2.0, + ... operation_name="service connection") + """ + last_exception = None + for attempt in range(max_retries + 1): + try: + result = operation() + _LOGGER.info( + "Successfully completed %s on attempt %d", + operation_name, + attempt + 1) + return result + except exception_types as e: + last_exception = e + if attempt < max_retries: + delay = retry_delay * (retry_backoff_factor**attempt) + _LOGGER.warning( + "%s attempt %d failed: %s. Retrying in %.2f seconds...", + operation_name, + attempt + 1, + e, + delay) + time.sleep(delay) + else: + _LOGGER.error( + "Failed %s after %d attempts", operation_name, max_retries + 1) + raise last_exception