|
19 | 19 |
|
20 | 20 | # pytype: skip-file |
21 | 21 |
|
| 22 | +import hashlib |
22 | 23 | import itertools |
23 | 24 | import logging |
24 | 25 | import unittest |
| 26 | +from typing import Any |
| 27 | +from typing import Dict |
| 28 | +from typing import Iterable |
| 29 | +from typing import Tuple |
| 30 | +from typing import Union |
25 | 31 |
|
26 | 32 | import pytest |
27 | 33 |
|
28 | 34 | import apache_beam as beam |
| 35 | +from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource |
29 | 36 | from apache_beam.testing.test_pipeline import TestPipeline |
30 | 37 | from apache_beam.testing.test_stream import TestStream |
31 | 38 | from apache_beam.testing.util import assert_that |
@@ -417,6 +424,71 @@ def process( |
417 | 424 | use_global_window=False, |
418 | 425 | label='assert per window') |
419 | 426 |
|
| 427 | + @pytest.mark.it_validatesrunner |
| 428 | + def test_side_input_with_sdf(self): |
| 429 | + """Test a side input with SDF. |
| 430 | +
|
| 431 | + This test verifies consisency of side input when it is split due to |
| 432 | + SDF (Splittable DoFns). The consistency is verified by checking the size |
| 433 | + and fingerprint of the side input. |
| 434 | +
|
| 435 | + This test needs to run with at least 2 workers (--num_workers=2) and |
| 436 | + autoscaling disabled (--autoscaling_algorithm=NONE). Otherwise it might |
| 437 | + provide false positives (i.e. not fail on bad state). |
| 438 | + """ |
| 439 | + initial_elements = 1000 |
| 440 | + num_records = 10000 |
| 441 | + key_size = 10 |
| 442 | + value_size = 100 |
| 443 | + expected_fingerprint = '00f7eeac8514721e2683d14a504b33d1' |
| 444 | + |
| 445 | + class GetSyntheticSDFOptions(beam.DoFn): |
| 446 | + """A DoFn that emits elements for genenrating SDF.""" |
| 447 | + def process(self, element: Any) -> Iterable[Dict[str, Union[int, str]]]: |
| 448 | + yield { |
| 449 | + 'num_records': num_records // initial_elements, |
| 450 | + 'key_size': key_size, |
| 451 | + 'value_size': value_size, |
| 452 | + 'initial_splitting_num_bundles': 0, |
| 453 | + 'initial_splitting_desired_bundle_size': 0, |
| 454 | + 'sleep_per_input_record_sec': 0, |
| 455 | + 'initial_splitting': 'const', |
| 456 | + } |
| 457 | + |
| 458 | + class SideInputTrackingDoFn(beam.DoFn): |
| 459 | + """A DoFn that emits the size and fingerprint of the side input. |
| 460 | +
|
| 461 | + In this context, the size is the number of elements and the fingerprint |
| 462 | + is the hash of the sorted serialized elements. |
| 463 | + """ |
| 464 | + def process( |
| 465 | + self, element: Any, |
| 466 | + side_input: Iterable[Tuple[bytes, |
| 467 | + bytes]]) -> Iterable[Tuple[int, str]]: |
| 468 | + |
| 469 | + # Sort for consistent hashing. |
| 470 | + sorted_side_input = sorted(side_input) |
| 471 | + size = len(sorted_side_input) |
| 472 | + m = hashlib.md5() |
| 473 | + for key, value in sorted_side_input: |
| 474 | + m.update(key) |
| 475 | + m.update(value) |
| 476 | + yield (size, m.hexdigest()) |
| 477 | + |
| 478 | + pipeline = self.create_pipeline() |
| 479 | + main_input = pipeline | 'Main input: Create' >> beam.Create([0]) |
| 480 | + side_input = pipeline | 'Side input: Create' >> beam.Create( |
| 481 | + range(initial_elements)) |
| 482 | + side_input |= 'Side input: Get synthetic SDF options' >> beam.ParDo( |
| 483 | + GetSyntheticSDFOptions()) |
| 484 | + side_input |= 'Side input: Process and split' >> beam.ParDo( |
| 485 | + SyntheticSDFAsSource()) |
| 486 | + results = main_input | 'Emit side input' >> beam.ParDo( |
| 487 | + SideInputTrackingDoFn(), beam.pvalue.AsIter(side_input)) |
| 488 | + |
| 489 | + assert_that(results, equal_to([(num_records, expected_fingerprint)])) |
| 490 | + pipeline.run() |
| 491 | + |
420 | 492 |
|
421 | 493 | if __name__ == '__main__': |
422 | 494 | logging.getLogger().setLevel(logging.DEBUG) |
|
0 commit comments