Skip to content

Commit a0831e0

Browse files
authored
Add AftersynchronizedProcessing Time as continuation trigger (#36285)
* Add AftersynchronizedProcessing Time as continuation trigger * fix trailing space * fix trailing space * fix formatting
1 parent 340d420 commit a0831e0

File tree

4 files changed

+196
-4
lines changed

4 files changed

+196
-4
lines changed

sdks/python/apache_beam/transforms/core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3341,6 +3341,18 @@ def infer_output_type(self, input_type):
33413341
return typehints.KV[
33423342
key_type, typehints.WindowedValue[value_type]] # type: ignore[misc]
33433343

3344+
def get_windowing(self, inputs):
3345+
# Switch to the continuation trigger associated with the current trigger.
3346+
windowing = inputs[0].windowing
3347+
triggerfn = windowing.triggerfn.get_continuation_trigger()
3348+
return Windowing(
3349+
windowfn=windowing.windowfn,
3350+
triggerfn=triggerfn,
3351+
accumulation_mode=windowing.accumulation_mode,
3352+
timestamp_combiner=windowing.timestamp_combiner,
3353+
allowed_lateness=windowing.allowed_lateness,
3354+
environment_id=windowing.environment_id)
3355+
33443356
def expand(self, pcoll):
33453357
from apache_beam.transforms.trigger import DataLossReason
33463358
from apache_beam.transforms.trigger import DefaultTrigger

sdks/python/apache_beam/transforms/ptransform_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from apache_beam.metrics import Metrics
4848
from apache_beam.metrics.metric import MetricsFilter
4949
from apache_beam.options.pipeline_options import PipelineOptions
50+
from apache_beam.options.pipeline_options import StandardOptions
5051
from apache_beam.options.pipeline_options import StreamingOptions
5152
from apache_beam.options.pipeline_options import TypeOptions
5253
from apache_beam.portability import common_urns
@@ -61,6 +62,9 @@
6162
from apache_beam.transforms.display import DisplayData
6263
from apache_beam.transforms.display import DisplayDataItem
6364
from apache_beam.transforms.ptransform import PTransform
65+
from apache_beam.transforms.trigger import AccumulationMode
66+
from apache_beam.transforms.trigger import AfterProcessingTime
67+
from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime
6468
from apache_beam.transforms.window import TimestampedValue
6569
from apache_beam.typehints import with_input_types
6670
from apache_beam.typehints import with_output_types
@@ -510,6 +514,21 @@ def test_group_by_key_unbounded_global_default_trigger(self):
510514
with TestPipeline(options=test_options) as pipeline:
511515
pipeline | TestStream() | beam.GroupByKey()
512516

517+
def test_group_by_key_trigger(self):
518+
options = PipelineOptions(['--allow_unsafe_triggers'])
519+
options.view_as(StandardOptions).streaming = True
520+
with TestPipeline(runner='BundleBasedDirectRunner',
521+
options=options) as pipeline:
522+
pcoll = pipeline | 'Start' >> beam.Create([(0, 0)])
523+
triggered = pcoll | 'Trigger' >> beam.WindowInto(
524+
window.GlobalWindows(),
525+
trigger=AfterProcessingTime(1),
526+
accumulation_mode=AccumulationMode.DISCARDING)
527+
output = triggered | 'Gbk' >> beam.GroupByKey()
528+
self.assertTrue(
529+
isinstance(
530+
output.windowing.triggerfn, _AfterSynchronizedProcessingTime))
531+
513532
def test_group_by_key_unsafe_trigger(self):
514533
test_options = PipelineOptions()
515534
test_options.view_as(TypeOptions).allow_unsafe_triggers = False

sdks/python/apache_beam/transforms/trigger.py

Lines changed: 115 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def from_runner_api(proto, context):
304304
'after_each': AfterEach,
305305
'after_end_of_window': AfterWatermark,
306306
'after_processing_time': AfterProcessingTime,
307-
# after_processing_time, after_synchronized_processing_time
307+
'after_synchronized_processing_time': _AfterSynchronizedProcessingTime,
308308
'always': Always,
309309
'default': DefaultTrigger,
310310
'element_count': AfterCount,
@@ -317,6 +317,17 @@ def from_runner_api(proto, context):
317317
def to_runner_api(self, unused_context):
318318
pass
319319

320+
@abstractmethod
321+
def get_continuation_trigger(self):
322+
"""Returns:
323+
Trigger to use after a GroupBy to preserve the intention of this
324+
trigger. Specifically, triggers that are time based and intended
325+
to provide speculative results should continue providing speculative
326+
results. Triggers that fire once (or multiple times) should
327+
continue firing once (or multiple times).
328+
"""
329+
pass
330+
320331

321332
class DefaultTrigger(TriggerFn):
322333
"""Semantically Repeatedly(AfterWatermark()), but more optimized."""
@@ -366,6 +377,9 @@ def to_runner_api(self, unused_context):
366377
def has_ontime_pane(self):
367378
return True
368379

380+
def get_continuation_trigger(self):
381+
return self
382+
369383

370384
class AfterProcessingTime(TriggerFn):
371385
"""Fire exactly once after a specified delay from processing time."""
@@ -421,6 +435,11 @@ def to_runner_api(self, context):
421435
def has_ontime_pane(self):
422436
return False
423437

438+
def get_continuation_trigger(self):
439+
# The continuation of an AfterProcessingTime trigger is an
440+
# _AfterSynchronizedProcessingTime trigger.
441+
return _AfterSynchronizedProcessingTime()
442+
424443

425444
class Always(TriggerFn):
426445
"""Repeatedly invoke the given trigger, never finishing."""
@@ -466,6 +485,9 @@ def to_runner_api(self, context):
466485
return beam_runner_api_pb2.Trigger(
467486
always=beam_runner_api_pb2.Trigger.Always())
468487

488+
def get_continuation_trigger(self):
489+
return self
490+
469491

470492
class _Never(TriggerFn):
471493
"""A trigger that never fires.
@@ -518,6 +540,9 @@ def to_runner_api(self, context):
518540
return beam_runner_api_pb2.Trigger(
519541
never=beam_runner_api_pb2.Trigger.Never())
520542

543+
def get_continuation_trigger(self):
544+
return self
545+
521546

522547
class AfterWatermark(TriggerFn):
523548
"""Fire exactly once when the watermark passes the end of the window.
@@ -531,9 +556,19 @@ class AfterWatermark(TriggerFn):
531556
LATE_TAG = _CombiningValueStateTag('is_late', any)
532557

533558
def __init__(self, early=None, late=None):
534-
# TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly
535-
self.early = Repeatedly(early) if early else None
536-
self.late = Repeatedly(late) if late else None
559+
self.early = self._wrap_if_not_repeatedly(early)
560+
self.late = self._wrap_if_not_repeatedly(late)
561+
562+
@staticmethod
563+
def _wrap_if_not_repeatedly(trigger):
564+
if trigger and not isinstance(trigger, Repeatedly):
565+
return Repeatedly(trigger)
566+
return trigger
567+
568+
def get_continuation_trigger(self):
569+
return AfterWatermark(
570+
self.early.get_continuation_trigger() if self.early else None,
571+
self.late.get_continuation_trigger() if self.late else None)
537572

538573
def __repr__(self):
539574
qualifiers = []
@@ -692,6 +727,9 @@ def to_runner_api(self, unused_context):
692727
def has_ontime_pane(self):
693728
return False
694729

730+
def get_continuation_trigger(self):
731+
return AfterCount(1)
732+
695733

696734
class Repeatedly(TriggerFn):
697735
"""Repeatedly invoke the given trigger, never finishing."""
@@ -741,6 +779,9 @@ def to_runner_api(self, context):
741779
def has_ontime_pane(self):
742780
return self.underlying.has_ontime_pane()
743781

782+
def get_continuation_trigger(self):
783+
return Repeatedly(self.underlying.get_continuation_trigger())
784+
744785

745786
class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta):
746787
def __init__(self, *triggers):
@@ -831,6 +872,12 @@ def to_runner_api(self, context):
831872
def has_ontime_pane(self):
832873
return any(t.has_ontime_pane() for t in self.triggers)
833874

