|
26 | 26 | from tenacity import retry |
27 | 27 | from tenacity import stop_after_attempt |
28 | 28 |
|
29 | | -from apache_beam.runners.worker import statesampler |
30 | | -from apache_beam.utils.counters import CounterFactory |
31 | | -from apache_beam.utils.counters import CounterName |
| 29 | +from apache_beam.internal import pickler |
| 30 | +from apache_beam.runners import common |
32 | 31 | from apache_beam.runners.worker import operation_specs |
33 | 32 | from apache_beam.runners.worker import operations |
34 | | -from apache_beam.internal import pickler |
| 33 | +from apache_beam.runners.worker import statesampler |
35 | 34 | from apache_beam.transforms import core |
| 35 | +from apache_beam.utils.counters import CounterFactory |
| 36 | +from apache_beam.utils.counters import CounterName |
36 | 37 |
|
37 | 38 | _LOGGER = logging.getLogger(__name__) |
38 | 39 |
|
@@ -143,7 +144,7 @@ def test_process_timers_metric_is_recorded(self): |
143 | 144 | sampler = statesampler.StateSampler( |
144 | 145 | 'test_stage', counter_factory, sampling_period_ms=1) |
145 | 146 |
|
146 | | - # Keeps range between 50-350 ms, which is fair. |
| 147 | + # Keeps range between 50-350 ms, which is fair. |
147 | 148 | state_duration_ms = 200 |
148 | 149 | margin_of_error = 0.75 |
149 | 150 |
|
@@ -184,52 +185,149 @@ def test_process_timers_metric_is_recorded(self): |
184 | 185 | expected_value * (1.0 + margin_of_error), |
185 | 186 | "The timer metric was higher than expected.") |
186 | 187 |
|
187 | | - def test_do_operation_with_sampler(self): |
| 188 | + def test_do_operation_process_timer_metric(self): |
| 189 | + """ |
| 190 | + Tests that the 'process-timers-msecs' metric is correctly recorded |
| 191 | + when a timer is processed within a DoOperation. |
188 | 192 | """ |
189 | | - Tests that a DoOperation with an active state_sampler correctly |
190 | | - creates a real ScopedState object for timer processing. |
191 | | - """ |
| 193 | + counter_factory = CounterFactory() |
| 194 | + sampler = statesampler.StateSampler( |
| 195 | + 'test_stage', counter_factory, sampling_period_ms=1) |
| 196 | + |
192 | 197 | mock_spec = operation_specs.WorkerDoFn( |
193 | 198 | serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)), |
194 | 199 | output_tags=[], |
195 | 200 | input=None, |
196 | 201 | side_inputs=[], |
197 | 202 | output_coders=[]) |
198 | 203 |
|
199 | | - sampler = statesampler.StateSampler( |
200 | | - 'test_stage', CounterFactory(), sampling_period_ms=1) |
201 | | - |
202 | | - # 1. Create the operation WITHOUT the unexpected keyword argument |
203 | | - op = operations.create_operation( |
204 | | - name_context='test_op', |
| 204 | + op = operations.DoOperation( |
| 205 | + name=common.NameContext('test_op'), |
205 | 206 | spec=mock_spec, |
206 | | - counter_factory=CounterFactory(), |
207 | | - state_sampler=sampler) |
| 207 | + counter_factory=counter_factory, |
| 208 | + sampler=sampler) |
| 209 | + |
| 210 | + op.dofn_runner = Mock() |
| 211 | + op.timer_specs = {'timer_id': Mock()} |
| 212 | + state_duration_ms = 200 |
| 213 | + margin_of_error = 0.75 |
| 214 | + |
| 215 | + def mock_process_user_timer(*args, **kwargs): |
| 216 | + time.sleep(state_duration_ms / 1000.0) |
| 217 | + |
| 218 | + op.dofn_runner.process_user_timer = mock_process_user_timer |
| 219 | + |
| 220 | + mock_timer_data = Mock() |
| 221 | + mock_timer_data.windows = [Mock()] |
| 222 | + mock_timer_data.user_key = Mock() |
| 223 | + mock_timer_data.fire_timestamp = Mock() |
| 224 | + mock_timer_data.paneinfo = Mock() |
| 225 | + mock_timer_data.dynamic_timer_tag = Mock() |
208 | 226 |
|
209 | | - self.assertIsNot( |
210 | | - op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE) |
| 227 | + sampler.start() |
| 228 | + op.process_timer('timer_id', mock_timer_data) |
| 229 | + sampler.stop() |
| 230 | + sampler.commit_counters() |
| 231 | + |
| 232 | + if not statesampler.FAST_SAMPLER: |
| 233 | + return |
| 234 | + |
| 235 | + expected_counter_name = CounterName( |
| 236 | + 'process-timers-msecs', step_name='test_op', stage_name='test_stage') |
| 237 | + |
| 238 | + found_counter = None |
| 239 | + for counter in counter_factory.get_counters(): |
| 240 | + if counter.name == expected_counter_name: |
| 241 | + found_counter = counter |
| 242 | + break |
| 243 | + |
| 244 | + self.assertIsNotNone( |
| 245 | + found_counter, |
| 246 | + f"The expected counter '{expected_counter_name}' was not created.") |
| 247 | + |
| 248 | + actual_value = found_counter.value() |
| 249 | + expected_value = state_duration_ms |
| 250 | + self.assertGreater( |
| 251 | + actual_value, |
| 252 | + expected_value * (1.0 - margin_of_error), |
| 253 | + "The timer metric was lower than expected.") |
| 254 | + self.assertLess( |
| 255 | + actual_value, |
| 256 | + expected_value * (1.0 + margin_of_error), |
| 257 | + "The timer metric was higher than expected.") |
211 | 258 |
|
212 | | - def test_do_operation_without_sampler(self): |
| 259 | + def test_do_operation_process_timer_metric_with_exception(self): |
213 | 260 | """ |
214 | | - Tests that a DoOperation without a state_sampler correctly uses the |
215 | | - NOOP_SCOPED_STATE for timer processing. |
| 261 | + Tests that the 'process-timers-msecs' metric is still recorded |
| 262 | + when a timer callback in a DoOperation raises an exception. |
216 | 263 | """ |
| 264 | + counter_factory = CounterFactory() |
| 265 | + sampler = statesampler.StateSampler( |
| 266 | + 'test_stage', counter_factory, sampling_period_ms=1) |
| 267 | + |
217 | 268 | mock_spec = operation_specs.WorkerDoFn( |
218 | 269 | serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)), |
219 | 270 | output_tags=[], |
220 | 271 | input=None, |
221 | 272 | side_inputs=[], |
222 | 273 | output_coders=[]) |
223 | 274 |
|
224 | | - # 1. Create the operation WITHOUT the unexpected keyword argument |
225 | | - op = operations.create_operation( |
226 | | - name_context='test_op', |
| 275 | + op = operations.DoOperation( |
| 276 | + name=common.NameContext('test_op'), |
227 | 277 | spec=mock_spec, |
228 | | - counter_factory=CounterFactory(), |
229 | | - state_sampler=None) |
| 278 | + counter_factory=counter_factory, |
| 279 | + sampler=sampler) |
| 280 | + |
| 281 | + op.dofn_runner = Mock() |
| 282 | + op.timer_specs = {'timer_id': Mock()} |
| 283 | + state_duration_ms = 200 |
| 284 | + margin_of_error = 0.75 |
| 285 | + |
| 286 | + def mock_process_user_timer(*args, **kwargs): |
| 287 | + time.sleep(state_duration_ms / 1000.0) |
| 288 | + raise ValueError("Test Exception") |
| 289 | + |
| 290 | + op.dofn_runner.process_user_timer = mock_process_user_timer |
| 291 | + |
| 292 | + mock_timer_data = Mock() |
| 293 | + mock_timer_data.windows = [Mock()] |
| 294 | + mock_timer_data.user_key = Mock() |
| 295 | + mock_timer_data.fire_timestamp = Mock() |
| 296 | + mock_timer_data.paneinfo = Mock() |
| 297 | + mock_timer_data.dynamic_timer_tag = Mock() |
| 298 | + |
| 299 | + sampler.start() |
| 300 | + with self.assertRaises(ValueError): |
| 301 | + op.process_timer('timer_id', mock_timer_data) |
| 302 | + sampler.stop() |
| 303 | + sampler.commit_counters() |
| 304 | + |
| 305 | + if not statesampler.FAST_SAMPLER: |
| 306 | + return |
| 307 | + |
| 308 | + expected_counter_name = CounterName( |
| 309 | + 'process-timers-msecs', step_name='test_op', stage_name='test_stage') |
| 310 | + |
| 311 | + found_counter = None |
| 312 | + for counter in counter_factory.get_counters(): |
| 313 | + if counter.name == expected_counter_name: |
| 314 | + found_counter = counter |
| 315 | + break |
230 | 316 |
|
231 | | - self.assertIs( |
232 | | - op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE) |
| 317 | + self.assertIsNotNone( |
| 318 | + found_counter, |
| 319 | + f"The expected counter '{expected_counter_name}' was not created.") |
| 320 | + |
| 321 | + actual_value = found_counter.value() |
| 322 | + expected_value = state_duration_ms |
| 323 | + self.assertGreater( |
| 324 | + actual_value, |
| 325 | + expected_value * (1.0 - margin_of_error), |
| 326 | + "The timer metric was lower than expected.") |
| 327 | + self.assertLess( |
| 328 | + actual_value, |
| 329 | + expected_value * (1.0 + margin_of_error), |
| 330 | + "The timer metric was higher than expected.") |
233 | 331 |
|
234 | 332 |
|
235 | 333 | if __name__ == '__main__': |
|
0 commit comments