|
| 1 | +# coding=utf-8 |
1 | 2 | # |
2 | 3 | # Licensed to the Apache Software Foundation (ASF) under one or more |
3 | 4 | # contributor license agreements. See the NOTICE file distributed with |
|
14 | 15 | # See the License for the specific language governing permissions and |
15 | 16 | # limitations under the License. |
16 | 17 | # |
| 18 | +# pytype: skip-file |
| 19 | +# pylint: disable=line-too-long |
17 | 20 |
|
18 | | -import logging |
19 | 21 | import unittest |
| 22 | +from io import StringIO |
20 | 23 |
|
21 | | -import apache_beam as beam |
| 24 | +import mock |
22 | 25 |
|
23 | | -# pylint: disable=ungrouped-imports |
| 26 | +# pylint: disable=unused-import |
24 | 27 | try: |
25 | | - from apache_beam.transforms.enrichment import cross_join |
| 28 | + from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_bigtable, \ |
| 29 | + enrichment_with_vertex_ai_legacy |
| 30 | + from apache_beam.examples.snippets.transforms.elementwise.enrichment import enrichment_with_vertex_ai |
| 31 | + from apache_beam.io.requestresponse import RequestResponseIO |
26 | 32 | except ImportError: |
27 | | - raise unittest.SkipTest('RequestResponseIO dependencies are not installed.') |
| 33 | + raise unittest.SkipTest('RequestResponseIO dependencies are not installed') |
28 | 34 |
|
29 | 35 |
|
30 | | -class TestEnrichmentTransform(unittest.TestCase): |
31 | | - def test_cross_join(self): |
32 | | - left = {'id': 1, 'key': 'city'} |
33 | | - right = {'id': 1, 'value': 'durham'} |
34 | | - expected = beam.Row(id=1, key='city', value='durham') |
35 | | - output = cross_join(left, right) |
36 | | - self.assertEqual(expected, output) |
| 36 | +def validate_enrichment_with_bigtable(): |
| 37 | + expected = '''[START enrichment_with_bigtable] |
| 38 | +Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) |
| 39 | +Row(sale_id=3, customer_id=3, product_id=2, quantity=3, product={'product_id': '2', 'product_name': 'pixel 6', 'product_stock': '4'}) |
| 40 | +Row(sale_id=5, customer_id=5, product_id=4, quantity=2, product={'product_id': '4', 'product_name': 'pixel 8', 'product_stock': '10'}) |
| 41 | + [END enrichment_with_bigtable]'''.splitlines()[1:-1] |
| 42 | + return expected |
| 43 | + |
| 44 | + |
| 45 | +def validate_enrichment_with_vertex_ai(): |
| 46 | + expected = '''[START enrichment_with_vertex_ai] |
| 47 | +Row(user_id='2963', product_id=14235, sale_price=15.0, age=12.0, state='1', gender='1', country='1') |
| 48 | +Row(user_id='21422', product_id=11203, sale_price=12.0, age=12.0, state='0', gender='0', country='0') |
| 49 | +Row(user_id='20592', product_id=8579, sale_price=9.0, age=12.0, state='2', gender='1', country='2') |
| 50 | + [END enrichment_with_vertex_ai]'''.splitlines()[1:-1] |
| 51 | + return expected |
| 52 | + |
| 53 | + |
| 54 | +def validate_enrichment_with_vertex_ai_legacy(): |
| 55 | + expected = '''[START enrichment_with_vertex_ai_legacy] |
| 56 | +Row(entity_id='movie_01', title='The Shawshank Redemption', genres='Drama') |
| 57 | +Row(entity_id='movie_02', title='The Shining', genres='Horror') |
| 58 | +Row(entity_id='movie_04', title='The Dark Knight', genres='Action') |
| 59 | + [END enrichment_with_vertex_ai_legacy]'''.splitlines()[1:-1] |
| 60 | + return expected |
| 61 | + |
| 62 | + |
| 63 | +@mock.patch('sys.stdout', new_callable=StringIO) |
| 64 | +class EnrichmentTest(unittest.TestCase): |
| 65 | + def test_enrichment_with_bigtable(self, mock_stdout): |
| 66 | + enrichment_with_bigtable() |
| 67 | + output = mock_stdout.getvalue().splitlines() |
| 68 | + expected = validate_enrichment_with_bigtable() |
| 69 | + self.assertEqual(output, expected) |
| 70 | + |
| 71 | + def test_enrichment_with_vertex_ai(self, mock_stdout): |
| 72 | + enrichment_with_vertex_ai() |
| 73 | + output = mock_stdout.getvalue().splitlines() |
| 74 | + expected = validate_enrichment_with_vertex_ai() |
| 75 | + |
| 76 | + for i in range(len(expected)): |
| 77 | + self.assertEqual(set(output[i].split(',')), set(expected[i].split(','))) |
| 78 | + |
| 79 | + def test_enrichment_with_vertex_ai_legacy(self, mock_stdout): |
| 80 | + enrichment_with_vertex_ai_legacy() |
| 81 | + output = mock_stdout.getvalue().splitlines() |
| 82 | + expected = validate_enrichment_with_vertex_ai_legacy() |
| 83 | + self.maxDiff = None |
| 84 | + self.assertEqual(output, expected) |
37 | 85 |
|
38 | 86 |
|
39 | 87 | if __name__ == '__main__': |
40 | | - logging.getLogger().setLevel(logging.INFO) |
41 | 88 | unittest.main() |
0 commit comments