-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[2/3] sdks/python: sink data with Milvus Search I/O connector #36729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
f5788ab
sdks/python: add milvus sink integration
mohamedawnallah 3992585
CHANGES.md: update release notes
mohamedawnallah 798468a
Merge remote-tracking branch 'upstream/master' into sinkWilMilvusIO-2
mohamedawnallah 43c5da2
sdks/python: fix py docs formatting issues
mohamedawnallah 46a03c8
sdks/python: fix linting issues
mohamedawnallah 8ac2c42
sdks/python: delegate auto-flushing to milvus backend
mohamedawnallah d0b68de
sdks/python: address gemini comments
mohamedawnallah File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
359 changes: 359 additions & 0 deletions
359
sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,359 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| from dataclasses import dataclass | ||
| from dataclasses import field | ||
| from typing import Any | ||
| from typing import Callable | ||
| from typing import Dict | ||
| from typing import List | ||
| from typing import Optional | ||
|
|
||
| from pymilvus import MilvusClient | ||
| from pymilvus.exceptions import MilvusException | ||
|
|
||
| import apache_beam as beam | ||
| from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig | ||
| from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig | ||
| from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec | ||
| from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder | ||
| from apache_beam.ml.rag.types import Chunk | ||
| from apache_beam.ml.rag.utils import DEFAULT_WRITE_BATCH_SIZE | ||
| from apache_beam.ml.rag.utils import MilvusConnectionParameters | ||
| from apache_beam.ml.rag.utils import MilvusHelpers | ||
| from apache_beam.ml.rag.utils import retry_with_backoff | ||
| from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs | ||
| from apache_beam.transforms import DoFn | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class MilvusWriteConfig: | ||
| """Configuration parameters for writing data to Milvus collections. | ||
| This class defines the parameters needed to write data to a Milvus collection, | ||
| including collection targeting, batching behavior, and operation timeouts. | ||
| Args: | ||
| collection_name: Name of the target Milvus collection to write data to. | ||
| Must be a non-empty string. | ||
| partition_name: Name of the specific partition within the collection to | ||
| write to. If empty, writes to the default partition. | ||
| timeout: Maximum time in seconds to wait for write operations to complete. | ||
| If None, uses the client's default timeout. | ||
| write_config: Configuration for write operations including batch size and | ||
| other write-specific settings. | ||
| kwargs: Additional keyword arguments for write operations. Enables forward | ||
| compatibility with future Milvus client parameters. | ||
| """ | ||
| collection_name: str | ||
| partition_name: str = "" | ||
| timeout: Optional[float] = None | ||
| write_config: WriteConfig = field(default_factory=WriteConfig) | ||
| kwargs: Dict[str, Any] = field(default_factory=dict) | ||
|
|
||
| def __post_init__(self): | ||
| if not self.collection_name: | ||
| raise ValueError("Collection name must be provided") | ||
|
|
||
| @property | ||
| def write_batch_size(self): | ||
| """Returns the batch size for write operations. | ||
| Returns: | ||
| The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. | ||
| """ | ||
| return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE | ||
|
|
||
|
|
||
| @dataclass | ||
| class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): | ||
| """Configuration for writing vector data to Milvus collections. | ||
| This class extends VectorDatabaseWriteConfig to provide Milvus-specific | ||
| configuration for ingesting vector embeddings and associated metadata. | ||
| It defines how Apache Beam chunks are converted to Milvus records and | ||
| handles the write operation parameters. | ||
| The configuration includes connection parameters, write settings, and | ||
| column specifications that determine how chunk data is mapped to Milvus | ||
| fields. | ||
| Args: | ||
| connection_params: Configuration for connecting to the Milvus server, | ||
| including URI, credentials, and connection options. | ||
| write_config: Configuration for write operations including collection name, | ||
| partition, batch size, and timeouts. | ||
| column_specs: List of column specifications defining how chunk fields are | ||
| mapped to Milvus collection fields. Defaults to standard RAG fields | ||
| (id, embedding, sparse_embedding, content, metadata). | ||
| Example: | ||
| config = MilvusVectorWriterConfig( | ||
| connection_params=MilvusConnectionParameters( | ||
| uri="http://localhost:19530"), | ||
| write_config=MilvusWriteConfig(collection_name="my_collection"), | ||
| column_specs=MilvusVectorWriterConfig.default_column_specs()) | ||
| """ | ||
| connection_params: MilvusConnectionParameters | ||
| write_config: MilvusWriteConfig | ||
| column_specs: List[ColumnSpec] = field( | ||
| default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) | ||
|
|
||
| def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: | ||
| """Creates a function to convert Apache Beam Chunks to Milvus records. | ||
| Returns: | ||
| A function that takes a Chunk and returns a dictionary representing | ||
| a Milvus record with fields mapped according to column_specs. | ||
| """ | ||
| def convert(chunk: Chunk) -> Dict[str, Any]: | ||
| result = {} | ||
| for col in self.column_specs: | ||
| result[col.column_name] = col.value_fn(chunk) | ||
| return result | ||
|
|
||
| return convert | ||
|
|
||
| def create_write_transform(self) -> beam.PTransform: | ||
| """Creates the Apache Beam transform for writing to Milvus. | ||
| Returns: | ||
| A PTransform that can be applied to a PCollection of Chunks to write | ||
| them to the configured Milvus collection. | ||
| """ | ||
| return _WriteToMilvusVectorDatabase(self) | ||
|
|
||
| @staticmethod | ||
| def default_column_specs() -> List[ColumnSpec]: | ||
| """Returns default column specifications for RAG use cases. | ||
| Creates column mappings for standard RAG fields: id, dense embedding, | ||
| sparse embedding, content text, and metadata. These specifications | ||
| define how Chunk fields are converted to Milvus-compatible formats. | ||
| Returns: | ||
| List of ColumnSpec objects defining the default field mappings. | ||
| """ | ||
| column_specs = ColumnSpecsBuilder() | ||
| return column_specs\ | ||
| .with_id_spec()\ | ||
| .with_embedding_spec(convert_fn=lambda values: list(values))\ | ||
| .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ | ||
| .with_content_spec()\ | ||
| .with_metadata_spec(convert_fn=lambda values: dict(values))\ | ||
| .build() | ||
|
|
||
|
|
||
| class _WriteToMilvusVectorDatabase(beam.PTransform): | ||
| """Apache Beam PTransform for writing vector data to Milvus. | ||
| This transform handles the conversion of Apache Beam Chunks to Milvus records | ||
| and coordinates the write operations. It applies the configured converter | ||
| function and uses a DoFn for batched writes to optimize performance. | ||
| Args: | ||
| config: MilvusVectorWriterConfig containing all necessary parameters for | ||
| the write operation. | ||
| """ | ||
| def __init__(self, config: MilvusVectorWriterConfig): | ||
| self.config = config | ||
|
|
||
| def expand(self, pcoll: beam.PCollection[Chunk]): | ||
| """Expands the PTransform to convert chunks and write to Milvus. | ||
| Args: | ||
| pcoll: PCollection of Chunk objects to write to Milvus. | ||
| Returns: | ||
| PCollection of the same Chunk objects after writing to Milvus. | ||
mohamedawnallah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| return ( | ||
| pcoll | ||
| | "Convert to Records" >> beam.Map(self.config.create_converter()) | ||
| | beam.ParDo( | ||
| _WriteMilvusFn( | ||
| self.config.connection_params, self.config.write_config))) | ||
|
|
||
|
|
||
| class _WriteMilvusFn(DoFn): | ||
| """DoFn that handles batched writes to Milvus. | ||
| This DoFn accumulates records in batches and flushes them to Milvus when | ||
| the batch size is reached or when the bundle finishes. This approach | ||
| optimizes performance by reducing the number of individual write operations. | ||
| Args: | ||
| connection_params: Configuration for connecting to the Milvus server. | ||
| write_config: Configuration for write operations including batch size | ||
| and collection details. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| connection_params: MilvusConnectionParameters, | ||
| write_config: MilvusWriteConfig): | ||
| self._connection_params = connection_params | ||
| self._write_config = write_config | ||
| self.batch = [] | ||
|
|
||
| def process(self, element, *args, **kwargs): | ||
| """Processes individual records, batching them for efficient writes. | ||
| Args: | ||
| element: A dictionary representing a Milvus record to write. | ||
| *args: Additional positional arguments. | ||
| **kwargs: Additional keyword arguments. | ||
| Yields: | ||
| The original element after adding it to the batch. | ||
| """ | ||
| _ = args, kwargs # Unused parameters | ||
| self.batch.append(element) | ||
| if len(self.batch) >= self._write_config.write_batch_size: | ||
| self._flush() | ||
| yield element | ||
|
|
||
| def finish_bundle(self): | ||
| """Called when a bundle finishes processing. | ||
| Flushes any remaining records in the batch to ensure all data is written. | ||
| """ | ||
| self._flush() | ||
|
|
||
| def _flush(self): | ||
| """Flushes the current batch of records to Milvus. | ||
| Creates a MilvusSink connection and writes all batched records, | ||
| then clears the batch for the next set of records. | ||
| """ | ||
| if len(self.batch) == 0: | ||
| return | ||
| with _MilvusSink(self._connection_params, self._write_config) as sink: | ||
| sink.write(self.batch) | ||
| self.batch = [] | ||
|
|
||
| def display_data(self): | ||
| """Returns display data for monitoring and debugging. | ||
| Returns: | ||
| Dictionary containing database, collection, and batch size information | ||
| for display in the Apache Beam monitoring UI. | ||
| """ | ||
| res = super().display_data() | ||
| res["database"] = self._connection_params.db_name | ||
| res["collection"] = self._write_config.collection_name | ||
| res["batch_size"] = self._write_config.write_batch_size | ||
| return res | ||
|
|
||
|
|
||
| class _MilvusSink: | ||
| """Low-level sink for writing data directly to Milvus. | ||
| This class handles the direct interaction with the Milvus client for | ||
| upsert operations. It manages the connection lifecycle and provides | ||
| context manager support for proper resource cleanup. | ||
| Args: | ||
| connection_params: Configuration for connecting to the Milvus server. | ||
| write_config: Configuration for write operations including collection | ||
| and partition targeting. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| connection_params: MilvusConnectionParameters, | ||
| write_config: MilvusWriteConfig): | ||
| self._connection_params = connection_params | ||
| self._write_config = write_config | ||
| self._client = None | ||
|
|
||
| def write(self, documents): | ||
| """Writes a batch of documents to the Milvus collection. | ||
| Performs an upsert operation to insert new documents or update existing | ||
| ones based on primary key. After the upsert, flushes the collection to | ||
| ensure data persistence. | ||
| Args: | ||
| documents: List of dictionaries representing Milvus records to write. | ||
| Each dictionary should contain fields matching the collection schema. | ||
| """ | ||
| if not self._client: | ||
| self._client = MilvusClient( | ||
| **unpack_dataclass_with_kwargs(self._connection_params)) | ||
mohamedawnallah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| try: | ||
| resp = self._client.upsert( | ||
| collection_name=self._write_config.collection_name, | ||
| partition_name=self._write_config.partition_name, | ||
| data=documents, | ||
| timeout=self._write_config.timeout, | ||
| **self._write_config.kwargs) | ||
|
|
||
| # Try to flush, but handle connection issues gracefully. | ||
| try: | ||
| self._client.flush(self._write_config.collection_name) | ||
| except Exception as e: | ||
| # If flush fails due to connection issues, log but don't fail the write. | ||
damccorm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| _LOGGER.warning( | ||
| "Flush operation failed, but upsert was successful: %s", e) | ||
|
|
||
| _LOGGER.debug( | ||
| "Upserted into Milvus: upsert_count=%d, cost=%d", | ||
| resp.get("upsert_count", 0), | ||
| resp.get("cost", 0)) | ||
| except Exception as e: | ||
| _LOGGER.error("Failed to write to Milvus: %s", e) | ||
mohamedawnallah marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise | ||
|
|
||
| def __enter__(self): | ||
| """Enters the context manager and establishes Milvus connection. | ||
| Returns: | ||
| Self, enabling use in 'with' statements. | ||
| """ | ||
| if not self._client: | ||
| connection_params = unpack_dataclass_with_kwargs(self._connection_params) | ||
|
|
||
| # Extract retry parameters from connection_params. | ||
| max_retries = connection_params.pop('max_retries', 3) | ||
| retry_delay = connection_params.pop('retry_delay', 1.0) | ||
| retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0) | ||
|
|
||
| def create_client(): | ||
| return MilvusClient(**connection_params) | ||
|
|
||
| self._client = retry_with_backoff( | ||
| create_client, | ||
| max_retries=max_retries, | ||
| retry_delay=retry_delay, | ||
| retry_backoff_factor=retry_backoff_factor, | ||
| operation_name="Milvus connection", | ||
| exception_types=(MilvusException, )) | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc_val, exc_tb): | ||
| """Exits the context manager and closes the Milvus connection. | ||
| Args: | ||
| exc_type: Exception type if an exception was raised. | ||
| exc_val: Exception value if an exception was raised. | ||
| exc_tb: Exception traceback if an exception was raised. | ||
| """ | ||
| _ = exc_type, exc_val, exc_tb # Unused parameters | ||
| if self._client: | ||
| self._client.close() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.