Skip to content

Commit 5b8743b

Browse files
[BEAM-36736] Add state sampling for timer processing in the Python SDK (#36737)
* [BEAM-36736] Add state sampling for timer processing * Force CI to rebuild * Fix error with no state found * Fix error for Regex test * Resolve linting error * Add test case to test full functionality * Fix suffix issue * Fix formatting issues using tox -e yapf-check * Add test cases to test code paths * Address comments and remove extra test case * Remove user state context variable * Adjust state duration for test to avoid flakiness * Add different tests, remove no op scoped state, and address formatting/lint issues * Add patch to deal with CI presubmit errors * Adjust test case to not use dofn_runner * Test case failing presubmits, attempting to fix * Fix mocking for tests and ensure all pass * Remove extra test and increase retries on the process timer tests to avoid flakiness * Remove upper bound restriction and reduce retries * Remove unused suffix param. --------- Co-authored-by: tvalentyn <[email protected]>
1 parent d421c98 commit 5b8743b

File tree

3 files changed

+199
-9
lines changed

3 files changed

+199
-9
lines changed

sdks/python/apache_beam/runners/worker/operations.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ cdef class DoOperation(Operation):
117117
cdef dict timer_specs
118118
cdef public object input_info
119119
cdef object fn
120+
cdef object scoped_timer_processing_state
120121

121122

122123
cdef class SdfProcessSizedElements(DoOperation):

sdks/python/apache_beam/runners/worker/operations.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,10 @@ def __init__(
809809
self.tagged_receivers = None # type: Optional[_TaggedReceivers]
810810
# A mapping of timer tags to the input "PCollections" they come in on.
811811
self.input_info = None # type: Optional[OpInputInfo]
812-
812+
self.scoped_timer_processing_state = self.state_sampler.scoped_state(
813+
self.name_context,
814+
'process-timers',
815+
metrics_container=self.metrics_container)
813816
# See fn_data in dataflow_runner.py
814817
# TODO: Store all the items from spec?
815818
self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
@@ -971,14 +974,15 @@ def add_timer_info(self, timer_family_id, timer_info):
971974
self.user_state_context.add_timer_info(timer_family_id, timer_info)
972975

973976
def process_timer(self, tag, timer_data):
974-
timer_spec = self.timer_specs[tag]
975-
self.dofn_runner.process_user_timer(
976-
timer_spec,
977-
timer_data.user_key,
978-
timer_data.windows[0],
979-
timer_data.fire_timestamp,
980-
timer_data.paneinfo,
981-
timer_data.dynamic_timer_tag)
977+
with self.scoped_timer_processing_state:
978+
timer_spec = self.timer_specs[tag]
979+
self.dofn_runner.process_user_timer(
980+
timer_spec,
981+
timer_data.user_key,
982+
timer_data.windows[0],
983+
timer_data.fire_timestamp,
984+
timer_data.paneinfo,
985+
timer_data.dynamic_timer_tag)
982986

983987
def finish(self):
984988
# type: () -> None

sdks/python/apache_beam/runners/worker/statesampler_test.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,56 @@
2121
import logging
2222
import time
2323
import unittest
24+
from unittest import mock
25+
from unittest.mock import Mock
26+
from unittest.mock import patch
2427

2528
from tenacity import retry
2629
from tenacity import stop_after_attempt
2730

31+
from apache_beam.internal import pickler
32+
from apache_beam.runners import common
33+
from apache_beam.runners.worker import operation_specs
34+
from apache_beam.runners.worker import operations
2835
from apache_beam.runners.worker import statesampler
36+
from apache_beam.transforms import core
37+
from apache_beam.transforms import userstate
38+
from apache_beam.transforms.core import GlobalWindows
39+
from apache_beam.transforms.core import Windowing
40+
from apache_beam.transforms.window import GlobalWindow
2941
from apache_beam.utils.counters import CounterFactory
3042
from apache_beam.utils.counters import CounterName
43+
from apache_beam.utils.windowed_value import PaneInfo
3144

3245
_LOGGER = logging.getLogger(__name__)
3346

3447

48+
class TimerDoFn(core.DoFn):
49+
TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
50+
51+
def __init__(self, sleep_duration_s=0):
52+
self._sleep_duration_s = sleep_duration_s
53+
54+
@userstate.on_timer(TIMER_SPEC)
55+
def on_timer_f(self):
56+
if self._sleep_duration_s:
57+
time.sleep(self._sleep_duration_s)
58+
59+
60+
class ExceptionTimerDoFn(core.DoFn):
61+
"""A DoFn that raises an exception when its timer fires."""
62+
TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK)
63+
64+
def __init__(self, sleep_duration_s=0):
65+
self._sleep_duration_s = sleep_duration_s
66+
67+
@userstate.on_timer(TIMER_SPEC)
68+
def on_timer_f(self):
69+
if self._sleep_duration_s:
70+
time.sleep(self._sleep_duration_s)
71+
raise RuntimeError("Test exception from timer")
72+
73+
3574
class StateSamplerTest(unittest.TestCase):
3675

3776
# Due to somewhat non-deterministic nature of state sampling and sleep,
@@ -127,6 +166,152 @@ def test_sampler_transition_overhead(self):
127166
# debug mode).
128167
self.assertLess(overhead_us, 20.0)
129168

169+
@retry(reraise=True, stop=stop_after_attempt(3))
170+
# Patch the problematic function to return the correct timer spec
171+
@patch('apache_beam.transforms.userstate.get_dofn_specs')
172+
def test_do_operation_process_timer(self, mock_get_dofn_specs):
173+
fn = TimerDoFn()
174+
mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
175+
176+
if not statesampler.FAST_SAMPLER:
177+
self.skipTest('DoOperation test requires FAST_SAMPLER')
178+
179+
state_duration_ms = 200
180+
margin_of_error = 0.75
181+
182+
counter_factory = CounterFactory()
183+
sampler = statesampler.StateSampler(
184+
'test_do_op', counter_factory, sampling_period_ms=1)
185+
186+
fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0)
187+
188+
spec = operation_specs.WorkerDoFn(
189+
serialized_fn=pickler.dumps(
190+
(fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
191+
output_tags=[],
192+
input=None,
193+
side_inputs=[],
194+
output_coders=[])
195+
196+
mock_user_state_context = mock.MagicMock()
197+
op = operations.DoOperation(
198+
common.NameContext('step1'),
199+
spec,
200+
counter_factory,
201+
sampler,
202+
user_state_context=mock_user_state_context)
203+
204+
op.setup()
205+
206+
timer_data = Mock()
207+
timer_data.user_key = None
208+
timer_data.windows = [GlobalWindow()]
209+
timer_data.fire_timestamp = 0
210+
timer_data.paneinfo = PaneInfo(
211+
is_first=False,
212+
is_last=False,
213+
timing=0,
214+
index=0,
215+
nonspeculative_index=0)
216+
timer_data.dynamic_timer_tag = ''
217+
218+
sampler.start()
219+
op.process_timer('ts-timer', timer_data=timer_data)
220+
sampler.stop()
221+
sampler.commit_counters()
222+
223+
expected_name = CounterName(
224+
'process-timers-msecs', step_name='step1', stage_name='test_do_op')
225+
226+
found_counter = None
227+
for counter in counter_factory.get_counters():
228+
if counter.name == expected_name:
229+
found_counter = counter
230+
break
231+
232+
self.assertIsNotNone(
233+
found_counter, f"Expected counter '{expected_name}' to be created.")
234+
235+
actual_value = found_counter.value()
236+
logging.info("Actual value %d", actual_value)
237+
self.assertGreater(
238+
actual_value, state_duration_ms * (1.0 - margin_of_error))
239+
240+
@retry(reraise=True, stop=stop_after_attempt(3))
241+
@patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs')
242+
def test_do_operation_process_timer_with_exception(self, mock_get_dofn_specs):
243+
fn = ExceptionTimerDoFn()
244+
mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC])
245+
246+
if not statesampler.FAST_SAMPLER:
247+
self.skipTest('DoOperation test requires FAST_SAMPLER')
248+
249+
state_duration_ms = 200
250+
margin_of_error = 0.50
251+
252+
counter_factory = CounterFactory()
253+
sampler = statesampler.StateSampler(
254+
'test_do_op_exception', counter_factory, sampling_period_ms=1)
255+
256+
fn_for_spec = ExceptionTimerDoFn(
257+
sleep_duration_s=state_duration_ms / 1000.0)
258+
259+
spec = operation_specs.WorkerDoFn(
260+
serialized_fn=pickler.dumps(
261+
(fn_for_spec, [], {}, [], Windowing(GlobalWindows()))),
262+
output_tags=[],
263+
input=None,
264+
side_inputs=[],
265+
output_coders=[])
266+
267+
mock_user_state_context = mock.MagicMock()
268+
op = operations.DoOperation(
269+
common.NameContext('step1'),
270+
spec,
271+
counter_factory,
272+
sampler,
273+
user_state_context=mock_user_state_context)
274+
275+
op.setup()
276+
277+
timer_data = Mock()
278+
timer_data.user_key = None
279+
timer_data.windows = [GlobalWindow()]
280+
timer_data.fire_timestamp = 0
281+
timer_data.paneinfo = PaneInfo(
282+
is_first=False,
283+
is_last=False,
284+
timing=0,
285+
index=0,
286+
nonspeculative_index=0)
287+
timer_data.dynamic_timer_tag = ''
288+
289+
sampler.start()
290+
# Assert that the expected exception is raised
291+
with self.assertRaises(RuntimeError):
292+
op.process_timer('ts-ts-timer', timer_data=timer_data)
293+
sampler.stop()
294+
sampler.commit_counters()
295+
296+
expected_name = CounterName(
297+
'process-timers-msecs',
298+
step_name='step1',
299+
stage_name='test_do_op_exception')
300+
301+
found_counter = None
302+
for counter in counter_factory.get_counters():
303+
if counter.name == expected_name:
304+
found_counter = counter
305+
break
306+
307+
self.assertIsNotNone(
308+
found_counter, f"Expected counter '{expected_name}' to be created.")
309+
310+
actual_value = found_counter.value()
311+
self.assertGreater(
312+
actual_value, state_duration_ms * (1.0 - margin_of_error))
313+
_LOGGER.info("Exception test finished successfully.")
314+
130315

131316
if __name__ == '__main__':
132317
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)