Skip to content
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,26 @@ def enrichment_with_bigtable():
# [START enrichment_with_bigtable]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.bigtable import (
BigTableEnrichmentHandler, )

project_id = 'apache-beam-testing'
instance_id = 'beam-test'
table_id = 'bigtable-enrichment-test'
row_key = 'product_id'
project_id = "apache-beam-testing"
instance_id = "beam-test"
table_id = "bigtable-enrichment-test"
row_key = "product_id"

data = [
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2)
beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2),
]

bigtable_handler = BigTableEnrichmentHandler(
project_id=project_id,
instance_id=instance_id,
table_id=table_id,
row_key=row_key)
row_key=row_key,
)
with beam.Pipeline() as p:
_ = (
p
Expand All @@ -55,16 +57,16 @@ def enrichment_with_vertex_ai():
# [START enrichment_with_vertex_ai]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \
import VertexAIFeatureStoreEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import (
VertexAIFeatureStoreEnrichmentHandler, )

project_id = 'apache-beam-testing'
location = 'us-central1'
project_id = "apache-beam-testing"
location = "us-central1"
api_endpoint = f"{location}-aiplatform.googleapis.com"
data = [
beam.Row(user_id='2963', product_id=14235, sale_price=15.0),
beam.Row(user_id='21422', product_id=11203, sale_price=12.0),
beam.Row(user_id='20592', product_id=8579, sale_price=9.0),
beam.Row(user_id="2963", product_id=14235, sale_price=15.0),
beam.Row(user_id="21422", product_id=11203, sale_price=12.0),
beam.Row(user_id="20592", product_id=8579, sale_price=9.0),
]

vertex_ai_handler = VertexAIFeatureStoreEnrichmentHandler(
Expand All @@ -88,23 +90,23 @@ def enrichment_with_vertex_ai_legacy():
# [START enrichment_with_vertex_ai_legacy]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \
import VertexAIFeatureStoreLegacyEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import (
VertexAIFeatureStoreLegacyEnrichmentHandler, )

project_id = 'apache-beam-testing'
location = 'us-central1'
project_id = "apache-beam-testing"
location = "us-central1"
api_endpoint = f"{location}-aiplatform.googleapis.com"
data = [
beam.Row(entity_id="movie_01", title='The Shawshank Redemption'),
beam.Row(entity_id="movie_01", title="The Shawshank Redemption"),
beam.Row(entity_id="movie_02", title="The Shining"),
beam.Row(entity_id="movie_04", title='The Dark Knight'),
beam.Row(entity_id="movie_04", title="The Dark Knight"),
]

vertex_ai_handler = VertexAIFeatureStoreLegacyEnrichmentHandler(
project=project_id,
location=location,
api_endpoint=api_endpoint,
entity_type_id='movies',
entity_type_id="movies",
feature_store_id="movie_prediction_unique",
feature_ids=["title", "genres"],
row_key="entity_id",
Expand All @@ -118,6 +120,119 @@ def enrichment_with_vertex_ai_legacy():
# [END enrichment_with_vertex_ai_legacy]


def enrichment_with_bigquery_storage_basic():
# [START enrichment_with_bigquery_storage_basic]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import (
BigQueryStorageEnrichmentHandler, )

project_id = "apache-beam-testing"
dataset = "beam-test"
table_name = "bigquery-enrichment-test-products"
# Sample sales data to enrich
sales_data = [
beam.Row(sale_id=1001, product_id=101, customer_id=501, quantity=2),
beam.Row(sale_id=1002, product_id=102, customer_id=502, quantity=1),
beam.Row(sale_id=1003, product_id=103, customer_id=503, quantity=5),
]

# Basic enrichment - enrich sales data with product information
handler = BigQueryStorageEnrichmentHandler(
project=project_id,
table_name=f"{project_id}.{dataset}.{table_name}",
row_restriction_template="id = {product_id}",
fields=["product_id"],
column_names=[
"id as product_id", "product_name", "category", "unit_price"
],
)

with beam.Pipeline() as p:
_ = (
p
| "Create Sales Data" >> beam.Create(sales_data)
| "Enrich with Product Info" >> Enrichment(handler)
| "Print Results" >> beam.Map(print))
# [END enrichment_with_bigquery_storage_basic]


def enrichment_with_bigquery_storage_custom_function():
# [START enrichment_with_bigquery_storage_custom_function]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import (
BigQueryStorageEnrichmentHandler, )

project_id = "apache-beam-testing"
dataset = "beam-test"
table_name = "bigquery-enrichment-test-products"
# Advanced sales data with category and quantity
sales_data = [
beam.Row(
sale_id=1001,
product_id=101,
category="Electronics",
customer_id=501,
quantity=2,
),
beam.Row(
sale_id=1002,
product_id=102,
category="Electronics",
customer_id=502,
quantity=4,
),
beam.Row(
sale_id=1003,
product_id=103,
category="Furniture",
customer_id=503,
quantity=5,
),
beam.Row(
sale_id=1004,
product_id=101,
category="Electronics",
customer_id=504,
quantity=6,
),
]

def build_row_restriction(condition_values, primary_keys, req_row):
# Only enrich if quantity > 2 and category is Electronics
if req_row.quantity > 2 and req_row.category == "Electronics":
return f'id = {req_row.product_id} AND category = "{req_row.category}"'
else:
return None # Skip enrichment for this row

def extract_condition_values(req_row):
return {
"product_id": req_row.product_id,
"category": req_row.category,
"quantity": req_row.quantity,
}

handler = BigQueryStorageEnrichmentHandler(
project=project_id,
table_name=f"{project_id}.{dataset}.{table_name}",
row_restriction_template_fn=build_row_restriction,
condition_value_fn=extract_condition_values,
column_names=[
"id as prod_id",
"product_name as name",
"category",
"unit_price as price",
],
)

with beam.Pipeline() as p:
_ = (
p
| "Create Sales Data" >> beam.Create(sales_data)
| "Enrich with Product Info (Advanced)" >> Enrichment(handler)
| "Print Results" >> beam.Map(print))
# [END enrichment_with_bigquery_storage_custom_function]
def enrichment_with_google_cloudsql_pg():
# [START enrichment_with_google_cloudsql_pg]
import apache_beam as beam
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#
# pytype: skip-file
# pylint: disable=line-too-long
# ruff: noqa: E501

import os
import unittest
Expand All @@ -33,6 +34,18 @@

# pylint: disable=unused-import
try:
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
enrichment_with_bigtable,
enrichment_with_vertex_ai_legacy,
)
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
enrichment_with_vertex_ai, )
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
enrichment_with_bigquery_storage_basic,
enrichment_with_bigquery_storage_custom_function,
)
except ImportError:
raise unittest.SkipTest("RequestResponseIO dependencies are not installed")
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements are incorrectly placed inside a try block. The imports on lines 37-46 should be moved after the existing imports that start at line 49, as they are currently inside a try block that will cause syntax issues.

