Skip to content

Commit 18373bf

Browse files
authored
Preserve pane index through reshuffle. (#34348)
* Preserve pane index through reshuffle. * Fix coders. * Add coder test. * Fix lint error. * Change compat version. * refactor. * Fix typehint issue. * Use fn api runner. * Fix lint error. * Refactor. * Remove strange duplciation. * Revert prism coder changes.
1 parent bd2891d commit 18373bf

File tree

3 files changed

+309
-9
lines changed

3 files changed

+309
-9
lines changed

sdks/python/apache_beam/testing/util.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
import glob
2424
import io
2525
import tempfile
26+
from typing import Any
2627
from typing import Iterable
28+
from typing import List
29+
from typing import NamedTuple
2730

2831
from apache_beam import pvalue
2932
from apache_beam.transforms import window
@@ -35,6 +38,8 @@
3538
from apache_beam.transforms.ptransform import PTransform
3639
from apache_beam.transforms.ptransform import ptransform_fn
3740
from apache_beam.transforms.util import CoGroupByKey
41+
from apache_beam.utils.windowed_value import PANE_INFO_UNKNOWN
42+
from apache_beam.utils.windowed_value import PaneInfo
3843

3944
__all__ = [
4045
'assert_that',
@@ -56,8 +61,11 @@ class BeamAssertException(Exception):
5661

5762

5863
# Used for reifying timestamps and windows for assert_that matchers.
59-
TestWindowedValue = collections.namedtuple(
60-
'TestWindowedValue', 'value timestamp windows')
64+
class TestWindowedValue(NamedTuple):
65+
value: Any
66+
timestamp: Any
67+
windows: List
68+
pane_info: PaneInfo = PANE_INFO_UNKNOWN
6169

6270

6371
def contains_in_any_order(iterable):
@@ -290,11 +298,15 @@ def assert_that(
290298

291299
class ReifyTimestampWindow(DoFn):
292300
def process(
293-
self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
301+
self,
302+
element,
303+
timestamp=DoFn.TimestampParam,
304+
window=DoFn.WindowParam,
305+
pane_info=DoFn.PaneInfoParam):
294306
# This returns TestWindowedValue instead of
295307
# beam.utils.windowed_value.WindowedValue because ParDo will extract
296308
# the timestamp and window out of the latter.
297-
return [TestWindowedValue(element, timestamp, [window])]
309+
return [TestWindowedValue(element, timestamp, [window], pane_info)]
298310

299311
class AddWindow(DoFn):
300312
def process(self, element, window=DoFn.WindowParam):

sdks/python/apache_beam/transforms/util.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,6 @@ def is_compat_version_prior_to(options, breaking_change_version):
933933
# keep the old behavior prior to a breaking change or use the new behavior.
934934
# - If update_compatibility_version < breaking_change_version, we will return
935935
# True and keep the old behavior.
936-
# - If update_compatibility_version is None or >= breaking_change_version, we
937-
# will return False and use the behavior from the breaking change.
938936
update_compatibility_version = options.view_as(
939937
pipeline_options.StreamingOptions).update_compatibility_version
940938

@@ -949,6 +947,53 @@ def is_compat_version_prior_to(options, breaking_change_version):
949947
return False
950948

951949

950+
def reify_metadata_default_window(
951+
element, timestamp=DoFn.TimestampParam, pane_info=DoFn.PaneInfoParam):
952+
key, value = element
953+
if timestamp == window.MIN_TIMESTAMP:
954+
timestamp = None
955+
return key, (value, timestamp, pane_info)
956+
957+
958+
def restore_metadata_default_window(element):
959+
key, values = element
960+
return [
961+
window.GlobalWindows.windowed_value(None).with_value((key, value))
962+
if timestamp is None else window.GlobalWindows.windowed_value(
963+
value=(key, value), timestamp=timestamp, pane_info=pane_info)
964+
for (value, timestamp, pane_info) in values
965+
]
966+
967+
968+
def reify_metadata_custom_window(
969+
element,
970+
timestamp=DoFn.TimestampParam,
971+
window=DoFn.WindowParam,
972+
pane_info=DoFn.PaneInfoParam):
973+
key, value = element
974+
return key, windowed_value.WindowedValue(
975+
value, timestamp, [window], pane_info)
976+
977+
978+
def restore_metadata_custom_window(element):
979+
key, windowed_values = element
980+
return [wv.with_value((key, wv.value)) for wv in windowed_values]
981+
982+
983+
def _reify_restore_metadata(is_default_windowing):
984+
if is_default_windowing:
985+
return reify_metadata_default_window, restore_metadata_default_window
986+
return reify_metadata_custom_window, restore_metadata_custom_window
987+
988+
989+
def _add_pre_map_gkb_types(pre_gbk_map, is_default_windowing):
990+
if is_default_windowing:
991+
return pre_gbk_map.with_input_types(tuple[K, V]).with_output_types(
992+
tuple[K, tuple[V, Optional[Timestamp], windowed_value.PaneInfo]])
993+
return pre_gbk_map.with_input_types(tuple[K, V]).with_output_types(
994+
tuple[K, TypedWindowedValue[V]])
995+
996+
952997
@typehints.with_input_types(tuple[K, V])
953998
@typehints.with_output_types(tuple[K, V])
954999
class ReshufflePerKey(PTransform):
@@ -957,7 +1002,7 @@ class ReshufflePerKey(PTransform):
9571002
in particular checkpointing, and preventing fusion of the surrounding
9581003
transforms.
9591004
"""
960-
def expand(self, pcoll):
1005+
def expand_2_64_0(self, pcoll):
9611006
windowing_saved = pcoll.windowing
9621007
if windowing_saved.is_default():
9631008
# In this (common) case we can use a trivial trigger driver
@@ -1023,6 +1068,33 @@ def restore_timestamps(element):
10231068
result._windowing = windowing_saved
10241069
return result
10251070

1071+
def expand(self, pcoll):
1072+
if is_compat_version_prior_to(pcoll.pipeline.options, "2.65.0"):
1073+
return self.expand_2_64_0(pcoll)
1074+
1075+
windowing_saved = pcoll.windowing
1076+
is_default_windowing = windowing_saved.is_default()
1077+
reify_fn, restore_fn = _reify_restore_metadata(is_default_windowing)
1078+
1079+
pre_gbk_map = _add_pre_map_gkb_types(Map(reify_fn), is_default_windowing)
1080+
1081+
ungrouped = pcoll | pre_gbk_map
1082+
1083+
# TODO(https://github.com/apache/beam/issues/19785) Using global window as
1084+
# one of the standard window. This is to mitigate the Dataflow Java Runner
1085+
# Harness limitation to accept only standard coders.
1086+
ungrouped._windowing = Windowing(
1087+
window.GlobalWindows(),
1088+
triggerfn=Always(),
1089+
accumulation_mode=AccumulationMode.DISCARDING,
1090+
timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
1091+
result = (
1092+
ungrouped
1093+
| GroupByKey()
1094+
| FlatMap(restore_fn).with_output_types(Any))
1095+
result._windowing = windowing_saved
1096+
return result
1097+
10261098

10271099
@typehints.with_input_types(T)
10281100
@typehints.with_output_types(T)

0 commit comments

Comments
 (0)