Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,26 @@ def enrichment_with_bigtable():
# [START enrichment_with_bigtable]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigtable import BigTableEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.bigtable import (
BigTableEnrichmentHandler, )

project_id = 'apache-beam-testing'
instance_id = 'beam-test'
table_id = 'bigtable-enrichment-test'
row_key = 'product_id'
project_id = "apache-beam-testing"
instance_id = "beam-test"
table_id = "bigtable-enrichment-test"
row_key = "product_id"

data = [
beam.Row(sale_id=1, customer_id=1, product_id=1, quantity=1),
beam.Row(sale_id=3, customer_id=3, product_id=2, quantity=3),
beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2)
beam.Row(sale_id=5, customer_id=5, product_id=4, quantity=2),
]

bigtable_handler = BigTableEnrichmentHandler(
project_id=project_id,
instance_id=instance_id,
table_id=table_id,
row_key=row_key)
row_key=row_key,
)
with beam.Pipeline() as p:
_ = (
p
Expand All @@ -55,16 +57,16 @@ def enrichment_with_vertex_ai():
# [START enrichment_with_vertex_ai]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \
import VertexAIFeatureStoreEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import (
VertexAIFeatureStoreEnrichmentHandler, )

project_id = 'apache-beam-testing'
location = 'us-central1'
project_id = "apache-beam-testing"
location = "us-central1"
api_endpoint = f"{location}-aiplatform.googleapis.com"
data = [
beam.Row(user_id='2963', product_id=14235, sale_price=15.0),
beam.Row(user_id='21422', product_id=11203, sale_price=12.0),
beam.Row(user_id='20592', product_id=8579, sale_price=9.0),
beam.Row(user_id="2963", product_id=14235, sale_price=15.0),
beam.Row(user_id="21422", product_id=11203, sale_price=12.0),
beam.Row(user_id="20592", product_id=8579, sale_price=9.0),
]

vertex_ai_handler = VertexAIFeatureStoreEnrichmentHandler(
Expand All @@ -88,23 +90,23 @@ def enrichment_with_vertex_ai_legacy():
# [START enrichment_with_vertex_ai_legacy]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \
import VertexAIFeatureStoreLegacyEnrichmentHandler
from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import (
VertexAIFeatureStoreLegacyEnrichmentHandler, )

project_id = 'apache-beam-testing'
location = 'us-central1'
project_id = "apache-beam-testing"
location = "us-central1"
api_endpoint = f"{location}-aiplatform.googleapis.com"
data = [
beam.Row(entity_id="movie_01", title='The Shawshank Redemption'),
beam.Row(entity_id="movie_01", title="The Shawshank Redemption"),
beam.Row(entity_id="movie_02", title="The Shining"),
beam.Row(entity_id="movie_04", title='The Dark Knight'),
beam.Row(entity_id="movie_04", title="The Dark Knight"),
]

vertex_ai_handler = VertexAIFeatureStoreLegacyEnrichmentHandler(
project=project_id,
location=location,
api_endpoint=api_endpoint,
entity_type_id='movies',
entity_type_id="movies",
feature_store_id="movie_prediction_unique",
feature_ids=["title", "genres"],
row_key="entity_id",
Expand All @@ -116,3 +118,118 @@ def enrichment_with_vertex_ai_legacy():
| "Enrich W/ Vertex AI" >> Enrichment(vertex_ai_handler)
| "Print" >> beam.Map(print))
# [END enrichment_with_vertex_ai_legacy]


def enrichment_with_bigquery_storage_basic():
# [START enrichment_with_bigquery_storage_basic]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import (
BigQueryStorageEnrichmentHandler, )

project_id = "apache-beam-testing"
dataset = "beam-test"
table_name = "bigquery-enrichment-test-products"
# Sample sales data to enrich
sales_data = [
beam.Row(sale_id=1001, product_id=101, customer_id=501, quantity=2),
beam.Row(sale_id=1002, product_id=102, customer_id=502, quantity=1),
beam.Row(sale_id=1003, product_id=103, customer_id=503, quantity=5),
]

# Basic enrichment - enrich sales data with product information
handler = BigQueryStorageEnrichmentHandler(
project=project_id,
table_name=f"{project_id}.{dataset}.{table_name}",
row_restriction_template="id = {product_id}",
fields=["product_id"],
column_names=[
"id as product_id", "product_name", "category", "unit_price"
],
)

with beam.Pipeline() as p:
_ = (
p
| "Create Sales Data" >> beam.Create(sales_data)
| "Enrich with Product Info" >> Enrichment(handler)
| "Print Results" >> beam.Map(print))
# [END enrichment_with_bigquery_storage_basic]


def enrichment_with_bigquery_storage_custom_function():
# [START enrichment_with_bigquery_storage_custom_function]
import apache_beam as beam
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment_handlers.bigquery_storage_read import (
BigQueryStorageEnrichmentHandler, )

project_id = "apache-beam-testing"
dataset = "beam-test"
table_name = "bigquery-enrichment-test-products"
# Advanced sales data with category and quantity
sales_data = [
beam.Row(
sale_id=1001,
product_id=101,
category="Electronics",
customer_id=501,
quantity=2,
),
beam.Row(
sale_id=1002,
product_id=102,
category="Electronics",
customer_id=502,
quantity=4,
),
beam.Row(
sale_id=1003,
product_id=103,
category="Furniture",
customer_id=503,
quantity=5,
),
beam.Row(
sale_id=1004,
product_id=101,
category="Electronics",
customer_id=504,
quantity=6,
),
]

def build_row_restriction(condition_values, primary_keys, req_row):
# Only enrich if quantity > 2 and category is Electronics
if req_row.quantity > 2 and req_row.category == "Electronics":
return f'id = {req_row.product_id} AND category = "{req_row.category}"'
else:
return None # Skip enrichment for this row

def extract_condition_values(req_row):
return {
"product_id": req_row.product_id,
"category": req_row.category,
"quantity": req_row.quantity,
}

handler = BigQueryStorageEnrichmentHandler(
project=project_id,
table_name=f"{project_id}.{dataset}.{table_name}",
row_restriction_template_fn=build_row_restriction,
condition_value_fn=extract_condition_values,
column_names=[
"id as prod_id",
"product_name as name",
"category",
"unit_price as price",
],
)

with beam.Pipeline() as p:
_ = (
p
| "Create Sales Data" >> beam.Create(sales_data)
| "Enrich with Product Info (Advanced)" >> Enrichment(handler)
| "Print Results" >> beam.Map(print))
# [END enrichment_with_bigquery_storage_custom_function]
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#
# pytype: skip-file
# pylint: disable=line-too-long
# ruff: noqa: E501

import unittest
from io import StringIO
Expand All @@ -25,42 +26,80 @@

# pylint: disable=unused-import
try:
from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_bigtable, \
enrichment_with_vertex_ai_legacy
from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_vertex_ai
from apache_beam.io.requestresponse import RequestResponseIO
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
enrichment_with_bigtable,
enrichment_with_vertex_ai_legacy,
)
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
enrichment_with_vertex_ai, )
from apache_beam.examples.snippets.transforms.elementwise.enrichment import (
enrichment_with_bigquery_storage_basic,
enrichment_with_bigquery_storage_custom_function,
)
except ImportError:
raise unittest.SkipTest('RequestResponseIO dependencies are not installed')
raise unittest.SkipTest("RequestResponseIO dependencies are not installed")


