|
25 | 25 | from typing import Optional |
26 | 26 | from typing import Tuple |
27 | 27 | from typing import Union |
| 28 | +import uuid |
28 | 29 |
|
29 | 30 | from google.protobuf.json_format import MessageToDict |
30 | 31 | from pymilvus import AnnSearchRequest |
|
35 | 36 |
|
36 | 37 | from apache_beam.ml.rag.types import Chunk |
37 | 38 | from apache_beam.ml.rag.types import Embedding |
| 39 | +from apache_beam.ml.rag.utils import MilvusHelpers, MilvusConnectionParameters |
38 | 40 | from apache_beam.transforms.enrichment import EnrichmentSourceHandler |
39 | 41 |
|
40 | 42 |
|
@@ -104,44 +106,6 @@ def __str__(self): |
104 | 106 | return self.dict().__str__() |
105 | 107 |
|
106 | 108 |
|
107 | | -@dataclass |
108 | | -class MilvusConnectionParameters: |
109 | | - """Parameters for establishing connections to Milvus servers. |
110 | | -
|
111 | | - Args: |
112 | | - uri: URI endpoint for connecting to Milvus server in the format |
113 | | - "http(s)://hostname:port". |
114 | | - user: Username for authentication. Required if authentication is enabled and |
115 | | - not using token authentication. |
116 | | - password: Password for authentication. Required if authentication is enabled |
117 | | - and not using token authentication. |
118 | | - db_id: Database ID to connect to. Specifies which Milvus database to use. |
119 | | - Defaults to 'default'. |
120 | | - token: Authentication token as an alternative to username/password. |
121 | | - timeout: Connection timeout in seconds. Uses client default if None. |
122 | | - max_retries: Maximum number of connection retry attempts. Defaults to 3. |
123 | | - retry_delay: Initial delay between retries in seconds. Defaults to 1.0. |
124 | | - retry_backoff_factor: Multiplier for retry delay after each attempt. |
125 | | - Defaults to 2.0 (exponential backoff). |
126 | | - kwargs: Optional keyword arguments for additional connection parameters. |
127 | | - Enables forward compatibility. |
128 | | - """ |
129 | | - uri: str |
130 | | - user: str = field(default_factory=str) |
131 | | - password: str = field(default_factory=str) |
132 | | - db_id: str = "default" |
133 | | - token: str = field(default_factory=str) |
134 | | - timeout: Optional[float] = None |
135 | | - max_retries: int = 3 |
136 | | - retry_delay: float = 1.0 |
137 | | - retry_backoff_factor: float = 2.0 |
138 | | - kwargs: Dict[str, Any] = field(default_factory=dict) |
139 | | - |
140 | | - def __post_init__(self): |
141 | | - if not self.uri: |
142 | | - raise ValueError("URI must be provided for Milvus connection") |
143 | | - |
144 | | - |
145 | 109 | @dataclass |
146 | 110 | class BaseSearchParameters: |
147 | 111 | """Base parameters for both vector and keyword search operations. |
@@ -361,15 +325,15 @@ def __init__( |
361 | 325 | **kwargs): |
362 | 326 | """ |
363 | 327 | Example Usage: |
364 | | - connection_paramters = MilvusConnectionParameters( |
| 328 | + connection_parameters = MilvusConnectionParameters( |
365 | 329 | uri="http://localhost:19530") |
366 | 330 | search_parameters = MilvusSearchParameters( |
367 | 331 | collection_name="my_collection", |
368 | 332 | search_strategy=VectorSearchParameters(anns_field="embedding")) |
369 | 333 | collection_load_parameters = MilvusCollectionLoadParameters( |
370 | 334 | load_fields=["embedding", "metadata"]), |
371 | 335 | milvus_handler = MilvusSearchEnrichmentHandler( |
372 | | - connection_paramters, |
| 336 | + connection_parameters, |
373 | 337 | search_parameters, |
374 | 338 | collection_load_parameters=collection_load_parameters, |
375 | 339 | min_batch_size=10, |
@@ -534,10 +498,7 @@ def _get_keyword_search_data(self, chunk: Chunk): |
534 | 498 | raise ValueError( |
535 | 499 | f"Chunk {chunk.id} missing both text content and sparse embedding " |
536 | 500 | "required for keyword search") |
537 | | - |
538 | | - sparse_embedding = self.convert_sparse_embedding_to_milvus_format( |
539 | | - chunk.sparse_embedding) |
540 | | - |
| 501 | + sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding) |
541 | 502 | return chunk.content.text or sparse_embedding |
542 | 503 |
|
543 | 504 | def _get_call_response( |
|
0 commit comments