Skip to content

Commit 6cf8952

Browse files
sdks/python: itest CloudSQLEnrichmentHandler
1 parent 95ec739 commit 6cf8952

File tree

1 file changed

+396
-0
lines changed

1 file changed

+396
-0
lines changed
Lines changed: 396 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,396 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import logging
18+
import unittest
19+
from unittest.mock import MagicMock
20+
import pytest
21+
import apache_beam as beam
22+
from apache_beam.coders import coders
23+
from apache_beam.testing.test_pipeline import TestPipeline
24+
from apache_beam.testing.util import BeamAssertException
25+
from apache_beam.transforms.enrichment import Enrichment
26+
from apache_beam.transforms.enrichment_handlers.cloudsql import (
27+
CloudSQLEnrichmentHandler,
28+
DatabaseTypeAdapter,
29+
ExceptionLevel,
30+
)
31+
from testcontainers.redis import RedisContainer
32+
from google.cloud.sql.connector import Connector
33+
import os
34+
35+
_LOGGER = logging.getLogger(__name__)
36+
37+
38+
def _row_key_fn(request: beam.Row, key_id="product_id") -> tuple[str]:
39+
key_value = str(getattr(request, key_id))
40+
return (key_id, key_value)
41+
42+
43+
class ValidateResponse(beam.DoFn):
44+
"""ValidateResponse validates if a PCollection of `beam.Row`
45+
has the required fields."""
46+
def __init__(
47+
self,
48+
n_fields: int,
49+
fields: list[str],
50+
enriched_fields: dict[str, list[str]],
51+
):
52+
self.n_fields = n_fields
53+
self._fields = fields
54+
self._enriched_fields = enriched_fields
55+
56+
def process(self, element: beam.Row, *args, **kwargs):
57+
element_dict = element.as_dict()
58+
if len(element_dict.keys()) != self.n_fields:
59+
raise BeamAssertException(
60+
"Expected %d fields in enriched PCollection:" % self.n_fields)
61+
62+
for field in self._fields:
63+
if field not in element_dict or element_dict[field] is None:
64+
raise BeamAssertException(f"Expected a not None field: {field}")
65+
66+
for key in self._enriched_fields:
67+
if key not in element_dict:
68+
raise BeamAssertException(
69+
f"Response from Cloud SQL should contain {key} column.")
70+
71+
72+
def create_rows(cursor):
73+
"""Insert test rows into the Cloud SQL database table."""
74+
cursor.execute(
75+
"""
76+
CREATE TABLE IF NOT EXISTS products (
77+
product_id SERIAL PRIMARY KEY,
78+
product_name VARCHAR(255),
79+
product_stock INT
80+
)
81+
""")
82+
cursor.execute(
83+
"""
84+
INSERT INTO products (product_name, product_stock)
85+
VALUES
86+
('pixel 5', 2),
87+
('pixel 6', 4),
88+
('pixel 7', 20),
89+
('pixel 8', 10),
90+
('iphone 11', 3),
91+
('iphone 12', 7),
92+
('iphone 13', 8),
93+
('iphone 14', 3)
94+
ON CONFLICT DO NOTHING
95+
""")
96+
97+
98+
@pytest.mark.uses_testcontainer
99+
class TestCloudSQLEnrichment(unittest.TestCase):
100+
@classmethod
101+
def setUpClass(cls):
102+
cls.project_id = "apache-beam-testing"
103+
cls.region_id = "us-central1"
104+
cls.instance_id = "beam-test"
105+
cls.database_id = "postgres"
106+
cls.database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER")
107+
cls.database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD")
108+
cls.table_id = "products"
109+
cls.row_key = "product_id"
110+
cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL
111+
cls.req = [
112+
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
113+
beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
114+
beam.Row(sale_id=5, customer_id=5, product_id=3, quantity=2),
115+
beam.Row(sale_id=7, customer_id=7, product_id=4, quantity=1),
116+
]
117+
cls.connector = Connector()
118+
cls.client = cls.connector.connect(
119+
f"{cls.project_id}:{cls.region_id}:{cls.instance_id}",
120+
driver=cls.database_type_adapter.value,
121+
db=cls.database_id,
122+
user=cls.database_user,
123+
password=cls.database_password,
124+
)
125+
cls.cursor = cls.client.cursor()
126+
create_rows(cls.cursor)
127+
cls.cache_client_retries = 3
128+
129+
def _start_cache_container(self):
130+
for i in range(self.cache_client_retries):
131+
try:
132+
self.container = RedisContainer(image="redis:7.2.4")
133+
self.container.start()
134+
self.host = self.container.get_container_host_ip()
135+
self.port = self.container.get_exposed_port(6379)
136+
self.cache_client = self.container.get_client()
137+
break
138+
except Exception as e:
139+
if i == self.cache_client_retries - 1:
140+
_LOGGER.error(
141+
f"Unable to start redis container for RRIO tests after {self.cache_client_retries} retries."
142+
)
143+
raise e
144+
145+
@classmethod
146+
def tearDownClass(cls):
147+
cls.cursor.close()
148+
cls.client.close()
149+
cls.connector.close()
150+
cls.cursor, cls.client, cls.connector = None, None, None
151+
152+
def test_enrichment_with_cloudsql(self):
153+
expected_fields = [
154+
"sale_id",
155+
"customer_id",
156+
"product_id",
157+
"quantity",
158+
"product_name",
159+
"product_stock",
160+
]
161+
expected_enriched_fields = ["product_id", "product_name", "product_stock"]
162+
cloudsql = CloudSQLEnrichmentHandler(
163+
region_id=self.region_id,
164+
project_id=self.project_id,
165+
instance_id=self.instance_id,
166+
database_type_adapter=self.database_type_adapter,
167+
database_id=self.database_id,
168+
database_user=self.database_user,
169+
database_password=self.database_password,
170+
table_id=self.table_id,
171+
row_key=self.row_key,
172+
)
173+
with TestPipeline(is_integration_test=True) as test_pipeline:
174+
_ = (
175+
test_pipeline
176+
| "Create" >> beam.Create(self.req)
177+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql)
178+
| "Validate Response" >> beam.ParDo(
179+
ValidateResponse(
180+
len(expected_fields),
181+
expected_fields,
182+
expected_enriched_fields,
183+
)))
184+
185+
def test_enrichment_with_cloudsql_no_enrichment(self):
186+
expected_fields = ["sale_id", "customer_id", "product_id", "quantity"]
187+
expected_enriched_fields = {}
188+
cloudsql = CloudSQLEnrichmentHandler(
189+
region_id=self.region_id,
190+
project_id=self.project_id,
191+
instance_id=self.instance_id,
192+
database_type_adapter=self.database_type_adapter,
193+
database_id=self.database_id,
194+
database_user=self.database_user,
195+
database_password=self.database_password,
196+
table_id=self.table_id,
197+
row_key=self.row_key,
198+
)
199+
req = [beam.Row(sale_id=1, customer_id=1, product_id=99, quantity=1)]
200+
with TestPipeline(is_integration_test=True) as test_pipeline:
201+
_ = (
202+
test_pipeline
203+
| "Create" >> beam.Create(req)
204+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql)
205+
| "Validate Response" >> beam.ParDo(
206+
ValidateResponse(
207+
len(expected_fields),
208+
expected_fields,
209+
expected_enriched_fields,
210+
)))
211+
212+
def test_enrichment_with_cloudsql_raises_key_error(self):
213+
cloudsql = CloudSQLEnrichmentHandler(
214+
region_id=self.region_id,
215+
project_id=self.project_id,
216+
instance_id=self.instance_id,
217+
database_type_adapter=self.database_type_adapter,
218+
database_id=self.database_id,
219+
database_user=self.database_user,
220+
database_password=self.database_password,
221+
table_id=self.table_id,
222+
row_key="car_name",
223+
)
224+
with self.assertRaises(KeyError):
225+
test_pipeline = TestPipeline()
226+
_ = (
227+
test_pipeline
228+
| "Create" >> beam.Create(self.req)
229+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql))
230+
res = test_pipeline.run()
231+
res.wait_until_finish()
232+
233+
def test_enrichment_with_cloudsql_raises_not_found(self):
234+
"""Raises a database error when the GCP Cloud SQL table doesn't exist."""
235+
table_id = "invalid_table"
236+
cloudsql = CloudSQLEnrichmentHandler(
237+
region_id=self.region_id,
238+
project_id=self.project_id,
239+
instance_id=self.instance_id,
240+
database_type_adapter=self.database_type_adapter,
241+
database_id=self.database_id,
242+
database_user=self.database_user,
243+
database_password=self.database_password,
244+
table_id=table_id,
245+
row_key=self.row_key,
246+
)
247+
try:
248+
test_pipeline = beam.Pipeline()
249+
_ = (
250+
test_pipeline
251+
| "Create" >> beam.Create(self.req)
252+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql))
253+
res = test_pipeline.run()
254+
res.wait_until_finish()
255+
except (PgDatabaseError, RuntimeError) as e:
256+
self.assertIn(f'relation "{table_id}" does not exist', str(e))
257+
258+
def test_enrichment_with_cloudsql_exception_level(self):
259+
"""raises a `ValueError` exception when the GCP Cloud SQL query returns
260+
an empty row."""
261+
cloudsql = CloudSQLEnrichmentHandler(
262+
region_id=self.region_id,
263+
project_id=self.project_id,
264+
instance_id=self.instance_id,
265+
database_type_adapter=self.database_type_adapter,
266+
database_id=self.database_id,
267+
database_user=self.database_user,
268+
database_password=self.database_password,
269+
table_id=self.table_id,
270+
row_key=self.row_key,
271+
exception_level=ExceptionLevel.RAISE,
272+
)
273+
req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)]
274+
with self.assertRaises(ValueError):
275+
test_pipeline = beam.Pipeline()
276+
_ = (
277+
test_pipeline
278+
| "Create" >> beam.Create(req)
279+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql))
280+
res = test_pipeline.run()
281+
res.wait_until_finish()
282+
283+
def test_cloudsql_enrichment_with_lambda(self):
284+
expected_fields = [
285+
"sale_id",
286+
"customer_id",
287+
"product_id",
288+
"quantity",
289+
"product_name",
290+
"product_stock",
291+
]
292+
expected_enriched_fields = ["product_id", "product_name", "product_stock"]
293+
cloudsql = CloudSQLEnrichmentHandler(
294+
region_id=self.region_id,
295+
project_id=self.project_id,
296+
instance_id=self.instance_id,
297+
database_type_adapter=self.database_type_adapter,
298+
database_id=self.database_id,
299+
database_user=self.database_user,
300+
database_password=self.database_password,
301+
table_id=self.table_id,
302+
row_key_fn=_row_key_fn,
303+
)
304+
with TestPipeline(is_integration_test=True) as test_pipeline:
305+
_ = (
306+
test_pipeline
307+
| "Create" >> beam.Create(self.req)
308+
| "Enrich W/ CloudSQL" >> Enrichment(cloudsql)
309+
| "Validate Response" >> beam.ParDo(
310+
ValidateResponse(
311+
len(expected_fields),
312+
expected_fields,
313+
expected_enriched_fields)))
314+
315+
@pytest.fixture
316+
def cache_container(self):
317+
# Setup phase: start the container.
318+
self._start_cache_container()
319+
320+
# Hand control to the test.
321+
yield
322+
323+
# Cleanup phase: stop the container. It runs after the test completion
324+
# even if it failed.
325+
self.container.stop()
326+
self.container = None
327+
328+
@pytest.mark.usefixtures("cache_container")
329+
def test_cloudsql_enrichment_with_redis(self):
330+
expected_fields = [
331+
"sale_id",
332+
"customer_id",
333+
"product_id",
334+
"quantity",
335+
"product_name",
336+
"product_stock",
337+
]
338+
expected_enriched_fields = ["product_id", "product_name", "product_stock"]
339+
cloudsql = CloudSQLEnrichmentHandler(
340+
region_id=self.region_id,
341+
project_id=self.project_id,
342+
instance_id=self.instance_id,
343+
database_type_adapter=self.database_type_adapter,
344+
database_id=self.database_id,
345+
database_user=self.database_user,
346+
database_password=self.database_password,
347+
table_id=self.table_id,
348+
row_key_fn=_row_key_fn,
349+
)
350+
with TestPipeline(is_integration_test=True) as test_pipeline:
351+
_ = (
352+
test_pipeline
353+
| "Create1" >> beam.Create(self.req)
354+
| "Enrich W/ CloudSQL1" >> Enrichment(cloudsql).with_redis_cache(
355+
self.host, self.port, 300)
356+
| "Validate Response" >> beam.ParDo(
357+
ValidateResponse(
358+
len(expected_fields),
359+
expected_fields,
360+
expected_enriched_fields,
361+
)))
362+
363+
# Manually check cache entry to verify entries were correctly stored.
364+
c = coders.StrUtf8Coder()
365+
for req in self.req:
366+
key = cloudsql.get_cache_key(req)
367+
response = self.cache_client.get(c.encode(key))
368+
if not response:
369+
raise ValueError("No cache entry found for %s" % key)
370+
371+
# Mock the CloudSQL handler to avoid actual database calls.
372+
# This simulates a cache hit scenario by returning predefined data.
373+
actual = CloudSQLEnrichmentHandler.__call__
374+
CloudSQLEnrichmentHandler.__call__ = MagicMock(
375+
return_value=(
376+
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
377+
beam.Row(),
378+
))
379+
380+
# Run a second pipeline to verify cache is being used.
381+
with TestPipeline(is_integration_test=True) as test_pipeline:
382+
_ = (
383+
test_pipeline
384+
| "Create2" >> beam.Create(self.req)
385+
| "Enrich W/ CloudSQL2" >> Enrichment(cloudsql).with_redis_cache(
386+
self.host, self.port)
387+
| "Validate Response" >> beam.ParDo(
388+
ValidateResponse(
389+
len(expected_fields),
390+
expected_fields,
391+
expected_enriched_fields)))
392+
CloudSQLEnrichmentHandler.__call__ = actual
393+
394+
395+
if __name__ == "__main__":
396+
unittest.main()

0 commit comments

Comments
 (0)