def validate_enrichment_with_bigtable():
expected = '''[START enrichment_with_bigtable]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'})
Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'})
[END enrichment_with_bigtable]'''.splitlines()[1:-1]
expected = (
"""[START enrichment_with_bigtable]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, """
"""product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Row(sale_id=3, customer_id=3, product_id=2, quantity=3, """
"""product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'})
Row(sale_id=5, customer_id=5, product_id=4, quantity=2, """
"""product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'})
[END enrichment_with_bigtable]""").splitlines()[1:-1]
return expected


def validate_enrichment_with_vertex_ai():
expected = '''[START enrichment_with_vertex_ai]
Row(user_id='2963', product_id=14235, sale_price=15.0, age=12.0, state='1', gender='1', country='1')
Row(user_id='21422', product_id=11203, sale_price=12.0, age=12.0, state='0', gender='0', country='0')
Row(user_id='20592', product_id=8579, sale_price=9.0, age=12.0, state='2', gender='1', country='2')
[END enrichment_with_vertex_ai]'''.splitlines()[1:-1]
expected = (
"""[START enrichment_with_vertex_ai]
Row(user_id='2963', product_id=14235, sale_price=15.0, """
"""age=12.0, state='1', gender='1', country='1')
Row(user_id='21422', product_id=11203, sale_price=12.0, """
"""age=12.0, state='0', gender='0', country='0')
Row(user_id='20592', product_id=8579, sale_price=9.0, """
"""age=12.0, state='2', gender='1', country='2')
[END enrichment_with_vertex_ai]""").splitlines()[1:-1]
return expected


