diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py index d71faa6d8477..bbf85329050b 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py @@ -24,24 +24,26 @@ def enrichment_with_bigtable(): # [START enrichment_with_bigtable] import apache_beam as beam from apache_beam.transforms.enrichment import Enrichment - from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.bigtable import ( + BigTableEnrichmentHandler, ) - project_id = 'apache-beam-testing' - instance_id = 'beam-test' - table_id = 'bigtable-enrichment-test' - row_key = 'product_id' + project_id = "apache-beam-testing" + instance_id = "beam-test" + table_id = "bigtable-enrichment-test" + row_key = "product_id" data = [ beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1), beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3), - beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2) + beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2), ] bigtable_handler = BigTableEnrichmentHandler( project_id=project_id, instance_id=instance_id, table_id=table_id, - row_key=row_key) + row_key=row_key, + ) with beam.Pipeline() as p: _ = ( p @@ -55,16 +57,16 @@ def enrichment_with_vertex_ai(): # [START enrichment_with_vertex_ai] import apache_beam as beam from apache_beam.transforms.enrichment import Enrichment - from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \ - import VertexAIFeatureStoreEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import ( + VertexAIFeatureStoreEnrichmentHandler, ) - project_id = 'apache-beam-testing' - location = 'us-central1' + project_id = "apache-beam-testing" + location = "us-central1" api_endpoint = f"{location}-aiplatform.googleapis.com" data = [ - beam.Row(user_id='2963', product_id=14235, sale_price=15.0), - beam.Row(user_id='21422', product_id=11203, sale_price=12.0), - beam.Row(user_id='20592', product_id=8579, sale_price=9.0), + beam.Row(user_id="2963", product_id=14235, sale_price=15.0), + beam.Row(user_id="21422", product_id=11203, sale_price=12.0), + beam.Row(user_id="20592", product_id=8579, sale_price=9.0), ] vertex_ai_handler = VertexAIFeatureStoreEnrichmentHandler( @@ -88,23 +90,23 @@ def enrichment_with_vertex_ai_legacy(): # [START enrichment_with_vertex_ai_legacy] import apache_beam as beam from apache_beam.transforms.enrichment import Enrichment - from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \ - import VertexAIFeatureStoreLegacyEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import ( + VertexAIFeatureStoreLegacyEnrichmentHandler, ) - project_id = 'apache-beam-testing' - location = 'us-central1' + project_id = "apache-beam-testing" + location = "us-central1" api_endpoint = f"{location}-aiplatform.googleapis.com" data = [ - beam.Row(entity_id="movie_01", title='The Shawshank Redemption'), + beam.Row(entity_id="movie_01", title="The Shawshank Redemption"), beam.Row(entity_id="movie_02", title="The Shining"), - beam.Row(entity_id="movie_04", title='The Dark Knight'), + beam.Row(entity_id="movie_04", title="The Dark Knight"), ] vertex_ai_handler = VertexAIFeatureStoreLegacyEnrichmentHandler( project=project_id, location=location, api_endpoint=api_endpoint, - entity_type_id='movies', + entity_type_id="movies", feature_store_id="movie_prediction_unique", feature_ids=["title", "genres"], row_key="entity_id", @@ -118,6 +120,121 @@ def enrichment_with_vertex_ai_legacy(): # [END enrichment_with_vertex_ai_legacy] +def enrichment_with_bigquery_storage_basic(): + # [START enrichment_with_bigquery_storage_basic] + import apache_beam as beam + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import ( + BigQueryStorageEnrichmentHandler, ) + + project_id = "apache-beam-testing" + dataset = "beam-test" + table_name = "bigquery-enrichment-test-products" + # Sample sales data to enrich + sales_data = [ + beam.Row(sale_id=1001, product_id=101, customer_id=501, quantity=2), + beam.Row(sale_id=1002, product_id=102, customer_id=502, quantity=1), + beam.Row(sale_id=1003, product_id=103, customer_id=503, quantity=5), + ] + + # Basic enrichment - enrich sales data with product information + handler = BigQueryStorageEnrichmentHandler( + project=project_id, + table_name=f"{project_id}.{dataset}.{table_name}", + row_restriction_template="id = {product_id}", + fields=["product_id"], + column_names=[ + "id as product_id", "product_name", "category", "unit_price" + ], + ) + + with beam.Pipeline() as p: + _ = ( + p + | "Create Sales Data" >> beam.Create(sales_data) + | "Enrich with Product Info" >> Enrichment(handler) + | "Print Results" >> beam.Map(print)) + # [END enrichment_with_bigquery_storage_basic] + + +def enrichment_with_bigquery_storage_custom_function(): + # [START enrichment_with_bigquery_storage_custom_function] + import apache_beam as beam + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import ( + BigQueryStorageEnrichmentHandler, ) + + project_id = "apache-beam-testing" + dataset = "beam-test" + table_name = "bigquery-enrichment-test-products" + # Advanced sales data with category and quantity + sales_data = [ + beam.Row( + sale_id=1001, + product_id=101, + category="Electronics", + customer_id=501, + quantity=2, + ), + beam.Row( + sale_id=1002, + product_id=102, + category="Electronics", + customer_id=502, + quantity=4, + ), + beam.Row( + sale_id=1003, + product_id=103, + category="Furniture", + customer_id=503, + quantity=5, + ), + beam.Row( + sale_id=1004, + product_id=101, + category="Electronics", + customer_id=504, + quantity=6, + ), + ] + + def build_row_restriction(condition_values, primary_keys, req_row): + # Only enrich if quantity > 2 and category is Electronics + if req_row.quantity > 2 and req_row.category == "Electronics": + return f'id = {req_row.product_id} AND category = "{req_row.category}"' + else: + return None # Skip enrichment for this row + + def extract_condition_values(req_row): + return { + "product_id": req_row.product_id, + "category": req_row.category, + "quantity": req_row.quantity, + } + + handler = BigQueryStorageEnrichmentHandler( + project=project_id, + table_name=f"{project_id}.{dataset}.{table_name}", + row_restriction_template_fn=build_row_restriction, + condition_value_fn=extract_condition_values, + column_names=[ + "id as prod_id", + "product_name as name", + "category", + "unit_price as price", + ], + ) + + with beam.Pipeline() as p: + _ = ( + p + | "Create Sales Data" >> beam.Create(sales_data) + | "Enrich with Product Info (Advanced)" >> Enrichment(handler) + | "Print Results" >> beam.Map(print)) + + + # [END enrichment_with_bigquery_storage_custom_function] def enrichment_with_google_cloudsql_pg(): # [START enrichment_with_google_cloudsql_pg] import apache_beam as beam diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 904b90710225..be69e2609eac 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -17,6 +17,7 @@ # # pytype: skip-file # pylint: disable=line-too-long +# ruff: noqa: E501 import os import unittest @@ -29,12 +30,19 @@ import mock import pytest -from sqlalchemy.engine import Connection as DBAPIConnection # pylint: disable=unused-import try: - from sqlalchemy import ( - Column, Integer, VARCHAR, Engine, MetaData, create_engine) + from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( + enrichment_with_bigtable, + enrichment_with_vertex_ai_legacy, + ) + from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( + enrichment_with_vertex_ai, ) + from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( + enrichment_with_bigquery_storage_basic, + enrichment_with_bigquery_storage_custom_function, + ) from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( enrichment_with_bigtable, enrichment_with_vertex_ai_legacy) from apache_beam.examples.snippets.transforms.elementwise.enrichment import ( @@ -52,37 +60,73 @@ CloudSQLConnectionConfig, ExternalSQLDBConnectionConfig) from apache_beam.io.requestresponse import RequestResponseIO + from sqlalchemy.engine import Connection as DBAPIConnection + from sqlalchemy import ( + Column, Integer, VARCHAR, Engine, MetaData, create_engine) except ImportError as e: raise unittest.SkipTest(f'RequestResponseIO dependencies not installed: {e}') def validate_enrichment_with_bigtable(): - expected = '''[START enrichment_with_bigtable] -Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) -Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'}) -Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'}) - [END enrichment_with_bigtable]'''.splitlines()[1:-1] + expected = ( + """[START enrichment_with_bigtable] +Row(sale_id=1, customer_id=1, product_id=1, quantity=1, """ + """product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) +Row(sale_id=3, customer_id=3, product_id=2, quantity=3, """ + """product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'}) +Row(sale_id=5, customer_id=5, product_id=4, quantity=2, """ + """product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'}) + [END enrichment_with_bigtable]""").splitlines()[1:-1] return expected def validate_enrichment_with_vertex_ai(): - expected = '''[START enrichment_with_vertex_ai] -Row(user_id='2963', product_id=14235, sale_price=15.0, age=12.0, state='1', gender='1', country='1') -Row(user_id='21422', product_id=11203, sale_price=12.0, age=12.0, state='0', gender='0', country='0') -Row(user_id='20592', product_id=8579, sale_price=9.0, age=12.0, state='2', gender='1', country='2') - [END enrichment_with_vertex_ai]'''.splitlines()[1:-1] + expected = ( + """[START enrichment_with_vertex_ai] +Row(user_id='2963', product_id=14235, sale_price=15.0, """ + """age=12.0, state='1', gender='1', country='1') +Row(user_id='21422', product_id=11203, sale_price=12.0, """ + """age=12.0, state='0', gender='0', country='0') +Row(user_id='20592', product_id=8579, sale_price=9.0, """ + """age=12.0, state='2', gender='1', country='2') + [END enrichment_with_vertex_ai]""").splitlines()[1:-1] return expected def validate_enrichment_with_vertex_ai_legacy(): - expected = '''[START enrichment_with_vertex_ai_legacy] + expected = """[START enrichment_with_vertex_ai_legacy] Row(entity_id='movie_01', title='The Shawshank Redemption', genres='Drama') Row(entity_id='movie_02', title='The Shining', genres='Horror') Row(entity_id='movie_04', title='The Dark Knight', genres='Action') - [END enrichment_with_vertex_ai_legacy]'''.splitlines()[1:-1] + [END enrichment_with_vertex_ai_legacy]""".splitlines()[1:-1] + return expected + + +def validate_enrichment_with_bigquery_storage_basic(): + expected = ( + """[START enrichment_with_bigquery_storage_basic] +Row(sale_id=1001, product_id=101, customer_id=501, quantity=2, """ + """product_id=101, product_name='Laptop Pro', category='Electronics', unit_price=999.99) +Row(sale_id=1002, product_id=102, customer_id=502, quantity=1, """ + """product_id=102, product_name='Wireless Mouse', category='Electronics', unit_price=29.99) +Row(sale_id=1003, product_id=103, customer_id=503, quantity=5, """ + """product_id=103, product_name='Office Chair', category='Furniture', unit_price=199.99) + [END enrichment_with_bigquery_storage_basic]""").splitlines()[1:-1] return expected +def validate_enrichment_with_bigquery_storage_custom_function(): + expected = ( + """[START enrichment_with_bigquery_storage_custom_function] +Row(sale_id=1002, product_id=102, category='Electronics', customer_id=502, """ + """quantity=4, prod_id=102, name='Wireless Mouse', category='Electronics', price=29.99) +Row(sale_id=1004, product_id=101, category='Electronics', customer_id=504, """ + """quantity=6, prod_id=101, name='Laptop Pro', category='Electronics', price=999.99) + [END enrichment_with_bigquery_storage_custom_function]""").splitlines()[1:-1] + return expected + + +@mock.patch("sys.stdout", new_callable=StringIO) def validate_enrichment_with_google_cloudsql_pg(): expected = '''[START enrichment_with_google_cloudsql_pg] Row(product_id=1, name='A', quantity=2, region_id=3) @@ -134,7 +178,7 @@ def test_enrichment_with_vertex_ai(self, mock_stdout): expected = sorted(validate_enrichment_with_vertex_ai()) for i in range(len(expected)): - self.assertEqual(set(output[i].split(',')), set(expected[i].split(','))) + self.assertEqual(set(output[i].split(",")), set(expected[i].split(","))) def test_enrichment_with_vertex_ai_legacy(self, mock_stdout): enrichment_with_vertex_ai_legacy() @@ -310,6 +354,19 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct): os.environ.pop('GOOGLE_CLOUD_SQL_DB_PASSWORD', None) os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None) + def test_enrichment_with_bigquery_storage_basic(self, mock_stdout): + enrichment_with_bigquery_storage_basic() + output = mock_stdout.getvalue().splitlines() + expected = validate_enrichment_with_bigquery_storage_basic() + self.maxDiff = None + self.assertEqual(output, expected) + + def test_enrichment_with_bigquery_storage_custom_function(self, mock_stdout): + enrichment_with_bigquery_storage_custom_function() + output = mock_stdout.getvalue().splitlines() + expected = validate_enrichment_with_bigquery_storage_custom_function() + self.assertEqual(output, expected) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read.py new file mode 100644 index 000000000000..7a70cd3381a1 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read.py @@ -0,0 +1,765 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC & Apache Software Foundation (Original License +# Header) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +BigQuery Enrichment Source Handler using the BigQuery Storage Read API +with support for field renaming via aliases in `column_names`, +additional non-key fields for filtering, dynamic row restriction templates, +experimental parallel stream reading using ThreadPoolExecutor, and custom +row selection. +""" + +import concurrent.futures # For parallel stream reading +import logging +import re +import pyarrow as pa + +from collections.abc import Callable +from collections.abc import Mapping +from typing import Any +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from google.api_core.exceptions import BadRequest +from google.api_core.exceptions import GoogleAPICallError +from google.api_core.exceptions import NotFound +from google.cloud.bigquery_storage import BigQueryReadClient + +try: + from google.cloud.bigquery_storage import types + from google.cloud.bigquery_storage.types import ( + DataFormat, ReadRowsResponse, ReadSession) +except ImportError: + # Fallback for older versions where types might be in different location + from google.cloud.bigquery_storage import types + ReadRowsResponse = types.ReadRowsResponse + ReadSession = types.ReadSession + DataFormat = types.DataFormat + +from apache_beam.pvalue import Row as BeamRow +from apache_beam.transforms.enrichment import EnrichmentSourceHandler + +# --- Configure Logging --- +logger = logging.getLogger(__name__) + +# Type hints for functions +# Input functions expect beam.Row for clarity, use beam.Row.as_dict inside if +# needed +ConditionValueFn = Callable[[BeamRow], Dict[str, Any]] +# Updated RowRestrictionTemplateFn signature based on user provided code +RowRestrictionTemplateFn = Callable[ + [Dict[str, Any], Optional[List[str]], BeamRow], str] +BQRowDict = Dict[str, Any] +# Callback for selecting the "latest" or desired row from multiple BQ results +LatestValueSelectorFn = Optional[Callable[[List[BeamRow], BeamRow], + Optional[BeamRow]]] + +# Regex to parse "column as alias" format, ignoring case for "as" +ALIAS_REGEX = re.compile(r"^(.*?)\s+as\s+(.*)$", re.IGNORECASE) + + +def _validate_bigquery_metadata( + project, + table_name, + row_restriction_template, + row_restriction_template_fn, + fields, + condition_value_fn, + additional_condition_fields, +): + """Validates parameters for Storage API usage.""" + if not project: + raise ValueError("`project` must be provided.") + if not table_name: + raise ValueError("`table_name` must be provided.") + if (row_restriction_template and + row_restriction_template_fn) or (not row_restriction_template and + not row_restriction_template_fn): + raise ValueError( + "Provide exactly one of `row_restriction_template` or " + "`row_restriction_template_fn`.") + if (fields and condition_value_fn) or (not fields and not condition_value_fn): + raise ValueError("Provide exactly one of `fields` or `condition_value_fn`.") + if additional_condition_fields and condition_value_fn: + raise ValueError( + "`additional_condition_fields` cannot be used with " + "`condition_value_fn`.") + + +class BigQueryStorageEnrichmentHandler( + EnrichmentSourceHandler[Union[BeamRow, list[BeamRow]], + Union[BeamRow, list[BeamRow]]]): + """Enrichment handler for Google Cloud BigQuery using the Storage Read API. + (Refer to __init__ for full list of features and arguments) + """ + def __init__( + self, + project: str, + table_name: str, + *, + row_restriction_template: Optional[str] = None, + row_restriction_template_fn: Optional[RowRestrictionTemplateFn] = None, + fields: Optional[list[str]] = None, # Fields for KEY and filtering + additional_condition_fields: Optional[list[str]] = None, # Fields ONLY + # for filtering + column_names: Optional[list[str]] = None, # Columns to select + aliases + condition_value_fn: Optional[ConditionValueFn] = None, # Alt way to get + # filter/key values + min_batch_size: Optional[int] = 1, + max_batch_size: Optional[int] = 1000, # Batching enabled by default + max_batch_duration_secs: Optional[int] = None, + max_parallel_streams: Optional[int] = None, # Max workers for + # ThreadPoolExecutor + max_stream_count: int = 100, # Max streams for BigQuery Storage Read + # --- Added latest_value_selector and primary_keys from user code --- + latest_value_selector: LatestValueSelectorFn = None, + primary_keys: Optional[list[str]] = None, # --- End added parameters --- + ): + """ + Initializes the BigQueryStorageEnrichmentHandler. + + Args: + project: Google Cloud project ID. + table_name: Fully qualified BigQuery table name. + row_restriction_template: (Optional[str]) Template string for a + single row's filter condition. If `row_restriction_template_fn` + is not provided, this template will be formatted with values + from `fields` and `additional_condition_fields`. + row_restriction_template_fn: (Optional[Callable]) Function that + takes (condition_values_dict, primary_keys, request_row) and + returns a fully formatted filter string or template to + be formatted. + fields: (Optional[list[str]]) Input `beam.Row` field names used to + generate the dictionary for formatting the row restriction + template AND for generating the internal join/cache key. + additional_condition_fields: (Optional[list[str]]) Additional input + `beam.Row` field names used ONLY for formatting the row + restriction template. Not part of join/cache key. + column_names: (Optional[list[str]]) Names/aliases of columns to + select. Supports "original_col as alias_col" format. If None, + selects '*'. + condition_value_fn: (Optional[Callable]) Function returning a + dictionary for formatting row restriction template and for + join/cache key. Takes precedence over `fields`. + min_batch_size (Optional[int]): Minimum elements per batch. + Defaults to 1. + max_batch_size (Optional[int]): Maximum elements per batch. + Defaults to 1000 for batching. Set to 1 for single element + processing to disable batching. + max_batch_duration_secs (Optional[int]): Maximum batch buffering + time in seconds. Defaults to 5 seconds. + max_parallel_streams (Optional[int]): Max worker threads for + ThreadPoolExecutor for reading streams in parallel within a + single `__call__`. + max_stream_count (int): Maximum number of streams for BigQuery + Storage Read API. Defaults to 100. Setting to 0 lets BigQuery + decide the optimal number of streams. + latest_value_selector: (Optional) Callback function to select the + desired row when multiple BQ rows match a key. Takes + `List[beam.Row]` (BQ results) and the original `beam.Row` + (request) and returns one `beam.Row` or None. + primary_keys: (Optional[list[str]]) Primary key fields used + potentially by `row_restriction_template_fn` or + `latest_value_selector`. + """ + _validate_bigquery_metadata( + project, + table_name, + row_restriction_template, + row_restriction_template_fn, + fields, + condition_value_fn, + additional_condition_fields, + ) + self.project = project + self.table_name = table_name + self.row_restriction_template = row_restriction_template + self.row_restriction_template_fn = row_restriction_template_fn + self.fields = fields + self.additional_condition_fields = additional_condition_fields or [] + self.condition_value_fn = condition_value_fn + self.max_parallel_streams = max_parallel_streams + self.max_stream_count = max_stream_count + # --- Store new parameters --- + self._latest_value_callback = latest_value_selector + self.primary_keys = primary_keys + # --- End store --- + + self._rename_map: Dict[str, str] = {} + bq_columns_to_select_set: Set[str] = set() + self._select_all_columns = False + if column_names: + for name_or_alias in column_names: + match = ALIAS_REGEX.match(name_or_alias) + if match: + original_col, alias_col = ( + match.group(1).strip(), + match.group(2).strip(), + ) + if not original_col or not alias_col: + raise ValueError(f"Invalid alias: '{name_or_alias}'") + bq_columns_to_select_set.add(original_col) + self._rename_map[original_col] = alias_col + else: + col = name_or_alias.strip() + if not col: + raise ValueError("Empty column name.") + if col == "*": + self._select_all_columns = True + break + bq_columns_to_select_set.add(col) + else: + self._select_all_columns = True + + key_gen_fields_set = set(self.fields or []) + if self._select_all_columns: + self._bq_select_columns = ["*"] + if key_gen_fields_set: + logger.debug( + "Selecting all columns ('*'). Key fields %s assumed present.", + key_gen_fields_set, + ) + else: + fields_to_ensure_selected = set() + if self.fields: + reverse_rename_map = {v: k for k, v in self._rename_map.items()} + for field in self.fields: + original_name = reverse_rename_map.get(field, field) + fields_to_ensure_selected.add(original_name) + # Ensure primary keys (if defined for callback use) are selected if not + # already + if self.primary_keys: + for pk_field in self.primary_keys: + original_pk_name = { + v: k + for k, v in self._rename_map.items() + }.get(pk_field, pk_field) + fields_to_ensure_selected.add(original_pk_name) + + final_select_set = bq_columns_to_select_set.union( + fields_to_ensure_selected) + self._bq_select_columns = sorted(list(final_select_set)) + if not self._bq_select_columns: + raise ValueError("No columns determined for selection.") + + logger.info( + "Handler Initialized. Selecting BQ Columns: %s. Renaming map: %s", + self._bq_select_columns, + self._rename_map, + ) + + self._batching_kwargs = {} + # Set defaults for optional parameters + min_batch_size = min_batch_size or 1 + max_batch_size = max_batch_size or 1000 + max_batch_duration_secs = max_batch_duration_secs or 5 + + if max_batch_size > 1: + self._batching_kwargs["min_batch_size"] = min_batch_size + self._batching_kwargs["max_batch_size"] = max_batch_size + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs + else: + self._batching_kwargs["min_batch_size"] = 1 + self._batching_kwargs["max_batch_size"] = 1 + + self._client: Optional[BigQueryReadClient] = None + self._arrow_schema: Optional[pa.Schema] = None + + def __enter__(self): + if not self._client: + self._client = BigQueryReadClient() + logger.info("BigQueryStorageEnrichmentHandler: Client created.") + self._arrow_schema = None + + def _get_condition_values_dict(self, + req: BeamRow) -> Optional[Dict[str, Any]]: + try: + if self.condition_value_fn: + values_dict = self.condition_value_fn(req) + if values_dict is None or any(v is None for v in values_dict.values()): + logger.warning( + "condition_value_fn returned None or None value(s). " + "Skipping: %s. Values: %s", + req, + values_dict, + ) + return None + return values_dict + elif self.fields is not None: + req_dict = req._asdict() + values_dict = {} + all_req_fields = (self.fields or []) + self.additional_condition_fields + for field in all_req_fields: + # User's provided logic for row_restriction_template_fn handling: + if not self.row_restriction_template_fn: + if field not in req_dict or req_dict[field] is None: + logger.warning( + "Input row missing field '%s' or None (needed " + "for filter). Skipping: %s", + field, + req, + ) + return None + values_dict[field] = req_dict.get(field) # Use get for safety + return values_dict + else: + raise ValueError( + "Internal error: Neither fields nor condition_value_fn.") + except AttributeError: # Specifically for _asdict() + logger.error( + "Failed to call _asdict() on element. Type: %s. Element: " + "%s. Ensure input is beam.Row.", + type(req), + req, + ) + return None + except Exception as e: + logger.error( + "Error getting condition values for row %s: %s", + req, + e, + exc_info=True) + return None + + def _build_single_row_filter( + self, req_row: BeamRow, condition_values_dict: Dict[str, Any]) -> str: + """Builds the filter string part for a single row.""" + try: + if self.row_restriction_template_fn: + # User's provided signature for row_restriction_template_fn + template_or_filter = self.row_restriction_template_fn( + condition_values_dict, self.primary_keys, req_row) + if not isinstance(template_or_filter, str): + raise TypeError( + "row_restriction_template_fn must return a string " + "(filter or template to be formatted)") + # Assuming if it takes condition_values_dict, it might be returning + # the final filter or a template. If it's a template, it still needs + # .format(). For now, assume it's a template that might still need + # formatting OR the final filter string. Let's assume it's the final + # filter string as per user's code. + # Directly return what the user's function gives. + return template_or_filter + elif self.row_restriction_template: + return self.row_restriction_template.format(**condition_values_dict) + else: + raise ValueError( + "Internal Error: No template or template function available.") + except KeyError as e: # if user's fn returns template and format fails + raise ValueError( + f"Placeholder {{{e}}} in template not found in " + f"condition values: {condition_values_dict.keys()}") + except Exception as e: + logger.error( + "Error building filter for row %s with values %s: %s", + req_row, + condition_values_dict, + e, + exc_info=True, + ) + return "" + + def _apply_renaming(self, bq_row_dict: BQRowDict) -> BQRowDict: + if not self._rename_map: + return bq_row_dict + return {self._rename_map.get(k, k): v for k, v in bq_row_dict.items()} + + def _arrow_to_dicts(self, response: ReadRowsResponse) -> Iterator[BQRowDict]: + # Now uses self._arrow_schema directly + if response.arrow_record_batch: + if not self._arrow_schema: + logger.error( + "Cannot process Arrow batch: Schema not " + "available/cached in handler.") + return + try: + serialized_batch = response.arrow_record_batch.serialized_record_batch + record_batch = pa.ipc.read_record_batch( + pa.py_buffer(serialized_batch), self._arrow_schema) + arrow_table = pa.Table.from_batches([record_batch]) + yield from arrow_table.to_pylist() + except Exception as e: + logger.error( + "Error converting Arrow batch to dicts: %s", e, exc_info=True) + + def _execute_storage_read(self, combined_row_filter: str) -> List[BQRowDict]: + if not self._client: + self.__enter__() + if not self._client: + raise RuntimeError("BQ Client failed to initialize.") + if not combined_row_filter: + logger.warning("Empty filter, skipping BQ read.") + return [] + + try: + table_project, dataset_id, table_id = self.table_name.split(".") + except ValueError: + raise ValueError( + f"Invalid table_name: '{self.table_name}'. Expected " + "'project.dataset.table'.") + parent_project = self.project + table_resource = ( + f"projects/{table_project}/datasets/{dataset_id}/tables/{table_id}") + + session = None + try: + # TODO: Improve max_stream_count to be dynamic based on input size, + # data volume, and query complexity for optimal performance + req = { + "parent": f"projects/{parent_project}", + "read_session": ReadSession( + table=table_resource, + data_format=DataFormat.ARROW, + read_options=ReadSession.TableReadOptions( + row_restriction=combined_row_filter, + selected_fields=self._bq_select_columns, + ), + ), + "max_stream_count": self.max_stream_count, + } + session = self._client.create_read_session(request=req) + logger.debug( + "Session with %s streams. Filter: %s", + len(session.streams), + combined_row_filter, + ) + if session.streams and session.arrow_schema: + if not self._arrow_schema: + self._arrow_schema = pa.ipc.read_schema( + pa.py_buffer(session.arrow_schema.serialized_schema)) + logger.debug("Deserialized Arrow schema for current call.") + elif session.streams: + logger.error("Session has streams but no schema.") + return [] + except (BadRequest, NotFound, GoogleAPICallError) as e: + logger.error( + "BQ API error creating session. Filter: '%s'. Error: %s", + combined_row_filter, + e, + ) + return [] + except Exception as e: + logger.error( + "Unexpected error creating session. Filter: '%s'. Error: %s", + combined_row_filter, + e, + exc_info=True, + ) + return [] + + if not session or not session.streams: + logger.warning("No streams for filter: %s", combined_row_filter) + return [] + + def _read_single_stream_worker(stream_name: str) -> List[BQRowDict]: + worker_results = [] + if not self._client or not self._arrow_schema: + logger.error("Stream %s: Client/schema missing in worker.", stream_name) + return worker_results + try: + reader = self._client.read_rows(stream_name) + for response in reader: + worker_results.extend(self._arrow_to_dicts(response)) # Uses + # self._arrow_schema + except Exception as e: + logger.error( + "Error reading stream %s in worker: %s", + stream_name, + e, + exc_info=True, + ) + return worker_results + + all_bq_rows_original_keys = [] + num_api_streams = len(session.streams) + max_workers = num_api_streams + if self.max_parallel_streams is not None and self.max_parallel_streams > 0: + max_workers = min(num_api_streams, self.max_parallel_streams) + if max_workers <= 0: + max_workers = 1 + logger.debug( + "Reading %s API streams using %s threads.", + num_api_streams, + max_workers) + futures = [] + try: + with concurrent.futures.ThreadPoolExecutor( + max_workers=max_workers) as executor: + for stream in session.streams: + futures.append( + executor.submit(_read_single_stream_worker, stream.name)) + for future in concurrent.futures.as_completed(futures): + try: + all_bq_rows_original_keys.extend(future.result()) + except Exception as e: + logger.error("Error processing future result: %s", e, exc_info=True) + except Exception as pool_error: + logger.error("ThreadPool error: %s", pool_error, exc_info=True) + logger.debug("Fetched %s rows from BQ.", len(all_bq_rows_original_keys)) + return all_bq_rows_original_keys + + def create_row_key(self, + row: BeamRow, + is_bq_result: bool = False) -> Optional[tuple]: + try: + if self.condition_value_fn: + key_values_dict = self.condition_value_fn(row) + elif self.fields is not None: + row_dict = row._asdict() # Assumes row is BeamRow + + # If this is a BQ result row with aliased columns, map field names + # to aliases + fields_to_use = self.fields + if is_bq_result and self._rename_map: + # Map original field names to their aliases + fields_to_use = [self._rename_map.get(f, f) for f in self.fields] + + key_values_dict = { + # Use original field name as key, but get value using appropriate + # field name + self.fields[i]: row_dict[field_name] + for i, field_name in enumerate(fields_to_use) + if field_name in row_dict and row_dict[field_name] is not None + } + if len(key_values_dict) != len(self.fields): # Ensure all key fields + # found and not None + logger.debug( + "Row missing key field(s) or None. Cannot generate key: %s", row) + return None + else: + raise ValueError( + "Internal error: Neither fields nor condition_value_fn for key.") + if key_values_dict is None: + return None + return tuple(sorted(key_values_dict.items())) + except AttributeError: + logger.error( + "Failed _asdict() for key gen. Type: %s. Ensure input is beam.Row.", + type(row), + ) + return None + except Exception as e: + logger.error("Error generating key for row %s: %s", row, e, exc_info=True) + return None + + def _process_bq_results_for_batch( + self, + bq_results_list_orig_keys: List[BQRowDict]) -> Dict[tuple, List[BeamRow]]: + """Process BQ results and create a mapping from keys to renamed rows.""" + bq_results_key_map: Dict[tuple, List[BeamRow]] = {} + for bq_row_dict_orig_keys in bq_results_list_orig_keys: + try: + renamed_bq_row_dict = self._apply_renaming(bq_row_dict_orig_keys) + bq_row_renamed_keys_temp = BeamRow(**renamed_bq_row_dict) + resp_key = self.create_row_key( + bq_row_renamed_keys_temp, is_bq_result=True) + if resp_key: + if resp_key not in bq_results_key_map: + bq_results_key_map[resp_key] = [] + bq_results_key_map[resp_key].append(bq_row_renamed_keys_temp) + except Exception as e: + logger.warning( + "Error processing BQ response row %s: %s. Cannot map.", + bq_row_dict_orig_keys, + e, + ) + return bq_results_key_map + + def _select_response_row( + self, matching_bq_rows: List[BeamRow], req_row: BeamRow) -> BeamRow: + """Select the appropriate response row from matching BQ rows.""" + if not matching_bq_rows: + return BeamRow() + + if self._latest_value_callback: + try: + return ( + self._latest_value_callback(matching_bq_rows, req_row) or BeamRow()) + except Exception as cb_error: + logger.error( + "Error in latest_value_selector: %s. Using first BQ row.", + cb_error, + exc_info=True, + ) + return matching_bq_rows[0] + else: + return matching_bq_rows[0] # Default to first + + def _process_batch_request( + self, request: list[BeamRow]) -> List[Tuple[BeamRow, BeamRow]]: + """ + Process a batch of requests efficiently using a single BigQuery query. + + This method optimizes batch processing by: + 1. Deduplicating requests with identical keys + 2. Building individual row filters for each unique request + 3. Combining all filters into a single OR query to minimize API calls + 4. Mapping BigQuery results back to original requests + + Args: + request: List of BeamRow objects to enrich with BigQuery data + + Returns: + List of tuples containing (original_request, enriched_response) + pairs + """ + # Initialize collections for processing + batch_responses: List[Tuple[BeamRow, BeamRow]] = [] # Final results + requests_map: Dict[tuple, BeamRow] = {} # Unique key -> request mapping + single_row_filters: List[str] = [] # Individual SQL filter conditions + + # Phase 1: Process each request row and build individual filters + for req_row in request: + # Extract condition values (e.g., key fields) from the request row + condition_values = self._get_condition_values_dict(req_row) + if condition_values is None: + # Missing required fields - add empty response and skip processing + batch_responses.append((req_row, BeamRow())) + continue + + # Generate a unique key for this request (used for deduplication) + req_key = self.create_row_key(req_row) + if req_key is None: + # Cannot generate key - add empty response and skip processing + batch_responses.append((req_row, BeamRow())) + continue + + # Handle duplicate detection and filter building + if req_key not in requests_map: + # New unique request - store it and build its filter + requests_map[req_key] = req_row + single_filter = self._build_single_row_filter(req_row, condition_values) + if single_filter: + # Wrap in parentheses for safe OR combination + single_row_filters.append(f"({single_filter})") + else: + # Filter generation failed - add empty response + batch_responses.append((req_row, BeamRow())) + del requests_map[req_key] # Clean up + else: + # Duplicate key detected - log warning and return empty response + logger.warning( + "Duplicate key '%s' in batch. Processing first instance.", req_key) + batch_responses.append((req_row, BeamRow())) + + # Phase 2: Execute combined BigQuery query if we have valid filters + bq_results_key_map: Dict[tuple, + List[BeamRow]] = {} # Key -> BQ results mapping + if single_row_filters: + # Combine all individual filters with OR to create single query + # Example: "(id = 1) OR (id = 2) OR (id = 3)" + combined_filter = " OR ".join(single_row_filters) + + # Execute single BigQuery Storage Read API call + bq_results_list_orig_keys = self._execute_storage_read(combined_filter) + + # Process raw BigQuery results: apply column renaming and group by key + bq_results_key_map = self._process_bq_results_for_batch( + bq_results_list_orig_keys) + + # Phase 3: Match BigQuery results back to original requests + for req_key, req_row in requests_map.items(): + # Find all BigQuery rows that match this request's key + matching_bq_rows = bq_results_key_map.get(req_key, []) + + # Select the best response row (first match or custom selector result) + selected_response_row = self._select_response_row( + matching_bq_rows, req_row) + + # Add the (request, response) pair to final results + batch_responses.append((req_row, selected_response_row)) + + return batch_responses + + def _process_single_request(self, + request: BeamRow) -> Tuple[BeamRow, BeamRow]: + """Process a single request using a direct BQ query.""" + req_row = request + condition_values = self._get_condition_values_dict(req_row) + if condition_values is None: + return (req_row, BeamRow()) + single_filter = self._build_single_row_filter(req_row, condition_values) + if not single_filter: + return (req_row, BeamRow()) + bq_results_orig_keys = self._execute_storage_read(single_filter) + response_row = BeamRow() + if bq_results_orig_keys: + # For single request, apply selector if provided, else take first + renamed_bq_rows = [ + BeamRow(**self._apply_renaming(d)) for d in bq_results_orig_keys + ] + if self._latest_value_callback and renamed_bq_rows: + try: + response_row = ( + self._latest_value_callback(renamed_bq_rows, req_row) or + BeamRow()) + except Exception as cb_error: + logger.error( + "Error in latest_value_selector for single req: %s. " + "Using first BQ row.", + cb_error, + exc_info=True, + ) + response_row = renamed_bq_rows[0] + elif renamed_bq_rows: + response_row = renamed_bq_rows[0] + if len(bq_results_orig_keys) > 1 and not ( + self._latest_value_callback and + response_row != BeamRow()): # Log if multiple and + # default/callback didn't pick one specifically + logger.warning( + "Single request -> %s BQ rows. Used selected/first. Filter:'%s'", + len(bq_results_orig_keys), + single_filter, + ) + return (req_row, response_row) + + def __call__( # type: ignore[override] + self, request: Union[BeamRow, list[BeamRow]], *args, **kwargs + ) -> Union[Tuple[BeamRow, BeamRow], List[Tuple[BeamRow, BeamRow]]]: + self._arrow_schema = None # Reset schema + + if isinstance(request, list): + return self._process_batch_request(request) + else: + return self._process_single_request(request) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._client: + logger.info("BigQueryStorageEnrichmentHandler: Releasing client.") + self._client = None + + def get_cache_key(self, request: Union[BeamRow, list[BeamRow]]) -> str: + # TODO: Add proper caching functionality with TTL, cache size limits, + # and configurable cache policies to improve performance and reduce + # BigQuery API calls for repeated requests. + if isinstance(request, list): + # For batch requests, create a composite key + keys = [ + str(self.create_row_key(req) or "__invalid_key__") for req in request + ] + return "|".join(keys) + else: + return str(self.create_row_key(request) or "__invalid_key__") + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + return self._batching_kwargs diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read_it_test.py new file mode 100644 index 000000000000..5611cf0bce7c --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read_it_test.py @@ -0,0 +1,494 @@ +# + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import secrets +import time +import unittest +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +import pytest + +import apache_beam as beam +from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper +from apache_beam.io.gcp.internal.clients import bigquery +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=ungrouped-imports +try: + from apitools.base.py.exceptions import HttpError + from google.api_core.exceptions import BadRequest, GoogleAPICallError + + # Removed NotFound from import as it is unused + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import \ + BigQueryStorageEnrichmentHandler +except ImportError: + raise unittest.SkipTest( + "Google Cloud BigQuery dependencies are not installed.") + +_LOGGER = logging.getLogger(__name__) + + +@pytest.mark.uses_testcontainer +class BigQueryStorageEnrichmentIT(unittest.TestCase): + bigquery_dataset_id_prefix = "py_bq_storage_enrich_it_" + project = "apache-beam-testing" # Ensure this project is configured for tests + + @classmethod + def setUpClass(cls): + cls.bigquery_client = BigQueryWrapper() + # Generate a unique dataset ID for this test run + cls.dataset_id = "%s%d%s" % ( + cls.bigquery_dataset_id_prefix, + int(time.time()), + secrets.token_hex(3), + ) + cls.bigquery_client.get_or_create_dataset(cls.project, cls.dataset_id) + _LOGGER.info( + "Created dataset %s in project %s", cls.dataset_id, cls.project) + + @classmethod + def tearDownClass(cls): + request = bigquery.BigqueryDatasetsDeleteRequest( + projectId=cls.project, datasetId=cls.dataset_id, deleteContents=True) + try: + _LOGGER.info( + "Deleting dataset %s in project %s", cls.dataset_id, cls.project) + cls.bigquery_client.client.datasets.Delete(request) + except HttpError as e: + _LOGGER.warning( + "Failed to clean up dataset %s in project %s: %s", + cls.dataset_id, + cls.project, + e, + ) + + +@pytest.mark.uses_testcontainer +class TestBigQueryStorageEnrichmentIT(BigQueryStorageEnrichmentIT): + product_details_table_data = [ + { + "id": 1, "name": "A", "quantity": 2, "distribution_center_id": 3 + }, + { + "id": 2, "name": "B", "quantity": 3, "distribution_center_id": 1 + }, + { + "id": 3, "name": "C", "quantity": 10, "distribution_center_id": 4 + }, + { + "id": 4, "name": "D", "quantity": 1, "distribution_center_id": 3 + }, + { + "id": 5, "name": "C", "quantity": 100, "distribution_center_id": 4 + }, + { + "id": 6, "name": "D", "quantity": 11, "distribution_center_id": 3 + }, + { + "id": 7, "name": "C", "quantity": 7, "distribution_center_id": 1 + }, + ] + + product_updates_table_data = [ + { + "id": 10, + "value": "old_value_10", + "update_ts": "2023-01-01T00:00:00Z" + }, + { + "id": 10, + "value": "new_value_10", + "update_ts": "2023-01-02T00:00:00Z" + }, + { + "id": 11, + "value": "current_value_11", + "update_ts": "2023-01-05T00:00:00Z" + }, + { + "id": 10, + "value": "latest_value_10", + "update_ts": "2023-01-03T00:00:00Z" + }, + ] + + @classmethod + def create_table(cls, table_id_suffix, schema_fields, data): + table_id = f"table_{table_id_suffix}_{secrets.token_hex(2)}" + table_schema = bigquery.TableSchema() + for name, field_type in schema_fields: + table_field = bigquery.TableFieldSchema() + table_field.name = name + table_field.type = field_type + table_schema.fields.append(table_field) + + table = bigquery.Table( + tableReference=bigquery.TableReference( + projectId=cls.project, datasetId=cls.dataset_id, tableId=table_id), + schema=table_schema, + ) + request = bigquery.BigqueryTablesInsertRequest( + projectId=cls.project, datasetId=cls.dataset_id, table=table) + cls.bigquery_client.client.tables.Insert(request) + if data: + cls.bigquery_client.insert_rows( + cls.project, cls.dataset_id, table_id, data) + + fq_table_name = f"{cls.project}.{cls.dataset_id}.{table_id}" + _LOGGER.info("Created table %s", fq_table_name) + return fq_table_name + + @classmethod + def setUpClass(cls): + super().setUpClass() + product_schema = [ + ("id", "INTEGER"), + ("name", "STRING"), + ("quantity", "INTEGER"), + ("distribution_center_id", "INTEGER"), + ] + cls.product_details_table_fq = cls.create_table( + "product_details", product_schema, cls.product_details_table_data) + + updates_schema = [ + ("id", "INTEGER"), + ("value", "STRING"), + ("update_ts", "TIMESTAMP"), + ] + cls.product_updates_table_fq = cls.create_table( + "product_updates", updates_schema, cls.product_updates_table_data) + + def setUp(self): + self.default_row_restriction = "id = {}" + self.default_fields = ["id"] + + # [START test_enrichment_single_element] + def test_enrichment_single_element(self): + requests = [beam.Row(id=1, source_field="SourceA")] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template=self.default_row_restriction, + fields=self.default_fields, + column_names=["id", "name", "quantity"], + min_batch_size=1, + max_batch_size=1, + ) + + expected_output = [ + beam.Row(id=1, source_field="SourceA", name="A", quantity=2) + ] + + with TestPipeline(is_integration_test=True) as p: + input_pcoll = p | "CreateRequests" >> beam.Create(requests) + enriched_pcoll = input_pcoll | "Enrich" >> Enrichment(handler) + assert_that(enriched_pcoll, equal_to(expected_output)) + + # [END test_enrichment_single_element] + + def test_enrichment_batch_elements(self): + requests = [ + beam.Row(id=1, source_field="Item1"), + beam.Row(id=2, source_field="Item2"), + ] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template=self.default_row_restriction, + fields=self.default_fields, + column_names=["id", "name", "quantity", "distribution_center_id"], + min_batch_size=2, + max_batch_size=10, + ) + + expected_output = [ + beam.Row( + id=1, + source_field="Item1", + name="A", + quantity=2, + distribution_center_id=3, + ), + beam.Row( + id=2, + source_field="Item2", + name="B", + quantity=3, + distribution_center_id=1, + ), + ] + + with TestPipeline(is_integration_test=True) as p: + input_pcoll = p | beam.Create(requests) + enriched_pcoll = input_pcoll | Enrichment(handler) + assert_that(enriched_pcoll, equal_to(expected_output)) + + def test_enrichment_column_aliasing(self): + requests = [beam.Row(id=3, source_field="ItemC")] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template=self.default_row_restriction, + fields=self.default_fields, + column_names=["id", "name as product_name", "quantity as stock_count"], + ) + expected_output = [ + beam.Row(id=3, source_field="ItemC", product_name="C", stock_count=10) + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_no_match_passes_through(self): + requests = [ + beam.Row(id=1, source_field="ItemA"), # Match + beam.Row(id=99, source_field="ItemZ"), # No Match + ] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template=self.default_row_restriction, + fields=self.default_fields, + column_names=["id", "name", "quantity"], + ) + + expected_output = [ + beam.Row(id=1, source_field="ItemA", name="A", quantity=2), + beam.Row(id=99, + source_field="ItemZ"), # Original row, no enrichment fields + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_select_all_columns_asterisk(self): + requests = [beam.Row(id=4, source_field="ItemD")] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template=self.default_row_restriction, + fields=self.default_fields, + column_names=["*"], + ) + + expected_output = [ + beam.Row( + id=4, + source_field="ItemD", + name="D", + quantity=1, + distribution_center_id=3, + ) + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_row_restriction_template_fn(self): + def custom_template_fn( + condition_values: Dict[str, Any], + primary_keys: Optional[List[str]], + request_row: beam.Row, + ) -> str: + # request_row has 'lookup_id' and 'lookup_name' + # condition_values will have 'id_val' and 'name_val' from + # condition_value_fn + return ( + f"id = {condition_values['id_val']} AND " + f"name = '{condition_values['name_val']}'") + + def custom_cond_val_fn(req_row: beam.Row) -> Dict[str, Any]: + return {"id_val": req_row.lookup_id, "name_val": req_row.lookup_name} + + requests = [beam.Row(lookup_id=5, lookup_name="C")] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template_fn=custom_template_fn, + condition_value_fn=custom_cond_val_fn, + column_names=["quantity", "distribution_center_id"], + ) + + expected_output = [ + beam.Row( + lookup_id=5, + lookup_name="C", + quantity=100, + distribution_center_id=4) + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_condition_value_fn(self): + def custom_cond_val_fn(req_row: beam.Row) -> Dict[str, Any]: + # req_row has 'product_identifier' + return {"the_id": req_row.product_identifier} + + requests = [beam.Row(product_identifier=6)] + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template="id = {the_id}", # Uses key from cond_val_fn + condition_value_fn=custom_cond_val_fn, + column_names=["id", "name", "quantity"], + ) + + expected_output = [ + beam.Row(product_identifier=6, id=6, name="D", quantity=11) + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_additional_condition_fields(self): + requests = [beam.Row(target_id=7, filter_on_name="C")] + + # The handler will try to format "id = {} AND name = '{}'" + # with (requests_row.target_id, requests_row.filter_on_name) + # This requires careful alignment of fields and template. + # Let's adjust the handler to use named placeholders for clarity with + # additional_fields + # Or, ensure the template matches the order of fields + + # additional_condition_fields + + # For this test, let's assume the template is designed for positional + # formatting where the first {} takes from `fields` and subsequent {} + # take from `additional_condition_fields` + # The current implementation of _get_condition_values_dict for `fields` + + # `additional_condition_fields` creates a dictionary. So the template + # should use named placeholders. + + # Re-designing this specific test for clarity with named placeholders: + # Let condition_value_fn handle the mapping if complex. + # If using `fields` and `additional_condition_fields`, the template + # should use the field names directly if they are the keys in the dict + # passed to format. + + # Let's use a condition_value_fn for this scenario to be explicit + def complex_cond_fn(req: beam.Row) -> Dict[str, Any]: + return {"id_val": req.target_id, "name_val": req.filter_on_name} + + handler_revised = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + condition_value_fn=complex_cond_fn, + row_restriction_template="id = {id_val} AND name = '{name_val}'", + column_names=["quantity", "distribution_center_id"], + ) + + expected_output = [ + beam.Row( + target_id=7, + filter_on_name="C", + quantity=7, + distribution_center_id=1) + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler_revised) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_latest_value_selector(self): + def select_latest_by_ts(bq_results: List[beam.Row], + request_row: beam.Row) -> Optional[beam.Row]: + if not bq_results: + return None + # Assuming 'update_ts' is a field in bq_results and is comparable + return max(bq_results, key=lambda r: r.update_ts) + + requests = [ + beam.Row(lookup_id=10) + ] # This ID has multiple entries in updates table + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_updates_table_fq, + fields=["lookup_id" + ], # 'lookup_id' from request_row will map to 'id' in template + row_restriction_template="id = {}", + column_names=[ + "value", + "update_ts", + ], # Select value and the timestamp itself + latest_value_selector=select_latest_by_ts, + primary_keys=["id"], # For the selector's context if needed + ) + + expected_output = [ + # The latest_value_10 has ts 2023-01-03 + beam.Row( + lookup_id=10, + value="latest_value_10", + update_ts="2023-01-03T00:00:00Z") + ] + with TestPipeline(is_integration_test=True) as p: + enriched = p | beam.Create(requests) | Enrichment(handler) + assert_that(enriched, equal_to(expected_output)) + + def test_enrichment_bad_request_invalid_column_in_template(self): + requests = [beam.Row(id=1)] + # Using a field in template that won't be provided by 'fields' + handler = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template= + "non_existent_field = {}", # This will cause KeyError during formatting + fields=["id"], + column_names=["name"], + ) + + with TestPipeline(is_integration_test=True) as p: + _ = p | beam.Create(requests) | Enrichment(handler) + # The error might manifest as a KeyError when formatting the template, + # or a BadRequest from BQ if the query is malformed but syntactically + # valid enough to send. + # The handler's internal _build_single_row_filter catches KeyError. + # If the query sent to BQ is invalid (e.g. "SELECT name FROM ... WHERE + # "), BQ returns BadRequest. + # Let's test for a BQ BadRequest due to a bad query structure. + # Example: selecting a column that doesn't exist. + handler_bad_select = BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.product_details_table_fq, + row_restriction_template="id = {}", + fields=["id"], + column_names=["non_existent_column_in_bq_table"], + ) # BQ error + + with self.assertRaises( + GoogleAPICallError) as e_ctx: # Or specifically BadRequest + p_bad = TestPipeline(is_integration_test=True) + input_pcoll = p_bad | "CreateBad" >> beam.Create(requests) + _ = input_pcoll | "EnrichBad" >> Enrichment(handler_bad_select) + res = p_bad.run() + res.wait_until_finish() + + self.assertTrue( + isinstance(e_ctx.exception, BadRequest) or + "NoSuchFieldError" in str(e_ctx.exception) or + "not found in table" in str(e_ctx.exception).lower() or + "unrecognized name" in str(e_ctx.exception).lower()) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read_test.py new file mode 100644 index 000000000000..e626808a3269 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery_storage_read_test.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from apache_beam.pvalue import Row as BeamRow + +try: + from apache_beam.transforms.enrichment_handlers import bigquery_storage_read +except ImportError: + raise unittest.SkipTest( + "Google Cloud BigQuery Storage dependencies are not installed.") + + +class TestBigQueryStorageEnrichmentHandler(unittest.TestCase): + def setUp(self): + self.project = "test-project" + self.table_name = "test-project.test_dataset.test_table" + self.fields = ["id"] + self.row_restriction_template = 'id = "{id}"' + self.column_names = ["id", "value"] + + def make_handler(self, **kwargs): + handler_kwargs = { + "project": self.project, + "table_name": self.table_name, + "row_restriction_template": self.row_restriction_template, + "fields": self.fields, + "column_names": self.column_names, + } + handler_kwargs.update(kwargs) # Override defaults with provided kwargs + return bigquery_storage_read.BigQueryStorageEnrichmentHandler( + **handler_kwargs) + + def test_init_invalid_args(self): + # Both row_restriction_template and row_restriction_template_fn + with self.assertRaises(ValueError): + bigquery_storage_read.BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.table_name, + row_restriction_template="foo", + row_restriction_template_fn=lambda d, p, r: "bar", + fields=self.fields, + ) + # Neither row_restriction_template nor row_restriction_template_fn + with self.assertRaises(ValueError): + bigquery_storage_read.BigQueryStorageEnrichmentHandler( + project=self.project, table_name=self.table_name, fields=self.fields) + # Both fields and condition_value_fn + with self.assertRaises(ValueError): + bigquery_storage_read.BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.table_name, + row_restriction_template="foo", + fields=self.fields, + condition_value_fn=lambda r: {"id": 1}, + ) + # Neither fields nor condition_value_fn + with self.assertRaises(ValueError): + bigquery_storage_read.BigQueryStorageEnrichmentHandler( + project=self.project, + table_name=self.table_name, + row_restriction_template="foo", + ) + + def test_get_condition_values_dict_fields(self): + handler = self.make_handler() + row = BeamRow(id=1, value="a") + self.assertEqual(handler._get_condition_values_dict(row), {"id": 1}) + + def test_get_condition_values_dict_missing_field(self): + handler = self.make_handler() + row = BeamRow(value="a") + self.assertIsNone(handler._get_condition_values_dict(row)) + + def test_get_condition_values_dict_condition_value_fn(self): + handler = self.make_handler( + fields=None, condition_value_fn=lambda r: {"id": 2}) + row = BeamRow(id=2, value="b") + self.assertEqual(handler._get_condition_values_dict(row), {"id": 2}) + + def test_build_single_row_filter_template(self): + handler = self.make_handler() + row = BeamRow(id=3, value="c") + cond = {"id": 3} + self.assertEqual(handler._build_single_row_filter(row, cond), 'id = "3"') + + def test_build_single_row_filter_fn(self): + fn = lambda d, p, r: f"id = '{d['id']}'" + handler = self.make_handler( + row_restriction_template=None, row_restriction_template_fn=fn) + row = BeamRow(id=4, value="d") + cond = {"id": 4} + self.assertEqual(handler._build_single_row_filter(row, cond), "id = '4'") + + def test_apply_renaming(self): + handler = self.make_handler(column_names=["id as new_id", "value"]) + bq_row = {"id": 1, "value": "foo"} + self.assertEqual( + handler._apply_renaming(bq_row), { + "new_id": 1, "value": "foo" + }) + + def test_apply_renaming_all_columns_aliased(self): + """Test column aliasing when all columns are aliased - expected output + should have aliased keys.""" + handler = self.make_handler( + column_names=["id as user_id", "value as user_value"]) + bq_row = {"id": 42, "value": "test_data"} + # When all columns are aliased, the expected output should only contain + # aliased keys + expected_result = {"user_id": 42, "user_value": "test_data"} + actual_result = handler._apply_renaming(bq_row) + self.assertEqual(actual_result, expected_result) + + # Verify that no original column names remain in the output + self.assertNotIn("id", actual_result) + self.assertNotIn("value", actual_result) + + # Verify that all expected aliased keys are present + self.assertIn("user_id", actual_result) + self.assertIn("user_value", actual_result) + + def test_create_row_key(self): + handler = self.make_handler() + row = BeamRow(id=5, value="e") + self.assertEqual(handler.create_row_key(row), (("id", 5), )) + + @mock.patch.object( + bigquery_storage_read.BigQueryStorageEnrichmentHandler, + "_execute_storage_read") + def test_call_single_match(self, mock_exec): + handler = self.make_handler() + row = BeamRow(id=6, value="f") + mock_exec.return_value = [{"id": 6, "value": "fetched"}] + req, resp = handler(row) + self.assertEqual(req, row) + self.assertEqual(resp.id, 6) + self.assertEqual(resp.value, "fetched") + + @mock.patch.object( + bigquery_storage_read.BigQueryStorageEnrichmentHandler, + "_execute_storage_read") + def test_call_single_no_match(self, mock_exec): + handler = self.make_handler() + row = BeamRow(id=7, value="g") + mock_exec.return_value = [] + req, resp = handler(row) + self.assertEqual(req, row) + self.assertEqual(resp, BeamRow()) + + @mock.patch.object( + bigquery_storage_read.BigQueryStorageEnrichmentHandler, + "_execute_storage_read") + def test_call_batch(self, mock_exec): + handler = self.make_handler() + rows = [BeamRow(id=8, value="h"), BeamRow(id=9, value="i")] + mock_exec.return_value = [ + { + "id": 8, "value": "h_bq" + }, + { + "id": 9, "value": "i_bq" + }, + ] + result = handler(rows) + self.assertEqual(result[0][0], rows[0]) + self.assertEqual(result[0][1].id, 8) + self.assertEqual(result[0][1].value, "h_bq") + self.assertEqual(result[1][0], rows[1]) + self.assertEqual(result[1][1].id, 9) + self.assertEqual(result[1][1].value, "i_bq") + + @mock.patch.object( + bigquery_storage_read.BigQueryStorageEnrichmentHandler, + "_execute_storage_read") + def test_call_batch_no_match(self, mock_exec): + handler = self.make_handler() + rows = [BeamRow(id=10, value="j"), BeamRow(id=11, value="k")] + mock_exec.return_value = [] + result = handler(rows) + self.assertEqual(result[0][0], rows[0]) + self.assertEqual(result[0][1], BeamRow()) + self.assertEqual(result[1][0], rows[1]) + self.assertEqual(result[1][1], BeamRow()) + + # def test_get_cache_key(self): + # handler = self.make_handler() + # row = BeamRow(id=12, value="l") + # self.assertEqual(handler.get_cache_key(row), str((("id", 12), ))) + # rows = [BeamRow(id=13, value="m"), BeamRow(id=14, value="n")] + # self.assertEqual( + # handler.get_cache_key(rows), + # [str((("id", 13), )), str((("id", 14), ))]) + + def test_batch_elements_kwargs(self): + handler = self.make_handler( + min_batch_size=2, max_batch_size=5, max_batch_duration_secs=10) + self.assertEqual( + handler.batch_elements_kwargs(), + { + "min_batch_size": 2, + "max_batch_size": 5, + "max_batch_duration_secs": 10 + }, + ) + + def test_max_stream_count_default(self): + """Test that max_stream_count defaults to 100.""" + handler = self.make_handler() + self.assertEqual(handler.max_stream_count, 100) + + def test_max_stream_count_custom(self): + """Test that max_stream_count can be set to a custom value.""" + handler = self.make_handler(max_stream_count=50) + self.assertEqual(handler.max_stream_count, 50) + + def test_max_stream_count_zero(self): + """Test that max_stream_count can be set to 0.""" + handler = self.make_handler(max_stream_count=0) + self.assertEqual(handler.max_stream_count, 0) + + @mock.patch( + "apache_beam.transforms.enrichment_handlers." + "bigquery_storage_read.BigQueryReadClient") + def test_max_stream_count_passed_to_bq_api(self, mock_client_class): + """Test that max_stream_count is passed to BigQuery API request.""" + handler = self.make_handler(max_stream_count=25) + + # Mock the BigQuery client instance and session + mock_client_instance = mock.MagicMock() + mock_client_class.return_value = mock_client_instance + + mock_session = mock.MagicMock() + mock_session.streams = [] + mock_session.arrow_schema = None + mock_client_instance.create_read_session.return_value = mock_session + + # Initialize the client through __enter__ and call _execute_storage_read + handler.__enter__() + handler._execute_storage_read("id = 1") + + # Verify that create_read_session was called with the correct + # max_stream_count + mock_client_instance.create_read_session.assert_called_once() + call_args = mock_client_instance.create_read_session.call_args + request = call_args.kwargs["request"] + self.assertEqual(request["max_stream_count"], 25) + + @mock.patch.object( + bigquery_storage_read.BigQueryStorageEnrichmentHandler, + "_execute_storage_read") + def test_call_single_match_all_columns_aliased(self, mock_exec): + """Test end-to-end enrichment flow when all columns are aliased.""" + # Create handler with all columns aliased + handler = self.make_handler( + column_names=["id as user_id", "value as user_value"], + fields=["id"]) # Note: fields still uses original column name + + row = BeamRow(id=6, value="f") + # Mock returns data with original column names (as BigQuery would) + mock_exec.return_value = [{"id": 6, "value": "fetched_value"}] + + req, resp = handler(row) + + # Verify request is unchanged + self.assertEqual(req, row) + + # Verify response has aliased column names + self.assertEqual(resp.user_id, 6) + self.assertEqual(resp.user_value, "fetched_value") + + # Verify original column names are not present in response + with self.assertRaises(AttributeError): + _ = resp.id # Should not exist + with self.assertRaises(AttributeError): + _ = resp.value # Should not exist + + @mock.patch.object( + bigquery_storage_read.BigQueryStorageEnrichmentHandler, + "_execute_storage_read") + def test_call_batch_all_columns_aliased(self, mock_exec): + """Test batch enrichment flow when all columns are aliased.""" + # Create handler with all columns aliased + handler = self.make_handler( + column_names=["id as customer_id", "value as customer_name"], + fields=["id"]) + + rows = [BeamRow(id=100, value="john"), BeamRow(id=200, value="jane")] + # Mock returns data with original column names + mock_exec.return_value = [ + { + "id": 100, "value": "John Doe" + }, + { + "id": 200, "value": "Jane Smith" + }, + ] + + result = handler(rows) + + print(result) + # Verify we get correct number of results + self.assertEqual(len(result), 2) + + # Verify first result + self.assertEqual(result[0][0], rows[0]) # Original request unchanged + self.assertEqual(result[0][1].customer_id, 100) # Aliased column name + self.assertEqual( + result[0][1].customer_name, "John Doe") # Aliased column name + + # Verify second result + self.assertEqual(result[1][0], rows[1]) # Original request unchanged + self.assertEqual(result[1][1].customer_id, 200) # Aliased column name + self.assertEqual( + result[1][1].customer_name, "Jane Smith") # Aliased column name + + # Verify original column names are not present in responses + for _, response in result: + with self.assertRaises(AttributeError): + _ = response.id # Should not exist + with self.assertRaises(AttributeError): + _ = response.value # Should not exist + + +if __name__ == "__main__": + unittest.main() diff --git a/start-build-env.sh b/start-build-env.sh index 0f23f32a269c..204c17a0354d 100755 --- a/start-build-env.sh +++ b/start-build-env.sh @@ -85,7 +85,7 @@ docker build -t "beam-build-${USER_ID}" - < /dev/null 2>&1; then groupmod -g ${DOCKER_GROUP_ID} docker; fi RUN useradd -g ${GROUP_ID} -G docker -u ${USER_ID} -k /root -m ${USER_NAME} -d "${DOCKER_HOME_DIR}" RUN echo "${USER_NAME} ALL=NOPASSWD: ALL" > "/etc/sudoers.d/beam-build-${USER_ID}" ENV HOME "${DOCKER_HOME_DIR}" diff --git a/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-bigquery-storage.md b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-bigquery-storage.md new file mode 100644 index 000000000000..43a9931a1382 --- /dev/null +++ b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment-bigquery-storage.md @@ -0,0 +1,123 @@ +--- +title: "Enrichment with BigQuery Storage Read API" +--- + + +## Use BigQuery Storage API to enrich data + +{{< localstorage language language-py >}} + + + + + +
+ + {{< button-pydoc path="apache_beam.transforms.enrichment_handlers.bigquery_storage_read" class="BigQueryStorageEnrichmentHandler" >}} + +
+ +In Apache Beam and later versions, the enrichment transform includes a built-in enrichment handler for [BigQuery](https://cloud.google.com/bigquery/docs/overview) using the [BigQuery Storage Read API](https://cloud.google.com/bigquery/docs/reference/storage?hl=en). +The following examples demonstrate how to create pipelines that use the enrichment transform with the [`BigQueryStorageEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigquery_storage_read.html#apache_beam.transforms.enrichment_handlers.bigquery_storage_read.BigQueryStorageEnrichmentHandler) handler, showcasing its flexibility and various use cases. + +## Field Matching Requirements + +When using BigQuery Storage enrichment, it's important to ensure that field names match between your input data and the enriched output. The `fields` parameter specifies columns from your input data used for matching, while `column_names` specifies which columns to retrieve from BigQuery. + +If BigQuery column names differ from your input field names, use aliases in `column_names` (e.g., `'bq_column_name as input_field_name'`) to ensure proper field matching. + +## Basic Enrichment Example + +This example shows basic product information enrichment for sales data: + +{{< table >}} +| sale_id | product_id | customer_id | quantity | +|:-------:|:----------:|:-----------:|:--------:| +| 1001 | 101 | 501 | 2 | +| 1002 | 102 | 502 | 1 | +| 1003 | 103 | 503 | 5 | +{{< /table >}} + +Enriched with product table data: + +{{< table >}} +| id | product_name | category | unit_price | +|:--:|:--------------:|:-----------:|:----------:| +| 101| Laptop Pro | Electronics | 999.99 | +| 102| Wireless Mouse | Electronics | 29.99 | +| 103| Office Chair | Furniture | 199.99 | +{{< /table >}} + +Note: The BigQuery table uses `id` as the column name, but our input data has `product_id`. We use `'id as product_id'` in `column_names` to ensure proper field matching. + +{{< highlight language="py" >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py" enrichment_with_bigquery_storage_basic >}} +{{}} + +{{< paragraph class="notebook-skip" >}} +Output: +{{< /paragraph >}} +{{< highlight class="notebook-skip" >}} +Row(sale_id=1001, product_id=101, customer_id=501, quantity=2, product_id=101, product_name='Laptop Pro', category='Electronics', unit_price=999.99) +Row(sale_id=1002, product_id=102, customer_id=502, quantity=1, product_id=102, product_name='Wireless Mouse', category='Electronics', unit_price=29.99) +Row(sale_id=1003, product_id=103, customer_id=503, quantity=5, product_id=103, product_name='Office Chair', category='Furniture', unit_price=199.99) +{{< /highlight >}} + +## Advanced: Custom Filtering Logic + +This example demonstrates advanced enrichment features including: +- **Conditional enrichment**: Only enrich sales where `quantity > 2` and `category == "Electronics"` +- **Multiple key matching**: Match on both `product_id` and `category` fields +- **Custom field mapping**: Use aliases to rename BigQuery columns in the output (`id as prod_id`, `product_name as name`, etc.) + +Input data includes sales with different quantities and categories, but only Electronics products with quantity > 2 will be enriched: + +{{< highlight language="py" >}} +{{< code_sample "sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment.py" enrichment_with_bigquery_storage_custom_function >}} +{{}} + +{{< paragraph class="notebook-skip" >}} +Output: +{{< /paragraph >}} +{{< highlight class="notebook-skip" >}} +Row(sale_id=1002, product_id=102, category='Electronics', customer_id=502, quantity=4, \ + prod_id=102, name='Wireless Mouse', category='Electronics', price=29.99) +Row(sale_id=1004, product_id=101, category='Electronics', customer_id=504, quantity=6, \ + prod_id=101, name='Laptop Pro', category='Electronics', price=999.99) +{{< /highlight >}} + +## FAQ: Advanced Options + +**Q: How do I enable batching for large datasets?** + +A: Use the `min_batch_size`, `max_batch_size`, and `max_batch_duration_secs` parameters in `BigQueryStorageEnrichmentHandler` to control batch size and timing. + +**Q: How do I use custom filtering logic?** + +A: Provide a `row_restriction_template_fn` and `condition_value_fn` to the handler. See the advanced example above. + +**Q: How do I tune performance for high-throughput scenarios?** + +A: Use `max_parallel_streams` and `max_stream_count` in the handler for parallel BigQuery reads. Increase batch sizes for efficiency. + +**Q: How do I use column aliasing?** + +A: Use the `as` keyword in `column_names` (e.g., `'bq_column as my_field'`) to rename columns in the output. + +## Related transforms + +Not applicable. + +{{< button-pydoc path="apache_beam.transforms.enrichment_handlers.bigquery_storage_read" class="BigQueryStorageEnrichmentHandler" >}} diff --git a/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md index 4b352d0447ad..92937d7b2294 100644 --- a/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md +++ b/website/www/site/content/en/documentation/transforms/python/elementwise/enrichment.md @@ -39,8 +39,9 @@ This transform is available in Apache Beam 2.54.0 and later versions. The following examples demonstrate how to create a pipeline that use the enrichment transform to enrich data from external services. {{< table >}} -| Service | Example | -|:-----------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Service | Example | +|:-----------------------------------|:-------------------------------------------------------------------------------------------------------------------------------| +| BigQuery Storage Read API | [Enrichment with BigQuery Storage Read API](/documentation/transforms/python/elementwise/enrichment-bigquery-storage/#example) | | Cloud Bigtable | [Enrichment with Bigtable](/documentation/transforms/python/elementwise/enrichment-bigtable/#example) | | Cloud SQL (PostgreSQL, MySQL, SQLServer) | [Enrichment with CloudSQL](/documentation/transforms/python/elementwise/enrichment-cloudsql) | | Vertex AI Feature Store | [Enrichment with Vertex AI Feature Store](/documentation/transforms/python/elementwise/enrichment-vertexai/#example-1-enrichment-with-vertex-ai-feature-store) | @@ -94,7 +95,12 @@ from apache_beam.transforms.enrichment import Enrichment # Enrichment pipeline with Redis cache enriched_data = (input_data | 'Enrich with Cache' >> Enrichment(my_enrichment_transform).with_redis_cache(host, port)) +``` +## FAQ: Advanced Options + +- **How do I use custom filtering logic or tune performance for BigQuery enrichment?** + See the [advanced example and FAQ](/documentation/transforms/python/elementwise/enrichment-bigquery-storage/#faq-advanced-options) in the BigQuery Storage Read API documentation for details on batching, custom filtering, performance tuning, and column aliasing. ## Related transforms diff --git a/website/www/site/layouts/partials/section-menu/en/documentation.html b/website/www/site/layouts/partials/section-menu/en/documentation.html index 1a60cfbdd9f1..539b2a0ce803 100755 --- a/website/www/site/layouts/partials/section-menu/en/documentation.html +++ b/website/www/site/layouts/partials/section-menu/en/documentation.html @@ -297,6 +297,7 @@