875+
def get_continuation_trigger(self):
876+
return self.__class__(
877+
*(
878+
subtrigger.get_continuation_trigger()
879+
for subtrigger in self.triggers))
880+
834881

835882
class AfterAny(_ParallelTriggerFn):
836883
"""Fires when any subtrigger fires.
@@ -933,6 +980,13 @@ def to_runner_api(self, context):
933980
def has_ontime_pane(self):
934981
return any(t.has_ontime_pane() for t in self.triggers)
935982

983+
def get_continuation_trigger(self):
984+
return Repeatedly(
985+
AfterAny(
986+
*(
987+
subtrigger.get_continuation_trigger()
988+
for subtrigger in self.triggers)))
989+
936990

937991
class OrFinally(AfterAny):
938992
@staticmethod
@@ -1643,3 +1697,60 @@ def __repr__(self):
16431697
state_str = '\n'.join(
16441698
'%s: %s' % (key, dict(state)) for key, state in self.state.items())
16451699
return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)
1700+
1701+
1702+
class _AfterSynchronizedProcessingTime(TriggerFn):
1703+
"""A "runner's-discretion" trigger downstream of a GroupByKey
1704+
with AfterProcessingTime trigger.
1705+
1706+
In runners that directly execute this
1707+
Python code, the trigger currently always fires,
1708+
but this behavior is neither guaranteed nor
1709+
required by runners, regardless of whether they
1710+
execute triggers via Python.
1711+
1712+
_AfterSynchronizedProcessingTime is experimental
1713+
and internal-only. No backwards compatibility
1714+
guarantees.
1715+
"""
1716+
def __init__(self):
1717+
pass
1718+
1719+
def __repr__(self):
1720+
return '_AfterSynchronizedProcessingTime()'
1721+
1722+
def __eq__(self, other):
1723+
return type(self) == type(other)
1724+
1725+
def __hash__(self):
1726+
return hash(type(self))
1727+
1728+
def on_element(self, _element, _window, _context):
1729+
pass
1730+
1731+
def on_merge(self, _to_be_merged, _merge_result, _context):
1732+
pass
1733+
1734+
def should_fire(self, _time_domain, _timestamp, _window, _context):
1735+
return True
1736+
1737+
def on_fire(self, _timestamp, _window, _context):
1738+
return False
1739+
1740+
def reset(self, _window, _context):
1741+
pass
1742+
1743+
@staticmethod
1744+
def from_runner_api(_proto, _context):
1745+
return _AfterSynchronizedProcessingTime()
1746+
1747+
def to_runner_api(self, _context):
1748+
return beam_runner_api_pb2.Trigger(
1749+
after_synchronized_processing_time=beam_runner_api_pb2.Trigger.
1750+
AfterSynchronizedProcessingTime())
1751+
1752+
def has_ontime_pane(self):
1753+
return False
1754+
1755+
def get_continuation_trigger(self):
1756+
return self

sdks/python/apache_beam/transforms/trigger_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,56 @@ def test_trigger_encoding(self):
554554
TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context))
555555

556556

557+
class ContinuationTriggerTest(unittest.TestCase):
558+
def test_after_all(self):
559+
self.assertEqual(
560+
AfterAll(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
561+
AfterAll(AfterCount(1), AfterCount(1)))
562+
563+
def test_after_any(self):
564+
self.assertEqual(
565+
AfterAny(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
566+
AfterAny(AfterCount(1), AfterCount(1)))
567+
568+
def test_after_count(self):
569+
self.assertEqual(AfterCount(1).get_continuation_trigger(), AfterCount(1))
570+
self.assertEqual(AfterCount(100).get_continuation_trigger(), AfterCount(1))
571+
572+
def test_after_each(self):
573+
self.assertEqual(
574+
AfterEach(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
575+
Repeatedly(AfterAny(AfterCount(1), AfterCount(1))))
576+
577+
def test_after_processing_time(self):
578+
from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime
579+
self.assertEqual(
580+
AfterProcessingTime(10).get_continuation_trigger(),
581+
_AfterSynchronizedProcessingTime())
582+
583+
def test_after_watermark(self):
584+
self.assertEqual(
585+
AfterWatermark().get_continuation_trigger(), AfterWatermark())
586+
self.assertEqual(
587+
AfterWatermark(early=AfterCount(10),
588+
late=AfterCount(20)).get_continuation_trigger(),
589+
AfterWatermark(early=AfterCount(1), late=AfterCount(1)))
590+
591+
def test_always(self):
592+
self.assertEqual(Always().get_continuation_trigger(), Always())
593+
594+
def test_default(self):
595+
self.assertEqual(
596+
DefaultTrigger().get_continuation_trigger(), DefaultTrigger())
597+
598+
def test_never(self):
599+
self.assertEqual(_Never().get_continuation_trigger(), _Never())
600+
601+
def test_repeatedly(self):
602+
self.assertEqual(
603+
Repeatedly(AfterCount(10)).get_continuation_trigger(),
604+
Repeatedly(AfterCount(1)))
605+
606+
557607
class TriggerPipelineTest(unittest.TestCase):
558608
def test_after_processing_time(self):
559609
test_options = PipelineOptions(

0 commit comments

Comments
 (0)