Skip to content

Commit f29c481

Browse files
sdks/python: remove the examples from this PR
1 parent dbfd5f5 commit f29c481

File tree

2 files changed

+60
-85
lines changed

2 files changed

+60
-85
lines changed

sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -116,75 +116,3 @@ def enrichment_with_vertex_ai_legacy():
116116
| "Enrich W/ Vertex AI" >> Enrichment(vertex_ai_handler)
117117
| "Print" >> beam.Map(print))
118118
# [END enrichment_with_vertex_ai_legacy]
119-
120-
121-
def enrichment_with_milvus():
122-
# [START enrichment_with_milvus]
123-
import os
124-
import apache_beam as beam
125-
from apache_beam.ml.rag.types import Content
126-
from apache_beam.ml.rag.types import Chunk
127-
from apache_beam.ml.rag.types import Embedding
128-
from apache_beam.transforms.enrichment import Enrichment
129-
from apache_beam.ml.rag.enrichment.milvus_search import (
130-
MilvusSearchEnrichmentHandler,
131-
MilvusConnectionParameters,
132-
MilvusSearchParameters,
133-
MilvusCollectionLoadParameters,
134-
VectorSearchParameters,
135-
VectorSearchMetrics)
136-
137-
uri = os.environ.get("MILVUS_VECTOR_DB_URI")
138-
user = os.environ.get("MILVUS_VECTOR_DB_USER")
139-
password = os.environ.get("MILVUS_VECTOR_DB_PASSWORD")
140-
db_id = os.environ.get("MILVUS_VECTOR_DB_ID")
141-
token = os.environ.get("MILVUS_VECTOR_DB_TOKEN")
142-
collection_name = os.environ.get("MILVUS_VECTOR_DB_COLLECTION_NAME")
143-
144-
data = [
145-
Chunk(
146-
id="query1",
147-
embedding=Embedding(dense_embedding=[0.1, 0.2, 0.3]),
148-
content=Content())
149-
]
150-
151-
connection_parameters = MilvusConnectionParameters(
152-
uri, user, password, db_id, token)
153-
154-
# The first condition (language == "en") excludes documents in other
155-
# languages. Initially, this gives us two documents. After applying the second
156-
# condition (cost < 50), only the first document returns in search results.
157-
filter_expr = 'metadata["language"] == "en" AND cost < 50'
158-
159-
search_params = {"metric_type": VectorSearchMetrics.COSINE.value, "nprobe": 1}
160-
161-
vector_search_params = VectorSearchParameters(
162-
anns_field="dense_embedding_cosine",
163-
limit=3,
164-
filter=filter_expr,
165-
search_params=search_params)
166-
167-
search_parameters = MilvusSearchParameters(
168-
collection_name=collection_name,
169-
search_strategy=vector_search_params,
170-
output_fields=["id", "content", "domain", "cost", "metadata"],
171-
round_decimal=2)
172-
173-
# MilvusCollectionLoadParameters is optional and provides fine-grained control
174-
# over how collections are loaded into memory. For simple use cases or when
175-
# getting started, this parameter can be omitted to use default loading
176-
# behavior. Consider using it in resource-constrained environments to optimize
177-
# memory usage and query performance.
178-
collection_load_parameters = MilvusCollectionLoadParameters()
179-
180-
milvus_search_handler = MilvusSearchEnrichmentHandler(
181-
connection_parameters=connection_parameters,
182-
search_parameters=search_parameters,
183-
collection_load_parameters=collection_load_parameters)
184-
with beam.Pipeline() as p:
185-
_ = (
186-
p
187-
| "Create" >> beam.Create(data)
188-
| "Enrich W/ Milvus" >> Enrichment(milvus_search_handler)
189-
| "Print" >> beam.Map(print))
190-
# [END enrichment_with_milvus]
Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# coding=utf-8
12
#
23
# Licensed to the Apache Software Foundation (ASF) under one or more
34
# contributor license agreements. See the NOTICE file distributed with
@@ -14,28 +15,74 @@
1415
# See the License for the specific language governing permissions and
1516
# limitations under the License.
1617
#
18+
# pytype: skip-file
19+
# pylint: disable=line-too-long
1720

