|
21 | 21 | import logging |
22 | 22 | import time |
23 | 23 | import unittest |
| 24 | +from unittest import mock |
| 25 | +from unittest.mock import Mock |
| 26 | +from unittest.mock import patch |
24 | 27 |
|
25 | 28 | from tenacity import retry |
26 | 29 | from tenacity import stop_after_attempt |
27 | 30 |
|
| 31 | +from apache_beam.internal import pickler |
| 32 | +from apache_beam.runners import common |
| 33 | +from apache_beam.runners.worker import operation_specs |
| 34 | +from apache_beam.runners.worker import operations |
28 | 35 | from apache_beam.runners.worker import statesampler |
| 36 | +from apache_beam.transforms import core |
| 37 | +from apache_beam.transforms import userstate |
| 38 | +from apache_beam.transforms.core import GlobalWindows |
| 39 | +from apache_beam.transforms.core import Windowing |
| 40 | +from apache_beam.transforms.window import GlobalWindow |
29 | 41 | from apache_beam.utils.counters import CounterFactory |
30 | 42 | from apache_beam.utils.counters import CounterName |
| 43 | +from apache_beam.utils.windowed_value import PaneInfo |
31 | 44 |
|
32 | 45 | _LOGGER = logging.getLogger(__name__) |
33 | 46 |
|
34 | 47 |
|
| 48 | +class TimerDoFn(core.DoFn): |
| 49 | + TIMER_SPEC = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK) |
| 50 | + |
| 51 | + def __init__(self, sleep_duration_s=0): |
| 52 | + self._sleep_duration_s = sleep_duration_s |
| 53 | + |
| 54 | + @userstate.on_timer(TIMER_SPEC) |
| 55 | + def on_timer_f(self): |
| 56 | + if self._sleep_duration_s: |
| 57 | + time.sleep(self._sleep_duration_s) |
| 58 | + |
| 59 | + |
| 60 | +class ExceptionTimerDoFn(core.DoFn): |
| 61 | + """A DoFn that raises an exception when its timer fires.""" |
| 62 | + TIMER_SPEC = userstate.TimerSpec('ts-timer', userstate.TimeDomain.WATERMARK) |
| 63 | + |
| 64 | + def __init__(self, sleep_duration_s=0): |
| 65 | + self._sleep_duration_s = sleep_duration_s |
| 66 | + |
| 67 | + @userstate.on_timer(TIMER_SPEC) |
| 68 | + def on_timer_f(self): |
| 69 | + if self._sleep_duration_s: |
| 70 | + time.sleep(self._sleep_duration_s) |
| 71 | + raise RuntimeError("Test exception from timer") |
| 72 | + |
| 73 | + |
35 | 74 | class StateSamplerTest(unittest.TestCase): |
36 | 75 |
|
37 | 76 | # Due to somewhat non-deterministic nature of state sampling and sleep, |
@@ -127,6 +166,152 @@ def test_sampler_transition_overhead(self): |
127 | 166 | # debug mode). |
128 | 167 | self.assertLess(overhead_us, 20.0) |
129 | 168 |
|
| 169 | + @retry(reraise=True, stop=stop_after_attempt(3)) |
| 170 | + # Patch the problematic function to return the correct timer spec |
| 171 | + @patch('apache_beam.transforms.userstate.get_dofn_specs') |
| 172 | + def test_do_operation_process_timer(self, mock_get_dofn_specs): |
| 173 | + fn = TimerDoFn() |
| 174 | + mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC]) |
| 175 | + |
| 176 | + if not statesampler.FAST_SAMPLER: |
| 177 | + self.skipTest('DoOperation test requires FAST_SAMPLER') |
| 178 | + |
| 179 | + state_duration_ms = 200 |
| 180 | + margin_of_error = 0.75 |
| 181 | + |
| 182 | + counter_factory = CounterFactory() |
| 183 | + sampler = statesampler.StateSampler( |
| 184 | + 'test_do_op', counter_factory, sampling_period_ms=1) |
| 185 | + |
| 186 | + fn_for_spec = TimerDoFn(sleep_duration_s=state_duration_ms / 1000.0) |
| 187 | + |
| 188 | + spec = operation_specs.WorkerDoFn( |
| 189 | + serialized_fn=pickler.dumps( |
| 190 | + (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))), |
| 191 | + output_tags=[], |
| 192 | + input=None, |
| 193 | + side_inputs=[], |
| 194 | + output_coders=[]) |
| 195 | + |
| 196 | + mock_user_state_context = mock.MagicMock() |
| 197 | + op = operations.DoOperation( |
| 198 | + common.NameContext('step1'), |
| 199 | + spec, |
| 200 | + counter_factory, |
| 201 | + sampler, |
| 202 | + user_state_context=mock_user_state_context) |
| 203 | + |
| 204 | + op.setup() |
| 205 | + |
| 206 | + timer_data = Mock() |
| 207 | + timer_data.user_key = None |
| 208 | + timer_data.windows = [GlobalWindow()] |
| 209 | + timer_data.fire_timestamp = 0 |
| 210 | + timer_data.paneinfo = PaneInfo( |
| 211 | + is_first=False, |
| 212 | + is_last=False, |
| 213 | + timing=0, |
| 214 | + index=0, |
| 215 | + nonspeculative_index=0) |
| 216 | + timer_data.dynamic_timer_tag = '' |
| 217 | + |
| 218 | + sampler.start() |
| 219 | + op.process_timer('ts-timer', timer_data=timer_data) |
| 220 | + sampler.stop() |
| 221 | + sampler.commit_counters() |
| 222 | + |
| 223 | + expected_name = CounterName( |
| 224 | + 'process-timers-msecs', step_name='step1', stage_name='test_do_op') |
| 225 | + |
| 226 | + found_counter = None |
| 227 | + for counter in counter_factory.get_counters(): |
| 228 | + if counter.name == expected_name: |
| 229 | + found_counter = counter |
| 230 | + break |
| 231 | + |
| 232 | + self.assertIsNotNone( |
| 233 | + found_counter, f"Expected counter '{expected_name}' to be created.") |
| 234 | + |
| 235 | + actual_value = found_counter.value() |
| 236 | + logging.info("Actual value %d", actual_value) |
| 237 | + self.assertGreater( |
| 238 | + actual_value, state_duration_ms * (1.0 - margin_of_error)) |
| 239 | + |
| 240 | + @retry(reraise=True, stop=stop_after_attempt(3)) |
| 241 | + @patch('apache_beam.runners.worker.operations.userstate.get_dofn_specs') |
| 242 | + def test_do_operation_process_timer_with_exception(self, mock_get_dofn_specs): |
| 243 | + fn = ExceptionTimerDoFn() |
| 244 | + mock_get_dofn_specs.return_value = ([], [fn.TIMER_SPEC]) |
| 245 | + |
| 246 | + if not statesampler.FAST_SAMPLER: |
| 247 | + self.skipTest('DoOperation test requires FAST_SAMPLER') |
| 248 | + |
| 249 | + state_duration_ms = 200 |
| 250 | + margin_of_error = 0.50 |
| 251 | + |
| 252 | + counter_factory = CounterFactory() |
| 253 | + sampler = statesampler.StateSampler( |
| 254 | + 'test_do_op_exception', counter_factory, sampling_period_ms=1) |
| 255 | + |
| 256 | + fn_for_spec = ExceptionTimerDoFn( |
| 257 | + sleep_duration_s=state_duration_ms / 1000.0) |
| 258 | + |
| 259 | + spec = operation_specs.WorkerDoFn( |
| 260 | + serialized_fn=pickler.dumps( |
| 261 | + (fn_for_spec, [], {}, [], Windowing(GlobalWindows()))), |
| 262 | + output_tags=[], |
| 263 | + input=None, |
| 264 | + side_inputs=[], |
| 265 | + output_coders=[]) |
| 266 | + |
| 267 | + mock_user_state_context = mock.MagicMock() |
| 268 | + op = operations.DoOperation( |
| 269 | + common.NameContext('step1'), |
| 270 | + spec, |
| 271 | + counter_factory, |
| 272 | + sampler, |
| 273 | + user_state_context=mock_user_state_context) |
| 274 | + |
| 275 | + op.setup() |
| 276 | + |
| 277 | + timer_data = Mock() |
| 278 | + timer_data.user_key = None |
| 279 | + timer_data.windows = [GlobalWindow()] |
| 280 | + timer_data.fire_timestamp = 0 |
| 281 | + timer_data.paneinfo = PaneInfo( |
| 282 | + is_first=False, |
| 283 | + is_last=False, |
| 284 | + timing=0, |
| 285 | + index=0, |
| 286 | + nonspeculative_index=0) |
| 287 | + timer_data.dynamic_timer_tag = '' |
| 288 | + |
| 289 | + sampler.start() |
| 290 | + # Assert that the expected exception is raised |
| 291 | + with self.assertRaises(RuntimeError): |
| 292 | + op.process_timer('ts-ts-timer', timer_data=timer_data) |
| 293 | + sampler.stop() |
| 294 | + sampler.commit_counters() |
| 295 | + |
| 296 | + expected_name = CounterName( |
| 297 | + 'process-timers-msecs', |
| 298 | + step_name='step1', |
| 299 | + stage_name='test_do_op_exception') |
| 300 | + |
| 301 | + found_counter = None |
| 302 | + for counter in counter_factory.get_counters(): |
| 303 | + if counter.name == expected_name: |
| 304 | + found_counter = counter |
| 305 | + break |
| 306 | + |
| 307 | + self.assertIsNotNone( |
| 308 | + found_counter, f"Expected counter '{expected_name}' to be created.") |
| 309 | + |
| 310 | + actual_value = found_counter.value() |
| 311 | + self.assertGreater( |
| 312 | + actual_value, state_duration_ms * (1.0 - margin_of_error)) |
| 313 | + _LOGGER.info("Exception test finished successfully.") |
| 314 | + |
130 | 315 |
|
131 | 316 | if __name__ == '__main__': |
132 | 317 | logging.getLogger().setLevel(logging.INFO) |
|
0 commit comments