3535
3636from apache_beam .ml .rag .types import Chunk
3737from apache_beam .ml .rag .types import Embedding
38+ from apache_beam .ml .rag .utils import (
39+ MilvusHelpers ,
40+ MilvusConnectionConfig ,
41+ unpack_dataclass_with_kwargs )
3842from apache_beam .transforms .enrichment import EnrichmentSourceHandler
3943
4044
@@ -104,37 +108,6 @@ def __str__(self):
104108 return self .dict ().__str__ ()
105109
106110
107- @dataclass
108- class MilvusConnectionParameters :
109- """Parameters for establishing connections to Milvus servers.
110-
111- Args:
112- uri: URI endpoint for connecting to Milvus server in the format
113- "http(s)://hostname:port".
114- user: Username for authentication. Required if authentication is enabled and
115- not using token authentication.
116- password: Password for authentication. Required if authentication is enabled
117- and not using token authentication.
118- db_id: Database ID to connect to. Specifies which Milvus database to use.
119- Defaults to 'default'.
120- token: Authentication token as an alternative to username/password.
121- timeout: Connection timeout in seconds. Uses client default if None.
122- kwargs: Optional keyword arguments for additional connection parameters.
123- Enables forward compatibility.
124- """
125- uri : str
126- user : str = field (default_factory = str )
127- password : str = field (default_factory = str )
128- db_id : str = "default"
129- token : str = field (default_factory = str )
130- timeout : Optional [float ] = None
131- kwargs : Dict [str , Any ] = field (default_factory = dict )
132-
133- def __post_init__ (self ):
134- if not self .uri :
135- raise ValueError ("URI must be provided for Milvus connection" )
136-
137-
138111@dataclass
139112class BaseSearchParameters :
140113 """Base parameters for both vector and keyword search operations.
@@ -345,7 +318,7 @@ class MilvusSearchEnrichmentHandler(EnrichmentSourceHandler[InputT, OutputT]):
345318 """
346319 def __init__ (
347320 self ,
348- connection_parameters : MilvusConnectionParameters ,
321+ connection_parameters : MilvusConnectionConfig ,
349322 search_parameters : MilvusSearchParameters ,
350323 * ,
351324 collection_load_parameters : Optional [MilvusCollectionLoadParameters ],
@@ -354,7 +327,7 @@ def __init__(
354327 ** kwargs ):
355328 """
356329 Example Usage:
357- connection_paramters = MilvusConnectionParameters (
330+ connection_paramters = MilvusConnectionConfig (
358331 uri="http://localhost:19530")
359332 search_parameters = MilvusSearchParameters(
360333 collection_name="my_collection",
@@ -369,7 +342,7 @@ def __init__(
369342 max_batch_size=100)
370343
371344 Args:
372- connection_parameters (MilvusConnectionParameters ): Configuration for
345+ connection_parameters (MilvusConnectionConfig ): Configuration for
373346 connecting to the Milvus server, including URI, credentials, and
374347 connection options.
375348 search_parameters (MilvusSearchParameters): Configuration for search
@@ -493,8 +466,7 @@ def _get_keyword_search_data(self, chunk: Chunk):
493466 f"Chunk { chunk .id } missing both text content and sparse embedding "
494467 "required for keyword search" )
495468
496- sparse_embedding = self .convert_sparse_embedding_to_milvus_format (
497- chunk .sparse_embedding )
469+ sparse_embedding = MilvusHelpers .sparse_embedding (chunk .sparse_embedding )
498470
499471 return chunk .content .text or sparse_embedding
500472
@@ -533,15 +505,6 @@ def _normalize_milvus_value(self, value: Any):
533505 # Keep other types as they are.
534506 return value
535507
536- def convert_sparse_embedding_to_milvus_format (
537- self , sparse_vector : Tuple [List [int ], List [float ]]) -> Dict [int , float ]:
538- if not sparse_vector :
539- return None
540- # Converts sparse embedding from (indices, values) tuple format to
541- # Milvus-compatible values dict format {dimension_index: value, ...}.
542- indices , values = sparse_vector
543- return {int (idx ): float (val ) for idx , val in zip (indices , values )}
544-
545508 @property
546509 def collection_name (self ):
547510 """Getter method for collection_name property"""
@@ -585,15 +548,3 @@ def batch_elements_kwargs(self) -> Dict[str, int]:
585548def join_fn (left : Embedding , right : Dict [str , Any ]) -> Embedding :
586549 left .metadata ['enrichment_data' ] = right
587550 return left
588-
589-
590- def unpack_dataclass_with_kwargs (dataclass_instance ):
591- # Create a copy of the dataclass's __dict__.
592- params_dict : dict = dataclass_instance .__dict__ .copy ()
593-
594- # Extract the nested kwargs dictionary.
595- nested_kwargs = params_dict .pop ('kwargs' , {})
596-
597- # Merge the dictionaries, with nested_kwargs taking precedence
598- # in case of duplicate keys.
599- return {** params_dict , ** nested_kwargs }
0 commit comments