Skip to content

Commit 0ee038c

Browse files
sdks/python: fix linting issues
1 parent 732ae31 commit 0ee038c

File tree

3 files changed

+93
-85
lines changed

3 files changed

+93
-85
lines changed

sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,9 @@
5252
ConnectionConfig,
5353
CloudSQLConnectionConfig,
5454
ExternalSQLDBConnectionConfig)
55-
from apache_beam.ml.rag.enrichment.milvus_search import (
56-
MilvusConnectionParameters)
57-
from apache_beam.ml.rag.enrichment.milvus_search_it_test import (
58-
MilvusEnrichmentTestHelper, MilvusDBContainerInfo)
55+
from apache_beam.ml.rag.enrichment.milvus_search import MilvusConnectionParameters
56+
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
57+
from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
5958
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
6059
from apache_beam.ml.rag.utils import parse_chunk_strings
6160
from apache_beam.io.requestresponse import RequestResponseIO
@@ -261,7 +260,7 @@ def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
261260
@staticmethod
262261
@contextmanager
263262
def milvus_test_context():
264-
db: Optional[MilvusDBContainerInfo] = None
263+
db: Optional[VectorDBContainerInfo] = None
265264
try:
266265
db = EnrichmentTestHelpers.pre_milvus_enrichment()
267266
yield
@@ -374,16 +373,16 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct):
374373
os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)
375374

376375
@staticmethod
377-
def pre_milvus_enrichment() -> MilvusDBContainerInfo:
376+
def pre_milvus_enrichment() -> VectorDBContainerInfo:
378377
try:
379-
db = MilvusEnrichmentTestHelper.start_db_container()
378+
db = MilvusTestHelpers.start_db_container()
380379
connection_params = MilvusConnectionParameters(
381380
uri=db.uri,
382381
user=db.user,
383382
password=db.password,
384383
db_id=db.id,
385384
token=db.token)
386-
collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
385+
collection_name = MilvusTestHelpers.initialize_db_with_data(
387386
connection_params)
388387
except Exception as e:
389388
raise TestContainerStartupError(
@@ -401,9 +400,9 @@ def pre_milvus_enrichment() -> MilvusDBContainerInfo:
401400
return db
402401

403402
@staticmethod
404-
def post_milvus_enrichment(db: MilvusDBContainerInfo):
403+
def post_milvus_enrichment(db: VectorDBContainerInfo):
405404
try:
406-
MilvusEnrichmentTestHelper.stop_db_container(db)
405+
MilvusTestHelpers.stop_db_container(db)
407406
except Exception as e:
408407
raise TestContainerTeardownError(
409408
f"Milvus container failed to tear down: {str(e)}")

sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py

Lines changed: 5 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,6 @@
4747
from pymilvus.exceptions import MilvusException
4848
from pymilvus.milvus_client import IndexParams
4949

50-
from apache_beam.transforms.enrichment import Enrichment
51-
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
52-
from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
5350
from apache_beam.ml.rag.enrichment.milvus_search import HybridSearchParameters
5451
from apache_beam.ml.rag.enrichment.milvus_search import KeywordSearchMetrics
5552
from apache_beam.ml.rag.enrichment.milvus_search import KeywordSearchParameters
@@ -59,6 +56,9 @@
5956
from apache_beam.ml.rag.enrichment.milvus_search import MilvusSearchParameters
6057
from apache_beam.ml.rag.enrichment.milvus_search import VectorSearchMetrics
6158
from apache_beam.ml.rag.enrichment.milvus_search import VectorSearchParameters
59+
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
60+
from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
61+
from apache_beam.transforms.enrichment import Enrichment
6262
except ImportError as e:
6363
raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}')
6464

@@ -235,77 +235,6 @@ def __getitem__(self, key):
235235
}
236236

237237

