|
| 1 | +# |
| 2 | +# Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | +# contributor license agreements. See the NOTICE file distributed with |
| 4 | +# this work for additional information regarding copyright ownership. |
| 5 | +# The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | +# (the "License"); you may not use this file except in compliance with |
| 7 | +# the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# |
| 17 | +import logging |
| 18 | +import unittest |
| 19 | +from unittest.mock import MagicMock |
| 20 | +import pytest |
| 21 | +import apache_beam as beam |
| 22 | +from apache_beam.coders import coders |
| 23 | +from apache_beam.testing.test_pipeline import TestPipeline |
| 24 | +from apache_beam.testing.util import BeamAssertException |
| 25 | +from apache_beam.transforms.enrichment import Enrichment |
| 26 | +from apache_beam.transforms.enrichment_handlers.cloudsql import ( |
| 27 | + CloudSQLEnrichmentHandler, |
| 28 | + DatabaseTypeAdapter, |
| 29 | + ExceptionLevel, |
| 30 | +) |
| 31 | +from testcontainers.redis import RedisContainer |
| 32 | +from google.cloud.sql.connector import Connector |
| 33 | +import os |
| 34 | + |
| 35 | +_LOGGER = logging.getLogger(__name__) |
| 36 | + |
| 37 | + |
| 38 | +def _row_key_fn(request: beam.Row, key_id="product_id") -> tuple[str]: |
| 39 | + key_value = str(getattr(request, key_id)) |
| 40 | + return (key_id, key_value) |
| 41 | + |
| 42 | + |
| 43 | +class ValidateResponse(beam.DoFn): |
| 44 | + """ValidateResponse validates if a PCollection of `beam.Row` |
| 45 | + has the required fields.""" |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + n_fields: int, |
| 49 | + fields: list[str], |
| 50 | + enriched_fields: dict[str, list[str]], |
| 51 | + ): |
| 52 | + self.n_fields = n_fields |
| 53 | + self._fields = fields |
| 54 | + self._enriched_fields = enriched_fields |
| 55 | + |
| 56 | + def process(self, element: beam.Row, *args, **kwargs): |
| 57 | + element_dict = element.as_dict() |
| 58 | + if len(element_dict.keys()) != self.n_fields: |
| 59 | + raise BeamAssertException( |
| 60 | + "Expected %d fields in enriched PCollection:" % self.n_fields) |
| 61 | + |
| 62 | + for field in self._fields: |
| 63 | + if field not in element_dict or element_dict[field] is None: |
| 64 | + raise BeamAssertException(f"Expected a not None field: {field}") |
| 65 | + |
| 66 | + for key in self._enriched_fields: |
| 67 | + if key not in element_dict: |
| 68 | + raise BeamAssertException( |
| 69 | + f"Response from Cloud SQL should contain {key} column.") |
| 70 | + |
| 71 | + |
| 72 | +def create_rows(cursor): |
| 73 | + """Insert test rows into the Cloud SQL database table.""" |
| 74 | + cursor.execute( |
| 75 | + """ |
| 76 | + CREATE TABLE IF NOT EXISTS products ( |
| 77 | + product_id SERIAL PRIMARY KEY, |
| 78 | + product_name VARCHAR(255), |
| 79 | + product_stock INT |
| 80 | + ) |
| 81 | + """) |
| 82 | + cursor.execute( |
| 83 | + """ |
| 84 | + INSERT INTO products (product_name, product_stock) |
| 85 | + VALUES |
| 86 | + ('pixel 5', 2), |
| 87 | + ('pixel 6', 4), |
| 88 | + ('pixel 7', 20), |
| 89 | + ('pixel 8', 10), |
| 90 | + ('iphone 11', 3), |
| 91 | + ('iphone 12', 7), |
| 92 | + ('iphone 13', 8), |
| 93 | + ('iphone 14', 3) |
| 94 | + ON CONFLICT DO NOTHING |
| 95 | + """) |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.uses_testcontainer |
| 99 | +class TestCloudSQLEnrichment(unittest.TestCase): |
| 100 | + @classmethod |
| 101 | + def setUpClass(cls): |
| 102 | + cls.project_id = "apache-beam-testing" |
| 103 | + cls.region_id = "us-central1" |
| 104 | + cls.instance_id = "beam-test" |
| 105 | + cls.database_id = "postgres" |
| 106 | + cls.database_user = os.getenv("BEAM_TEST_CLOUDSQL_PG_USER") |
| 107 | + cls.database_password = os.getenv("BEAM_TEST_CLOUDSQL_PG_PASSWORD") |
| 108 | + cls.table_id = "products" |
| 109 | + cls.row_key = "product_id" |
| 110 | + cls.database_type_adapter = DatabaseTypeAdapter.POSTGRESQL |
| 111 | + cls.req = [ |
| 112 | + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), |
| 113 | + beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), |
| 114 | + beam.Row(sale_id=5, customer_id=5, product_id=3, quantity=2), |
| 115 | + beam.Row(sale_id=7, customer_id=7, product_id=4, quantity=1), |
| 116 | + ] |
| 117 | + cls.connector = Connector() |
| 118 | + cls.client = cls.connector.connect( |
| 119 | + f"{cls.project_id}:{cls.region_id}:{cls.instance_id}", |
| 120 | + driver=cls.database_type_adapter.value, |
| 121 | + db=cls.database_id, |
| 122 | + user=cls.database_user, |
| 123 | + password=cls.database_password, |
| 124 | + ) |
| 125 | + cls.cursor = cls.client.cursor() |
| 126 | + create_rows(cls.cursor) |
| 127 | + cls.cache_client_retries = 3 |
| 128 | + |
| 129 | + def _start_cache_container(self): |
| 130 | + for i in range(self.cache_client_retries): |
| 131 | + try: |
| 132 | + self.container = RedisContainer(image="redis:7.2.4") |
| 133 | + self.container.start() |
| 134 | + self.host = self.container.get_container_host_ip() |
| 135 | + self.port = self.container.get_exposed_port(6379) |
| 136 | + self.cache_client = self.container.get_client() |
| 137 | + break |
| 138 | + except Exception as e: |
| 139 | + if i == self.cache_client_retries - 1: |
| 140 | + _LOGGER.error( |
| 141 | + f"Unable to start redis container for RRIO tests after {self.cache_client_retries} retries." |
| 142 | + ) |
| 143 | + raise e |
| 144 | + |
| 145 | + @classmethod |
| 146 | + def tearDownClass(cls): |
| 147 | + cls.cursor.close() |
| 148 | + cls.client.close() |
| 149 | + cls.connector.close() |
| 150 | + cls.cursor, cls.client, cls.connector = None, None, None |
| 151 | + |
| 152 | + def test_enrichment_with_cloudsql(self): |
| 153 | + expected_fields = [ |
| 154 | + "sale_id", |
| 155 | + "customer_id", |
| 156 | + "product_id", |
| 157 | + "quantity", |
| 158 | + "product_name", |
| 159 | + "product_stock", |
| 160 | + ] |
| 161 | + expected_enriched_fields = ["product_id", "product_name", "product_stock"] |
| 162 | + cloudsql = CloudSQLEnrichmentHandler( |
| 163 | + region_id=self.region_id, |
| 164 | + project_id=self.project_id, |
| 165 | + instance_id=self.instance_id, |
| 166 | + database_type_adapter=self.database_type_adapter, |
| 167 | + database_id=self.database_id, |
| 168 | + database_user=self.database_user, |
| 169 | + database_password=self.database_password, |
| 170 | + table_id=self.table_id, |
| 171 | + row_key=self.row_key, |
| 172 | + ) |
| 173 | + with TestPipeline(is_integration_test=True) as test_pipeline: |
| 174 | + _ = ( |
| 175 | + test_pipeline |
| 176 | + | "Create" >> beam.Create(self.req) |
| 177 | + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) |
| 178 | + | "Validate Response" >> beam.ParDo( |
| 179 | + ValidateResponse( |
| 180 | + len(expected_fields), |
| 181 | + expected_fields, |
| 182 | + expected_enriched_fields, |
| 183 | + ))) |
| 184 | + |
| 185 | + def test_enrichment_with_cloudsql_no_enrichment(self): |
| 186 | + expected_fields = ["sale_id", "customer_id", "product_id", "quantity"] |
| 187 | + expected_enriched_fields = {} |
| 188 | + cloudsql = CloudSQLEnrichmentHandler( |
| 189 | + region_id=self.region_id, |
| 190 | + project_id=self.project_id, |
| 191 | + instance_id=self.instance_id, |
| 192 | + database_type_adapter=self.database_type_adapter, |
| 193 | + database_id=self.database_id, |
| 194 | + database_user=self.database_user, |
| 195 | + database_password=self.database_password, |
| 196 | + table_id=self.table_id, |
| 197 | + row_key=self.row_key, |
| 198 | + ) |
| 199 | + req = [beam.Row(sale_id=1, customer_id=1, product_id=99, quantity=1)] |
| 200 | + with TestPipeline(is_integration_test=True) as test_pipeline: |
| 201 | + _ = ( |
| 202 | + test_pipeline |
| 203 | + | "Create" >> beam.Create(req) |
| 204 | + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) |
| 205 | + | "Validate Response" >> beam.ParDo( |
| 206 | + ValidateResponse( |
| 207 | + len(expected_fields), |
| 208 | + expected_fields, |
| 209 | + expected_enriched_fields, |
| 210 | + ))) |
| 211 | + |
| 212 | + def test_enrichment_with_cloudsql_raises_key_error(self): |
| 213 | + cloudsql = CloudSQLEnrichmentHandler( |
| 214 | + region_id=self.region_id, |
| 215 | + project_id=self.project_id, |
| 216 | + instance_id=self.instance_id, |
| 217 | + database_type_adapter=self.database_type_adapter, |
| 218 | + database_id=self.database_id, |
| 219 | + database_user=self.database_user, |
| 220 | + database_password=self.database_password, |
| 221 | + table_id=self.table_id, |
| 222 | + row_key="car_name", |
| 223 | + ) |
| 224 | + with self.assertRaises(KeyError): |
| 225 | + test_pipeline = TestPipeline() |
| 226 | + _ = ( |
| 227 | + test_pipeline |
| 228 | + | "Create" >> beam.Create(self.req) |
| 229 | + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) |
| 230 | + res = test_pipeline.run() |
| 231 | + res.wait_until_finish() |
| 232 | + |
| 233 | + def test_enrichment_with_cloudsql_raises_not_found(self): |
| 234 | + """Raises a database error when the GCP Cloud SQL table doesn't exist.""" |
| 235 | + table_id = "invalid_table" |
| 236 | + cloudsql = CloudSQLEnrichmentHandler( |
| 237 | + region_id=self.region_id, |
| 238 | + project_id=self.project_id, |
| 239 | + instance_id=self.instance_id, |
| 240 | + database_type_adapter=self.database_type_adapter, |
| 241 | + database_id=self.database_id, |
| 242 | + database_user=self.database_user, |
| 243 | + database_password=self.database_password, |
| 244 | + table_id=table_id, |
| 245 | + row_key=self.row_key, |
| 246 | + ) |
| 247 | + try: |
| 248 | + test_pipeline = beam.Pipeline() |
| 249 | + _ = ( |
| 250 | + test_pipeline |
| 251 | + | "Create" >> beam.Create(self.req) |
| 252 | + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) |
| 253 | + res = test_pipeline.run() |
| 254 | + res.wait_until_finish() |
| 255 | + except (PgDatabaseError, RuntimeError) as e: |
| 256 | + self.assertIn(f'relation "{table_id}" does not exist', str(e)) |
| 257 | + |
| 258 | + def test_enrichment_with_cloudsql_exception_level(self): |
| 259 | + """raises a `ValueError` exception when the GCP Cloud SQL query returns |
| 260 | + an empty row.""" |
| 261 | + cloudsql = CloudSQLEnrichmentHandler( |
| 262 | + region_id=self.region_id, |
| 263 | + project_id=self.project_id, |
| 264 | + instance_id=self.instance_id, |
| 265 | + database_type_adapter=self.database_type_adapter, |
| 266 | + database_id=self.database_id, |
| 267 | + database_user=self.database_user, |
| 268 | + database_password=self.database_password, |
| 269 | + table_id=self.table_id, |
| 270 | + row_key=self.row_key, |
| 271 | + exception_level=ExceptionLevel.RAISE, |
| 272 | + ) |
| 273 | + req = [beam.Row(sale_id=1, customer_id=1, product_id=11, quantity=1)] |
| 274 | + with self.assertRaises(ValueError): |
| 275 | + test_pipeline = beam.Pipeline() |
| 276 | + _ = ( |
| 277 | + test_pipeline |
| 278 | + | "Create" >> beam.Create(req) |
| 279 | + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql)) |
| 280 | + res = test_pipeline.run() |
| 281 | + res.wait_until_finish() |
| 282 | + |
| 283 | + def test_cloudsql_enrichment_with_lambda(self): |
| 284 | + expected_fields = [ |
| 285 | + "sale_id", |
| 286 | + "customer_id", |
| 287 | + "product_id", |
| 288 | + "quantity", |
| 289 | + "product_name", |
| 290 | + "product_stock", |
| 291 | + ] |
| 292 | + expected_enriched_fields = ["product_id", "product_name", "product_stock"] |
| 293 | + cloudsql = CloudSQLEnrichmentHandler( |
| 294 | + region_id=self.region_id, |
| 295 | + project_id=self.project_id, |
| 296 | + instance_id=self.instance_id, |
| 297 | + database_type_adapter=self.database_type_adapter, |
| 298 | + database_id=self.database_id, |
| 299 | + database_user=self.database_user, |
| 300 | + database_password=self.database_password, |
| 301 | + table_id=self.table_id, |
| 302 | + row_key_fn=_row_key_fn, |
| 303 | + ) |
| 304 | + with TestPipeline(is_integration_test=True) as test_pipeline: |
| 305 | + _ = ( |
| 306 | + test_pipeline |
| 307 | + | "Create" >> beam.Create(self.req) |
| 308 | + | "Enrich W/ CloudSQL" >> Enrichment(cloudsql) |
| 309 | + | "Validate Response" >> beam.ParDo( |
| 310 | + ValidateResponse( |
| 311 | + len(expected_fields), |
| 312 | + expected_fields, |
| 313 | + expected_enriched_fields))) |
| 314 | + |
| 315 | + @pytest.fixture |
| 316 | + def cache_container(self): |
| 317 | + # Setup phase: start the container. |
| 318 | + self._start_cache_container() |
| 319 | + |
| 320 | + # Hand control to the test. |
| 321 | + yield |
| 322 | + |
| 323 | + # Cleanup phase: stop the container. It runs after the test completion |
| 324 | + # even if it failed. |
| 325 | + self.container.stop() |
| 326 | + self.container = None |
| 327 | + |
| 328 | + @pytest.mark.usefixtures("cache_container") |
| 329 | + def test_cloudsql_enrichment_with_redis(self): |
| 330 | + expected_fields = [ |
| 331 | + "sale_id", |
| 332 | + "customer_id", |
| 333 | + "product_id", |
| 334 | + "quantity", |
| 335 | + "product_name", |
| 336 | + "product_stock", |
| 337 | + ] |
| 338 | + expected_enriched_fields = ["product_id", "product_name", "product_stock"] |
| 339 | + cloudsql = CloudSQLEnrichmentHandler( |
| 340 | + region_id=self.region_id, |
| 341 | + project_id=self.project_id, |
| 342 | + instance_id=self.instance_id, |
| 343 | + database_type_adapter=self.database_type_adapter, |
| 344 | + database_id=self.database_id, |
| 345 | + database_user=self.database_user, |
| 346 | + database_password=self.database_password, |
| 347 | + table_id=self.table_id, |
| 348 | + row_key_fn=_row_key_fn, |
| 349 | + ) |
| 350 | + with TestPipeline(is_integration_test=True) as test_pipeline: |
| 351 | + _ = ( |
| 352 | + test_pipeline |
| 353 | + | "Create1" >> beam.Create(self.req) |
| 354 | + | "Enrich W/ CloudSQL1" >> Enrichment(cloudsql).with_redis_cache( |
| 355 | + self.host, self.port, 300) |
| 356 | + | "Validate Response" >> beam.ParDo( |
| 357 | + ValidateResponse( |
| 358 | + len(expected_fields), |
| 359 | + expected_fields, |
| 360 | + expected_enriched_fields, |
| 361 | + ))) |
| 362 | + |
| 363 | + # Manually check cache entry to verify entries were correctly stored. |
| 364 | + c = coders.StrUtf8Coder() |
| 365 | + for req in self.req: |
| 366 | + key = cloudsql.get_cache_key(req) |
| 367 | + response = self.cache_client.get(c.encode(key)) |
| 368 | + if not response: |
| 369 | + raise ValueError("No cache entry found for %s" % key) |
| 370 | + |
| 371 | + # Mock the CloudSQL handler to avoid actual database calls. |
| 372 | + # This simulates a cache hit scenario by returning predefined data. |
| 373 | + actual = CloudSQLEnrichmentHandler.__call__ |
| 374 | + CloudSQLEnrichmentHandler.__call__ = MagicMock( |
| 375 | + return_value=( |
| 376 | + beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), |
| 377 | + beam.Row(), |
| 378 | + )) |
| 379 | + |
| 380 | + # Run a second pipeline to verify cache is being used. |
| 381 | + with TestPipeline(is_integration_test=True) as test_pipeline: |
| 382 | + _ = ( |
| 383 | + test_pipeline |
| 384 | + | "Create2" >> beam.Create(self.req) |
| 385 | + | "Enrich W/ CloudSQL2" >> Enrichment(cloudsql).with_redis_cache( |
| 386 | + self.host, self.port) |
| 387 | + | "Validate Response" >> beam.ParDo( |
| 388 | + ValidateResponse( |
| 389 | + len(expected_fields), |
| 390 | + expected_fields, |
| 391 | + expected_enriched_fields))) |
| 392 | + CloudSQLEnrichmentHandler.__call__ = actual |
| 393 | + |
| 394 | + |
| 395 | +if __name__ == "__main__": |
| 396 | + unittest.main() |
0 commit comments