Skip to content

Commit 07f6162

Browse files
sdks/python: finalize milvus sink i/o connector
1 parent 0f949cf commit 07f6162

File tree

7 files changed

+1043
-512
lines changed

7 files changed

+1043
-512
lines changed

sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py

Lines changed: 8 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535

3636
from apache_beam.ml.rag.types import Chunk
3737
from apache_beam.ml.rag.types import Embedding
38+
from apache_beam.ml.rag.utils import (
39+
MilvusHelpers,
40+
MilvusConnectionConfig,
41+
unpack_dataclass_with_kwargs)
3842
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
3943

4044

@@ -104,37 +108,6 @@ def __str__(self):
104108
return self.dict().__str__()
105109

106110

107-
@dataclass
108-
class MilvusConnectionParameters:
109-
"""Parameters for establishing connections to Milvus servers.
110-
111-
Args:
112-
uri: URI endpoint for connecting to Milvus server in the format
113-
"http(s)://hostname:port".
114-
user: Username for authentication. Required if authentication is enabled and
115-
not using token authentication.
116-
password: Password for authentication. Required if authentication is enabled
117-
and not using token authentication.
118-
db_id: Database ID to connect to. Specifies which Milvus database to use.
119-
Defaults to 'default'.
120-
token: Authentication token as an alternative to username/password.
121-
timeout: Connection timeout in seconds. Uses client default if None.
122-
kwargs: Optional keyword arguments for additional connection parameters.
123-
Enables forward compatibility.
124-
"""
125-
uri: str
126-
user: str = field(default_factory=str)
127-
password: str = field(default_factory=str)
128-
db_id: str = "default"
129-
token: str = field(default_factory=str)
130-
timeout: Optional[float] = None
131-
kwargs: Dict[str, Any] = field(default_factory=dict)
132-
133-
def __post_init__(self):
134-
if not self.uri:
135-
raise ValueError("URI must be provided for Milvus connection")
136-
137-
138111
@dataclass
139112
class BaseSearchParameters:
140113
"""Base parameters for both vector and keyword search operations.
@@ -345,7 +318,7 @@ class MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
345318
"""
346319
def __init__(
347320
self,
348-
connection_parameters: MilvusConnectionParameters,
321+
connection_parameters: MilvusConnectionConfig,
349322
search_parameters: MilvusSearchParameters,
350323
*,
351324
collection_load_parameters: Optional[MilvusCollectionLoadParameters],
@@ -354,7 +327,7 @@ def __init__(
354327
**kwargs):
355328
"""
356329
Example Usage:
357-
connection_paramters = MilvusConnectionParameters(
330+
connection_paramters = MilvusConnectionConfig(
358331
uri="http://localhost:19530")
359332
search_parameters = MilvusSearchParameters(
360333
collection_name="my_collection",
@@ -369,7 +342,7 @@ def __init__(
369342
max_batch_size=100)
370343
371344
Args:
372-
connection_parameters (MilvusConnectionParameters): Configuration for
345+
connection_parameters (MilvusConnectionConfig): Configuration for
373346
connecting to the Milvus server, including URI, credentials, and
374347
connection options.
375348
search_parameters (MilvusSearchParameters): Configuration for search
@@ -493,8 +466,7 @@ def _get_keyword_search_data(self, chunk: Chunk):
493466
f"Chunk {chunk.id} missing both text content and sparse embedding "
494467
"required for keyword search")
495468

496-
sparse_embedding = self.convert_sparse_embedding_to_milvus_format(
497-
chunk.sparse_embedding)
469+
sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding)
498470

499471
return chunk.content.text or sparse_embedding
500472

@@ -533,15 +505,6 @@ def _normalize_milvus_value(self, value: Any):
533505
# Keep other types as they are.
534506
return value
535507

536-
def convert_sparse_embedding_to_milvus_format(
537-
self, sparse_vector: Tuple[List[int], List[float]]) -> Dict[int, float]:
538-
if not sparse_vector:
539-
return None
540-
# Converts sparse embedding from (indices, values) tuple format to
541-
# Milvus-compatible values dict format {dimension_index: value, ...}.
542-
indices, values = sparse_vector
543-
return {int(idx): float(val) for idx, val in zip(indices, values)}
544-
545508
@property
546509
def collection_name(self):
547510
"""Getter method for collection_name property"""
@@ -585,15 +548,3 @@ def batch_elements_kwargs(self) -> Dict[str, int]:
585548
def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding:
586549
left.metadata['enrichment_data'] = right
587550
return left
588-
589-
590-
def unpack_dataclass_with_kwargs(dataclass_instance):
591-
# Create a copy of the dataclass's __dict__.
592-
params_dict: dict = dataclass_instance.__dict__.copy()
593-
594-
# Extract the nested kwargs dictionary.
595-
nested_kwargs = params_dict.pop('kwargs', {})
596-
597-
# Merge the dictionaries, with nested_kwargs taking precedence
598-
# in case of duplicate keys.
599-
return {**params_dict, **nested_kwargs}

0 commit comments

Comments
 (0)