238-
def initialize_db_with_data(connc_params: MilvusConnectionParameters):
239-
# Open the connection to the milvus db with retry.
240-
def create_client():
241-
return MilvusClient(**connc_params.__dict__)
242-
243-
client = retry_with_backoff(
244-
create_client,
245-
max_retries=3,
246-
retry_delay=1.0,
247-
operation_name="Test Milvus client connection",
248-
exception_types=(MilvusException, ))
249-
250-
# Configure schema.
251-
field_schemas: List[FieldSchema] = cast(
252-
List[FieldSchema], MILVUS_IT_CONFIG["fields"])
253-
schema = CollectionSchema(
254-
fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"])
255-
256-
# Create collection with the schema.
257-
collection_name = MILVUS_IT_CONFIG["collection_name"]
258-
index_function: Callable[[], IndexParams] = cast(
259-
Callable[[], IndexParams], MILVUS_IT_CONFIG["index"])
260-
client.create_collection(
261-
collection_name=collection_name,
262-
schema=schema,
263-
index_params=index_function())
264-
265-
# Assert that collection was created.
266-
collection_error = f"Expected collection '{collection_name}' to be created."
267-
assert client.has_collection(collection_name), collection_error
268-
269-
# Gather all fields we have excluding 'sparse_embedding_bm25' special field.
270-
fields = list(map(lambda field: field.name, field_schemas))
271-
272-
# Prep data for indexing. Currently we can't insert sparse vectors for BM25
273-
# sparse embedding field as it would be automatically generated by Milvus
274-
# through the registered BM25 function.
275-
data_ready_to_index = []
276-
for doc in MILVUS_IT_CONFIG["corpus"]:
277-
item = {}
278-
for field in fields:
279-
if field.startswith("dense_embedding"):
280-
item[field] = doc["dense_embedding"]
281-
elif field == "sparse_embedding_inner_product":
282-
item[field] = doc["sparse_embedding"]
283-
elif field == "sparse_embedding_bm25":
284-
# It is automatically generated by Milvus from the content field.
285-
continue
286-
else:
287-
item[field] = doc[field]
288-
data_ready_to_index.append(item)
289-
290-
# Index data.
291-
result = client.insert(
292-
collection_name=collection_name, data=data_ready_to_index)
293-
294-
# Assert that the intended data has been properly indexed.
295-
insertion_err = f'failed to insert the {result["insert_count"]} data points'
296-
assert result["insert_count"] == len(data_ready_to_index), insertion_err
297-
298-
# Release the collection from memory. It will be loaded lazily when the
299-
# enrichment handler is invoked.
300-
client.release_collection(collection_name)
301-
302-
# Close the connection to the Milvus database, as no further preparation
303-
# operations are needed before executing the enrichment handler.
304-
client.close()
305-
306-
return collection_name
307-
308-
309238
@pytest.mark.require_docker_in_docker
310239
@unittest.skipUnless(
311240
platform.system() == "Linux",
@@ -329,7 +258,8 @@ def setUpClass(cls):
329258
db_name=cls._db.id,
330259
token=cls._db.token)
331260
cls._collection_load_params = MilvusCollectionLoadParameters()
332-
cls._collection_name = initialize_db_with_data(cls._connection_params)
261+
cls._collection_name = initialize_db_with_data(
262+
cls._connection_params, MILVUS_IT_CONFIG)
333263

334264
@classmethod
335265
def tearDownClass(cls):

sdks/python/apache_beam/ml/rag/test_utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,24 @@
2121
import socket
2222
import tempfile
2323
from dataclasses import dataclass
24+
from typing import Callable
2425
from typing import List
2526
from typing import Optional
27+
from typing import cast
2628

2729
import yaml
2830
from testcontainers.core.config import testcontainers_config
2931
from testcontainers.core.generic import DbContainer
3032
from testcontainers.milvus import MilvusContainer
33+
from pymilvus import CollectionSchema
34+
from pymilvus import FieldSchema
35+
from pymilvus import MilvusClient
36+
from pymilvus.exceptions import MilvusException
37+
from pymilvus.milvus_client import IndexParams
3138

3239
from apache_beam.ml.rag.types import Chunk
40+
from apache_beam.ml.rag.enrichment.milvus_search import MilvusConnectionParameters
41+
from apache_beam.ml.rag.utils import retry_with_backoff
3342

3443
_LOGGER = logging.getLogger(__name__)
3544

@@ -180,6 +189,76 @@ def stop_db_container(db_info: VectorDBContainerInfo):
180189
db_info.container.stop()
181190
_LOGGER.info("milvus db container stopped successfully.")
182191

192+
def initialize_db_with_data(
193+
connc_params: MilvusConnectionParameters, config: dict):
194+
# Open the connection to the milvus db with retry.
195+
def create_client():
196+
return MilvusClient(**connc_params.__dict__)
197+
198+
client = retry_with_backoff(
199+
create_client,
200+
max_retries=3,
201+
retry_delay=1.0,
202+
operation_name="Test Milvus client connection",
203+
exception_types=(MilvusException, ))
204+
205+
# Configure schema.
206+
field_schemas: List[FieldSchema] = cast(List[FieldSchema], config["fields"])
207+
schema = CollectionSchema(
208+
fields=field_schemas, functions=config["functions"])
209+
210+
# Create collection with the schema.
211+
collection_name = config["collection_name"]
212+
index_function: Callable[[], IndexParams] = cast(
213+
Callable[[], IndexParams], config["index"])
214+
client.create_collection(
215+
collection_name=collection_name,
216+
schema=schema,
217+
index_params=index_function())
218+
219+
# Assert that collection was created.
220+
collection_error = f"Expected collection '{collection_name}' to be created."
221+
assert client.has_collection(collection_name), collection_error
222+
223+
# Gather all fields we have excluding 'sparse_embedding_bm25' special field.
224+
fields = list(map(lambda field: field.name, field_schemas))
225+
226+
# Prep data for indexing. Currently we can't insert sparse vectors for BM25
227+
# sparse embedding field as it would be automatically generated by Milvus
228+
# through the registered BM25 function.
229+
data_ready_to_index = []
230+
for doc in config["corpus"]:
231+
item = {}
232+
for field in fields:
233+
if field.startswith("dense_embedding"):
234+
item[field] = doc["dense_embedding"]
235+
elif field == "sparse_embedding_inner_product":
236+
item[field] = doc["sparse_embedding"]
237+
elif field == "sparse_embedding_bm25":
238+
# It is automatically generated by Milvus from the content field.
239+
continue
240+
else:
241+
item[field] = doc[field]
242+
data_ready_to_index.append(item)
243+
244+
# Index data.
245+
result = client.insert(
246+
collection_name=collection_name, data=data_ready_to_index)
247+
248+
# Assert that the intended data has been properly indexed.
249+
insertion_err = f'failed to insert the {result["insert_count"]} data points'
250+
assert result["insert_count"] == len(data_ready_to_index), insertion_err
251+
252+
# Release the collection from memory. It will be loaded lazily when the
253+
# enrichment handler is invoked.
254+
client.release_collection(collection_name)
255+
256+
# Close the connection to the Milvus database, as no further preparation
257+
# operations are needed before executing the enrichment handler.
258+
client.close()
259+
260+
return collection_name
261+
183262
@staticmethod
184263
@contextlib.contextmanager
185264
def create_user_yaml(service_port: int, max_vector_field_num=5):

0 commit comments

Comments
 (0)