Copilot uses AI. Check for mistakes.
from sqlalchemy import (
Column, Integer, VARCHAR, Engine, MetaData, create_engine)
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
Expand All @@ -57,32 +70,65 @@


def validate_enrichment_with_bigtable():
expected = '''[START enrichment_with_bigtable]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'})
Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'})
[END enrichment_with_bigtable]'''.splitlines()[1:-1]
expected = (
"""[START enrichment_with_bigtable]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, """
"""product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Row(sale_id=3, customer_id=3, product_id=2, quantity=3, """
"""product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'})
Row(sale_id=5, customer_id=5, product_id=4, quantity=2, """
"""product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'})
[END enrichment_with_bigtable]""").splitlines()[1:-1]
return expected


def validate_enrichment_with_vertex_ai():
expected = '''[START enrichment_with_vertex_ai]
Row(user_id='2963', product_id=14235, sale_price=15.0, age=12.0, state='1', gender='1', country='1')
Row(user_id='21422', product_id=11203, sale_price=12.0, age=12.0, state='0', gender='0', country='0')
Row(user_id='20592', product_id=8579, sale_price=9.0, age=12.0, state='2', gender='1', country='2')
[END enrichment_with_vertex_ai]'''.splitlines()[1:-1]
expected = (
"""[START enrichment_with_vertex_ai]
Row(user_id='2963', product_id=14235, sale_price=15.0, """
"""age=12.0, state='1', gender='1', country='1')
Row(user_id='21422', product_id=11203, sale_price=12.0, """
"""age=12.0, state='0', gender='0', country='0')
Row(user_id='20592', product_id=8579, sale_price=9.0, """
"""age=12.0, state='2', gender='1', country='2')
[END enrichment_with_vertex_ai]""").splitlines()[1:-1]
return expected


