Skip to content

Commit bb84826

Browse files
sdks/python: add missing utils
1 parent 57c093e commit bb84826

File tree

2 files changed

+404
-0
lines changed

2 files changed

+404
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import contextlib
2+
from dataclasses import dataclass
3+
import os
4+
import socket
5+
import tempfile
6+
import logging
7+
from typing import Dict, List, Optional
8+
from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES
9+
from testcontainers.core.config import testcontainers_config
10+
from testcontainers.core.generic import DbContainer
11+
from testcontainers.milvus import MilvusContainer
12+
import yaml
13+
14+
from apache_beam.ml.rag.types import Chunk
15+
16+
_LOGGER = logging.getLogger(__name__)
17+
18+
@dataclass
19+
class VectorDBContainerInfo:
20+
"""Container information for vector database test instances.
21+
22+
Holds connection details and container reference for testing with
23+
vector databases like Milvus in containerized environments.
24+
"""
25+
container: DbContainer
26+
host: str
27+
port: int
28+
user: str = ""
29+
password: str = ""
30+
token: str = ""
31+
db_id: str = "default"
32+
33+
@property
34+
def uri(self) -> str:
35+
return f"http://{self.host}:{self.port}"
36+
37+
class TestHelpers:
38+
@staticmethod
39+
def find_free_port():
40+
"""Find a free port on the local machine."""
41+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
42+
# Bind to port 0, which asks OS to assign a free port.
43+
s.bind(('', 0))
44+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
45+
# Return the port number assigned by OS.
46+
return s.getsockname()[1]
47+
48+
class CustomMilvusContainer(MilvusContainer):
49+
"""Custom Milvus container with configurable ports and environment setup.
50+
51+
Extends MilvusContainer to provide custom port binding and environment
52+
configuration for testing with standalone Milvus instances.
53+
"""
54+
def __init__(
55+
self,
56+
image: str,
57+
service_container_port,
58+
healthcheck_container_port,
59+
**kwargs,
60+
) -> None:
61+
# Skip the parent class's constructor and go straight to
62+
# GenericContainer.
63+
super(MilvusContainer, self).__init__(image=image, **kwargs)
64+
self.port = service_container_port
65+
self.healthcheck_port = healthcheck_container_port
66+
self.with_exposed_ports(service_container_port, healthcheck_container_port)
67+
68+
# Get free host ports.
69+
service_host_port = TestHelpers.find_free_port()
70+
healthcheck_host_port = TestHelpers.find_free_port()
71+
72+
# Bind container and host ports.
73+
self.with_bind_ports(service_container_port, service_host_port)
74+
self.with_bind_ports(healthcheck_container_port, healthcheck_host_port)
75+
self.cmd = "milvus run standalone"
76+
77+
# Set environment variables needed for Milvus.
78+
envs = {
79+
"ETCD_USE_EMBED": "true",
80+
"ETCD_DATA_DIR": "/var/lib/milvus/etcd",
81+
"COMMON_STORAGETYPE": "local",
82+
"METRICS_PORT": str(healthcheck_container_port)
83+
}
84+
for env, value in envs.items():
85+
self.with_env(env, value)
86+
87+
88+
class MilvusTestHelpers:
89+
"""Helper utilities for testing Milvus vector database operations.
90+
91+
Provides static methods for managing test containers, configuration files,
92+
and chunk comparison utilities for Milvus-based integration tests.
93+
"""
94+
@staticmethod
95+
def start_db_container(
96+
image="milvusdb/milvus:latest",
97+
max_vec_fields=5,
98+
vector_client_max_retries=3,
99+
tc_max_retries=TC_MAX_TRIES) -> Optional[VectorDBContainerInfo]:
100+
service_container_port = TestHelpers.find_free_port()
101+
healthcheck_container_port = TestHelpers.find_free_port()
102+
user_yaml_creator = MilvusTestHelpers.create_user_yaml
103+
with user_yaml_creator(service_container_port, max_vec_fields) as cfg:
104+
info = None
105+
testcontainers_config.max_tries = tc_max_retries
106+
for i in range(vector_client_max_retries):
107+
try:
108+
vector_db_container = CustomMilvusContainer(
109+
image=image,
110+
service_container_port=service_container_port,
111+
healthcheck_container_port=healthcheck_container_port)
112+
vector_db_container = vector_db_container.with_volume_mapping(
113+
cfg, "/milvus/configs/user.yaml")
114+
vector_db_container.start()
115+
host = vector_db_container.get_container_host_ip()
116+
port = vector_db_container.get_exposed_port(service_container_port)
117+
info = VectorDBContainerInfo(vector_db_container, host, port)
118+
testcontainers_config.max_tries = TC_MAX_TRIES
119+
_LOGGER.info(
120+
"milvus db container started successfully on %s.", info.uri)
121+
break
122+
except Exception as e:
123+
stdout_logs, stderr_logs = vector_db_container.get_logs()
124+
stdout_logs = stdout_logs.decode("utf-8")
125+
stderr_logs = stderr_logs.decode("utf-8")
126+
_LOGGER.warning(
127+
"Retry %d/%d: Failed to start Milvus DB container. Reason: %s. "
128+
"STDOUT logs:\n%s\nSTDERR logs:\n%s",
129+
i + 1,
130+
vector_client_max_retries,
131+
e,
132+
stdout_logs,
133+
stderr_logs)
134+
if i == vector_client_max_retries - 1:
135+
_LOGGER.error(
136+
"Unable to start milvus db container for I/O tests after %d "
137+
"retries. Tests cannot proceed. STDOUT logs:\n%s\n"
138+
"STDERR logs:\n%s",
139+
vector_client_max_retries,
140+
stdout_logs,
141+
stderr_logs)
142+
raise e
143+
return info
144+
145+
@staticmethod
146+
def stop_db_container(db_info: VectorDBContainerInfo):
147+
if db_info is None:
148+
_LOGGER.warning("Milvus db info is None. Skipping stop operation.")
149+
return
150+
try:
151+
_LOGGER.debug("Stopping milvus db container.")
152+
db_info.container.stop()
153+
_LOGGER.info("milvus db container stopped successfully.")
154+
except Exception as e:
155+
_LOGGER.warning(
156+
"Error encountered while stopping milvus db container: %s", e)
157+
158+
@staticmethod
159+
@contextlib.contextmanager
160+
def create_user_yaml(service_port: int, max_vector_field_num=5):
161+
"""Creates a temporary user.yaml file for Milvus configuration.
162+
163+
This user yaml file overrides Milvus default configurations. It sets
164+
the Milvus service port to the specified container service port. It
165+
overrides the default max vector field number which is 4 as yet with the
166+
user-defined one.
167+
168+
Args:
169+
service_port: Port number for the Milvus service.
170+
max_vector_field_num: Max number of vec fields allowed per collection.
171+
172+
Yields:
173+
str: Path to the created temporary yaml file.
174+
"""
175+
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml',
176+
delete=False) as temp_file:
177+
# Define the content for user.yaml.
178+
user_config = {
179+
'proxy': {
180+
'maxVectorFieldNum': max_vector_field_num, 'port': service_port
181+
}
182+
}
183+
184+
# Write the content to the file.
185+
yaml.dump(user_config, temp_file, default_flow_style=False)
186+
path = temp_file.name
187+
188+
try:
189+
yield path
190+
finally:
191+
if os.path.exists(path):
192+
os.remove(path)
193+
194+
@staticmethod
195+
def assert_chunks_equivalent(
196+
actual_chunks: List[Chunk], expected_chunks: List[Chunk]):
197+
"""assert_chunks_equivalent checks for presence rather than exact match"""
198+
# Sort both lists by ID to ensure consistent ordering.
199+
actual_sorted = sorted(actual_chunks, key=lambda c: c.id)
200+
expected_sorted = sorted(expected_chunks, key=lambda c: c.id)
201+
202+
actual_len = len(actual_sorted)
203+
expected_len = len(expected_sorted)
204+
err_msg = (
205+
f"Different number of chunks, actual: {actual_len}, "
206+
f"expected: {expected_len}")
207+
assert actual_len == expected_len, err_msg
208+
209+
for actual, expected in zip(actual_sorted, expected_sorted):
210+
# Assert that IDs match.
211+
assert actual.id == expected.id
212+
213+
# Assert that dense embeddings match.
214+
err_msg = f"Dense embedding mismatch for chunk {actual.id}"
215+
assert actual.dense_embedding == expected.dense_embedding, err_msg
216+
217+
# Assert that sparse embeddings match.
218+
err_msg = f"Sparse embedding mismatch for chunk {actual.id}"
219+
assert actual.sparse_embedding == expected.sparse_embedding, err_msg
220+
221+
# Assert that text content match.
222+
err_msg = f"Text Content mismatch for chunk {actual.id}"
223+
assert actual.content.text == expected.content.text, err_msg
224+
225+
# For enrichment_data, be more flexible.
226+
# If "expected" has values for enrichment_data but actual doesn't, that's
227+
# acceptable since vector search results can vary based on many factors
228+
# including implementation details, vector database state, and slight
229+
# variations in similarity calculations.
230+
231+
# First ensure the enrichment data key exists.
232+
err_msg = f"Missing enrichment_data key in chunk {actual.id}"
233+
assert 'enrichment_data' in actual.metadata, err_msg
234+
235+
# For enrichment_data, ensure consistent ordering of results.
236+
actual_data = actual.metadata['enrichment_data']
237+
expected_data = expected.metadata['enrichment_data']
238+
239+
# If actual has enrichment data, then perform detailed validation.
240+
if actual_data:
241+
# Ensure the id key exist.
242+
err_msg = f"Missing id key in metadata {actual.id}"
243+
assert 'id' in actual_data, err_msg
244+
245+
# Validate IDs have consistent ordering.
246+
actual_ids = sorted(actual_data['id'])
247+
expected_ids = sorted(expected_data['id'])
248+
err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}"
249+
assert actual_ids == expected_ids, err_msg
250+
251+
# Ensure the distance key exist.
252+
err_msg = f"Missing distance key in metadata {actual.id}"
253+
assert 'distance' in actual_data, err_msg
254+
255+
# Validate distances exist and have same length as IDs.
256+
actual_distances = actual_data['distance']
257+
expected_distances = expected_data['distance']
258+
err_msg = (
259+
"Number of distances doesn't match number of IDs for "
260+
f"chunk {actual.id}")
261+
assert len(actual_distances) == len(expected_distances), err_msg
262+
263+
# Ensure the fields key exist.
264+
err_msg = f"Missing fields key in metadata {actual.id}"
265+
assert 'fields' in actual_data, err_msg
266+
267+
# Validate fields have consistent content.
268+
# Sort fields by 'id' to ensure consistent ordering.
269+
actual_fields_sorted = sorted(
270+
actual_data['fields'], key=lambda f: f.get('id', 0))
271+
expected_fields_sorted = sorted(
272+
expected_data['fields'], key=lambda f: f.get('id', 0))
273+
274+
# Compare field IDs.
275+
actual_field_ids = [f.get('id') for f in actual_fields_sorted]
276+
expected_field_ids = [f.get('id') for f in expected_fields_sorted]
277+
err_msg = f"Field IDs don't match for chunk {actual.id}"
278+
assert actual_field_ids == expected_field_ids, err_msg
279+
280+
# Compare field content.
281+
for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted):
282+
# Ensure the id key exist.
283+
err_msg = f"Missing id key in metadata.fields {actual.id}"
284+
assert 'id' in a_f, err_msg
285+
286+
err_msg = f"Field ID mismatch chunk {actual.id}"
287+
assert a_f['id'] == e_f['id'], err_msg
288+
289+
# Validate field metadata.
290+
err_msg = f"Field Metadata doesn't match for chunk {actual.id}"
291+
assert a_f['metadata'] == e_f['metadata'], err_msg

0 commit comments

Comments
 (0)