|
21 | 21 | import logging |
22 | 22 | import time |
23 | 23 | import unittest |
| 24 | +from unittest.mock import Mock |
24 | 25 |
|
25 | 26 | from tenacity import retry |
26 | 27 | from tenacity import stop_after_attempt |
27 | 28 |
|
28 | 29 | from apache_beam.runners.worker import statesampler |
29 | 30 | from apache_beam.utils.counters import CounterFactory |
30 | 31 | from apache_beam.utils.counters import CounterName |
| 32 | +from apache_beam.runners.worker import operation_specs |
| 33 | +from apache_beam.runners.worker import operations |
| 34 | +from apache_beam.internal import pickler |
| 35 | +from apache_beam.transforms import core |
31 | 36 |
|
32 | 37 | _LOGGER = logging.getLogger(__name__) |
33 | 38 |
|
@@ -213,6 +218,59 @@ def test_process_timers_metric_is_recorded(self): |
213 | 218 | expected_value * (1.0 + margin_of_error), |
214 | 219 | "The timer metric was higher than expected.") |
215 | 220 |
|
| 221 | + def test_do_operation_with_sampler(self): |
| 222 | + """ |
| 223 | + Tests that a DoOperation with an active state_sampler correctly |
| 224 | + creates a real ScopedState object for timer processing. |
| 225 | + """ |
| 226 | + mock_spec = operation_specs.WorkerDoFn( |
| 227 | + serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)), |
| 228 | + output_tags=[], |
| 229 | + input=None, |
| 230 | + side_inputs=[], |
| 231 | + output_coders=[]) |
| 232 | + |
| 233 | + sampler = statesampler.StateSampler( |
| 234 | + 'test_stage', CounterFactory(), sampling_period_ms=1) |
| 235 | + |
| 236 | + # 1. Create the operation WITHOUT the unexpected keyword argument |
| 237 | + op = operations.create_operation( |
| 238 | + name_context='test_op', |
| 239 | + spec=mock_spec, |
| 240 | + counter_factory=CounterFactory(), |
| 241 | + state_sampler=sampler) |
| 242 | + |
| 243 | + # 2. Set the user_state_context attribute AFTER creation |
| 244 | + op.user_state_context = Mock() |
| 245 | + |
| 246 | + self.assertIsNot( |
| 247 | + op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE) |
| 248 | + |
| 249 | + def test_do_operation_without_sampler(self): |
| 250 | + """ |
| 251 | + Tests that a DoOperation without a state_sampler correctly uses the |
| 252 | + NOOP_SCOPED_STATE for timer processing. |
| 253 | + """ |
| 254 | + mock_spec = operation_specs.WorkerDoFn( |
| 255 | + serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)), |
| 256 | + output_tags=[], |
| 257 | + input=None, |
| 258 | + side_inputs=[], |
| 259 | + output_coders=[]) |
| 260 | + |
| 261 | + # 1. Create the operation WITHOUT the unexpected keyword argument |
| 262 | + op = operations.create_operation( |
| 263 | + name_context='test_op', |
| 264 | + spec=mock_spec, |
| 265 | + counter_factory=CounterFactory(), |
| 266 | + state_sampler=None) |
| 267 | + |
| 268 | + # 2. Set the user_state_context attribute AFTER creation |
| 269 | + op.user_state_context = Mock() |
| 270 | + |
| 271 | + self.assertIs( |
| 272 | + op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE) |
| 273 | + |
216 | 274 |
|
217 | 275 | if __name__ == '__main__': |
218 | 276 | logging.getLogger().setLevel(logging.INFO) |
|
0 commit comments