def validate_enrichment_with_vertex_ai_legacy():
expected = '''[START enrichment_with_vertex_ai_legacy]
expected = """[START enrichment_with_vertex_ai_legacy]
Row(entity_id='movie_01', title='The Shawshank Redemption', genres='Drama')
Row(entity_id='movie_02', title='The Shining', genres='Horror')
Row(entity_id='movie_04', title='The Dark Knight', genres='Action')
[END enrichment_with_vertex_ai_legacy]'''.splitlines()[1:-1]
[END enrichment_with_vertex_ai_legacy]""".splitlines()[1:-1]
return expected


def validate_enrichment_with_bigquery_storage_basic():
expected = (
"""[START enrichment_with_bigquery_storage_basic]
Row(sale_id=1001, product_id=101, customer_id=501, quantity=2, """
"""product_id=101, product_name='Laptop Pro', category='Electronics', unit_price=999.99)
Row(sale_id=1002, product_id=102, customer_id=502, quantity=1, """
"""product_id=102, product_name='Wireless Mouse', category='Electronics', unit_price=29.99)
Row(sale_id=1003, product_id=103, customer_id=503, quantity=5, """
"""product_id=103, product_name='Office Chair', category='Furniture', unit_price=199.99)
[END enrichment_with_bigquery_storage_basic]""").splitlines()[1:-1]
return expected


def validate_enrichment_with_bigquery_storage_custom_function():
expected = (
"""[START enrichment_with_bigquery_storage_custom_function]
Row(sale_id=1002, product_id=102, category='Electronics', customer_id=502, """
"""quantity=4, prod_id=102, name='Wireless Mouse', category='Electronics', price=29.99)
Row(sale_id=1004, product_id=101, category='Electronics', customer_id=504, """
"""quantity=6, prod_id=101, name='Laptop Pro', category='Electronics', price=999.99)
[END enrichment_with_bigquery_storage_custom_function]""").splitlines()[1:-1]
return expected


@mock.patch("sys.stdout", new_callable=StringIO)
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function definition is missing proper indentation and parameter. The function should be a method with self parameter and proper indentation within the test class, not a standalone function with a decorator.

Suggested change
@mock.patch("sys.stdout", new_callable=StringIO)

Copilot uses AI. Check for mistakes.
def validate_enrichment_with_google_cloudsql_pg():
expected = '''[START enrichment_with_google_cloudsql_pg]
Row(product_id=1, name='A', quantity=2, region_id=3)
Expand Down Expand Up @@ -134,7 +180,7 @@ def test_enrichment_with_vertex_ai(self, mock_stdout):
expected = sorted(validate_enrichment_with_vertex_ai())

for i in range(len(expected)):
self.assertEqual(set(output[i].split(',')), set(expected[i].split(',')))
self.assertEqual(set(output[i].split(",")), set(expected[i].split(",")))

def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
enrichment_with_vertex_ai_legacy()
Expand Down Expand Up @@ -310,6 +356,19 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct):
os.environ.pop('GOOGLE_CLOUD_SQL_DB_PASSWORD', None)
os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)

def test_enrichment_with_bigquery_storage_basic(self, mock_stdout):
enrichment_with_bigquery_storage_basic()
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_bigquery_storage_basic()
self.maxDiff = None
self.assertEqual(output, expected)

def test_enrichment_with_bigquery_storage_custom_function(self, mock_stdout):
enrichment_with_bigquery_storage_custom_function()
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_bigquery_storage_custom_function()
self.assertEqual(output, expected)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Loading
Loading