def validate_enrichment_with_vertex_ai_legacy():
expected = '''[START enrichment_with_vertex_ai_legacy]
expected = """[START enrichment_with_vertex_ai_legacy]
Row(entity_id='movie_01', title='The Shawshank Redemption', genres='Drama')
Row(entity_id='movie_02', title='The Shining', genres='Horror')
Row(entity_id='movie_04', title='The Dark Knight', genres='Action')
[END enrichment_with_vertex_ai_legacy]'''.splitlines()[1:-1]
[END enrichment_with_vertex_ai_legacy]""".splitlines()[1:-1]
return expected


@mock.patch('sys.stdout', new_callable=StringIO)
def validate_enrichment_with_bigquery_storage_basic():
expected = (
"""[START enrichment_with_bigquery_storage_basic]
Row(sale_id=1001, product_id=101, customer_id=501, quantity=2, """
"""product_id=101, product_name='Laptop Pro', category='Electronics', unit_price=999.99)
Row(sale_id=1002, product_id=102, customer_id=502, quantity=1, """
"""product_id=102, product_name='Wireless Mouse', category='Electronics', unit_price=29.99)
Row(sale_id=1003, product_id=103, customer_id=503, quantity=5, """
"""product_id=103, product_name='Office Chair', category='Furniture', unit_price=199.99)
[END enrichment_with_bigquery_storage_basic]""").splitlines()[1:-1]
return expected


def validate_enrichment_with_bigquery_storage_custom_function():
expected = (
"""[START enrichment_with_bigquery_storage_custom_function]
Row(sale_id=1002, product_id=102, category='Electronics', customer_id=502, """
"""quantity=4, prod_id=102, name='Wireless Mouse', category='Electronics', price=29.99)
Row(sale_id=1004, product_id=101, category='Electronics', customer_id=504, """
"""quantity=6, prod_id=101, name='Laptop Pro', category='Electronics', price=999.99)
[END enrichment_with_bigquery_storage_custom_function]""").splitlines()[1:-1]
return expected


@mock.patch("sys.stdout", new_callable=StringIO)
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function definition is missing proper indentation and parameter. The function should be a method with self parameter and proper indentation within the test class, not a standalone function with a decorator.

Suggested change
@mock.patch("sys.stdout", new_callable=StringIO)

Copilot uses AI. Check for mistakes.
class EnrichmentTest(unittest.TestCase):
def test_enrichment_with_bigtable(self, mock_stdout):
enrichment_with_bigtable()
Expand All @@ -74,7 +113,7 @@ def test_enrichment_with_vertex_ai(self, mock_stdout):
expected = validate_enrichment_with_vertex_ai()

for i in range(len(expected)):
self.assertEqual(set(output[i].split(',')), set(expected[i].split(',')))
self.assertEqual(set(output[i].split(",")), set(expected[i].split(",")))

def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
enrichment_with_vertex_ai_legacy()
Expand All @@ -83,6 +122,19 @@ def test_enrichment_with_vertex_ai_legacy(self, mock_stdout):
self.maxDiff = None
self.assertEqual(output, expected)

def test_enrichment_with_bigquery_storage_basic(self, mock_stdout):
enrichment_with_bigquery_storage_basic()
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_bigquery_storage_basic()
self.maxDiff = None
self.assertEqual(output, expected)

def test_enrichment_with_bigquery_storage_custom_function(self, mock_stdout):
enrichment_with_bigquery_storage_custom_function()
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_bigquery_storage_custom_function()
self.assertEqual(output, expected)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Loading
Loading