Skip to content

Commit 5a65f6e

Browse files
authored
Exposed source window to AssignContext in window_mapping_fn for side inputs (#36722)
* added source window to assign context window mapping fn * yapf * added default window mapping test * side input test * added window and windowing fn required methods * added window and windowing fn required methods * updated typehint for assign in sideinputs_test.py * added none assert * removed unused mock import in sideinput_test.py * removed extraneous pdb file * removed 1igm.pdb * fixed import order
1 parent d2aed60 commit 5a65f6e

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

sdks/python/apache_beam/transforms/sideinputs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def default_window_mapping_fn(
6060
def map_via_end(source_window: window.BoundedWindow) -> window.BoundedWindow:
6161
return list(
6262
target_window_fn.assign(
63-
window.WindowFn.AssignContext(source_window.max_timestamp())))[-1]
63+
window.WindowFn.AssignContext(
64+
source_window.max_timestamp(), window=source_window)))[-1]
6465

6566
return map_via_end
6667

sdks/python/apache_beam/transforms/sideinputs_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from apache_beam.testing.util import equal_to
4040
from apache_beam.testing.util import equal_to_per_window
4141
from apache_beam.transforms import Map
42+
from apache_beam.transforms import sideinputs
4243
from apache_beam.transforms import trigger
4344
from apache_beam.transforms import window
4445
from apache_beam.utils.timestamp import Timestamp
@@ -489,6 +490,40 @@ def process(
489490
assert_that(results, equal_to([(num_records, expected_fingerprint)]))
490491
pipeline.run()
491492

493+
def test_default_window_mapping_fn_source_window(self):
494+
"""Test that the default window mapping function will propagate the
495+
source window when attempting to assign context.
496+
"""
497+
class StringIDWindow(window.BoundedWindow):
498+
"""A window defined by an arbitrary string ID."""
499+
def __init__(self, window_id: str):
500+
super().__init__(self._getTimestampFromProto())
501+
self.id = window_id
502+
503+
@staticmethod
504+
def _getTimestampFromProto() -> Timestamp:
505+
return Timestamp(micros=0)
506+
507+
class StringIDWindows(window.NonMergingWindowFn):
508+
""" A windowing function that assigns each element a window with ID."""
509+
def assign(
510+
self, assign_context: window.WindowFn.AssignContext
511+
) -> Iterable[window.BoundedWindow]:
512+
if assign_context.element is None:
513+
assert assign_context.window is not None
514+
return [assign_context.window]
515+
return [StringIDWindow(str(assign_context.element))]
516+
517+
def get_window_coder(self):
518+
return None
519+
520+
mapping_fn = sideinputs.default_window_mapping_fn(StringIDWindows())
521+
source_window = StringIDWindows().assign(
522+
window.WindowFn.AssignContext(Timestamp(10), element='element'))[0]
523+
bounded_window = mapping_fn(source_window)
524+
assert bounded_window is not None
525+
assert bounded_window.id == 'element'
526+
492527

493528
if __name__ == '__main__':
494529
logging.getLogger().setLevel(logging.DEBUG)

0 commit comments

Comments
 (0)