18-
import logging
1921
import unittest
22+
from io import StringIO
2023

21-
import apache_beam as beam
24+
import mock
2225

23-
# pylint: disable=ungrouped-imports
26+
# pylint: disable=unused-import
2427
try:
25-
from apache_beam.transforms.enrichment import cross_join
28+
from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_bigtable, \
29+
enrichment_with_vertex_ai_legacy
30+
from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_vertex_ai
31+
from apache_beam.io.requestresponse import RequestResponseIO
2632
except ImportError:
27-
raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
33+
raise unittest.SkipTest('RequestResponseIO dependencies are not installed')
2834

2935

30-
class TestEnrichmentTransform(unittest.TestCase):
31-
def test_cross_join(self):
32-
left = {'id': 1, 'key': 'city'}
33-
right = {'id': 1, 'value': 'durham'}
34-
expected = beam.Row(id=1, key='city', value='durham')
35-
output = cross_join(left, right)
36-
self.assertEqual(expected, output)
36+
def validate_enrichment_with_bigtable():
37+
expected = '''[START enrichment_with_bigtable]
38+
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
39+
Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'})
40+
Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'})
41+
[END enrichment_with_bigtable]'''.splitlines()[1:-1]
42+
return expected
43+
44+
45+
def validate_enrichment_with_vertex_ai():
46+
expected = '''[START enrichment_with_vertex_ai]
47+
Row(user_id='2963', product_id=14235, sale_price=15.0, age=12.0, state='1', gender='1', country='1')
48+
Row(user_id='21422', product_id=11203, sale_price=12.0, age=12.0, state='0', gender='0', country='0')
49+
Row(user_id='20592', product_id=8579, sale_price=9.0, age=12.0, state='2', gender='1', country='2')
50+
[END enrichment_with_vertex_ai]'''.splitlines()[1:-1]
51+
return expected
52+
53+
54+
def validate_enrichment_with_vertex_ai_legacy():
55+
expected = '''[START enrichment_with_vertex_ai_legacy]
56+
Row(entity_id='movie_01', title='The Shawshank Redemption', genres='Drama')
57+
Row(entity_id='movie_02', title='The Shining', genres='Horror')
58+
Row(entity_id='movie_04', title='The Dark Knight', genres='Action')
59+
[END enrichment_with_vertex_ai_legacy]'''.splitlines()[1:-1]
60+
return expected
61+
62+
63+
@mock.patch('sys.stdout', new_callable=StringIO)
64+
class EnrichmentTest(unittest.TestCase):
65+
def test_enrichment_with_bigtable(self, mock_stdout):
66+
enrichment_with_bigtable()
67+
output = mock_stdout.getvalue().splitlines()
68+
expected = validate_enrichment_with_bigtable()
69+
self.assertEqual(output, expected)
70+
71+
def test_enrichment_with_vertex_ai(self, mock_stdout):
72+
enrichment_with_vertex_ai()
73+
output = mock_stdout.getvalue().splitlines()
74+
expected = validate_enrichment_with_vertex_ai()
75+
76+
for i in range(len(expected)):
77+
self.assertEqual(set(output[i].split(',')), set(expected[i].split(',')))
78+
79+
def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
80+
enrichment_with_vertex_ai_legacy()
81+
output = mock_stdout.getvalue().splitlines()
82+
expected = validate_enrichment_with_vertex_ai_legacy()
83+
self.maxDiff = None
84+
self.assertEqual(output, expected)
3785

3886

3987
if __name__ == '__main__':
40-
logging.getLogger().setLevel(logging.INFO)
4188
unittest.main()

0 commit comments

Comments
 (0)