Skip to content

Commit d41f76b

Browse files
committed
Add different tests, remove no op scoped state, and address formatting/lint issues
1 parent bf55c2d commit d41f76b

File tree

3 files changed

+138
-70
lines changed

3 files changed

+138
-70
lines changed

sdks/python/apache_beam/runners/worker/operations.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from apache_beam.runners.worker import opcounters
5050
from apache_beam.runners.worker import operation_specs
5151
from apache_beam.runners.worker import sideinputs
52-
from apache_beam.runners.worker import statesampler
5352
from apache_beam.runners.worker.data_sampler import DataSampler
5453
from apache_beam.transforms import sideinputs as apache_sideinputs
5554
from apache_beam.transforms import combiners
@@ -445,19 +444,12 @@ def __init__(
445444
self.metrics_container = MetricsContainer(self.name_context.metrics_name())
446445

447446
self.state_sampler = state_sampler
448-
if self.state_sampler:
449-
self.scoped_start_state = self.state_sampler.scoped_state(
450-
self.name_context, 'start', metrics_container=self.metrics_container)
451-
self.scoped_process_state = self.state_sampler.scoped_state(
452-
self.name_context,
453-
'process',
454-
metrics_container=self.metrics_container)
455-
self.scoped_finish_state = self.state_sampler.scoped_state(
456-
self.name_context, 'finish', metrics_container=self.metrics_container)
457-
else:
458-
self.scoped_start_state = statesampler.NOOP_SCOPED_STATE
459-
self.scoped_process_state = statesampler.NOOP_SCOPED_STATE
460-
self.scoped_finish_state = statesampler.NOOP_SCOPED_STATE
447+
self.scoped_start_state = self.state_sampler.scoped_state(
448+
self.name_context, 'start', metrics_container=self.metrics_container)
449+
self.scoped_process_state = self.state_sampler.scoped_state(
450+
self.name_context, 'process', metrics_container=self.metrics_container)
451+
self.scoped_finish_state = self.state_sampler.scoped_state(
452+
self.name_context, 'finish', metrics_container=self.metrics_container)
461453
# TODO(ccy): the '-abort' state can be added when the abort is supported in
462454
# Operations.
463455
self.receivers = [] # type: List[ConsumerSet]
@@ -817,12 +809,10 @@ def __init__(
817809
self.tagged_receivers = None # type: Optional[_TaggedReceivers]
818810
# A mapping of timer tags to the input "PCollections" they come in on.
819811
self.input_info = None # type: Optional[OpInputInfo]
820-
self.scoped_timer_processing_state = statesampler.NOOP_SCOPED_STATE
821-
if self.state_sampler:
822-
self.scoped_timer_processing_state = self.state_sampler.scoped_state(
823-
self.name_context,
824-
'process-timers',
825-
metrics_container=self.metrics_container)
812+
self.scoped_timer_processing_state = self.state_sampler.scoped_state(
813+
self.name_context,
814+
'process-timers',
815+
metrics_container=self.metrics_container)
826816
# See fn_data in dataflow_runner.py
827817
# TODO: Store all the items from spec?
828818
self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
@@ -984,7 +974,7 @@ def add_timer_info(self, timer_family_id, timer_info):
984974
self.user_state_context.add_timer_info(timer_family_id, timer_info)
985975

986976
def process_timer(self, tag, timer_data):
987-
def process_timer_logic():
977+
with self.scoped_timer_processing_state:
988978
timer_spec = self.timer_specs[tag]
989979
self.dofn_runner.process_user_timer(
990980
timer_spec,
@@ -994,12 +984,6 @@ def process_timer_logic():
994984
timer_data.paneinfo,
995985
timer_data.dynamic_timer_tag)
996986

997-
if self.scoped_timer_processing_state:
998-
with self.scoped_timer_processing_state:
999-
process_timer_logic()
1000-
else:
1001-
process_timer_logic()
1002-
1003987
def finish(self):
1004988
# type: () -> None
1005989
super(DoOperation, self).finish()

sdks/python/apache_beam/runners/worker/statesampler.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,17 +170,3 @@ def commit_counters(self) -> None:
170170
for state in self._states_by_name.values():
171171
state_msecs = int(1e-6 * state.nsecs)
172172
state.counter.update(state_msecs - state.counter.value())
173-
174-
175-
class NoOpScopedState:
176-
def __enter__(self):
177-
pass
178-
179-
def __exit__(self, exc_type, exc_val, exc_tb):
180-
pass
181-
182-
def sampled_msecs_int(self):
183-
return 0
184-
185-
186-
NOOP_SCOPED_STATE = NoOpScopedState()

sdks/python/apache_beam/runners/worker/statesampler_test.py

Lines changed: 127 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
from tenacity import retry
2727
from tenacity import stop_after_attempt
2828

29-
from apache_beam.runners.worker import statesampler
30-
from apache_beam.utils.counters import CounterFactory
31-
from apache_beam.utils.counters import CounterName
29+
from apache_beam.internal import pickler
30+
from apache_beam.runners import common
3231
from apache_beam.runners.worker import operation_specs
3332
from apache_beam.runners.worker import operations
34-
from apache_beam.internal import pickler
33+
from apache_beam.runners.worker import statesampler
3534
from apache_beam.transforms import core
35+
from apache_beam.utils.counters import CounterFactory
36+
from apache_beam.utils.counters import CounterName
3637

3738
_LOGGER = logging.getLogger(__name__)
3839

@@ -143,7 +144,7 @@ def test_process_timers_metric_is_recorded(self):
143144
sampler = statesampler.StateSampler(
144145
'test_stage', counter_factory, sampling_period_ms=1)
145146

146-
# Keeps range between 50-350 ms, which is fair.
147+
# Keeps range between 50-350 ms, which is fair.
147148
state_duration_ms = 200
148149
margin_of_error = 0.75
149150

@@ -184,52 +185,149 @@ def test_process_timers_metric_is_recorded(self):
184185
expected_value * (1.0 + margin_of_error),
185186
"The timer metric was higher than expected.")
186187

187-
def test_do_operation_with_sampler(self):
188+
def test_do_operation_process_timer_metric(self):
189+
"""
190+
Tests that the 'process-timers-msecs' metric is correctly recorded
191+
when a timer is processed within a DoOperation.
188192
"""
189-
Tests that a DoOperation with an active state_sampler correctly
190-
creates a real ScopedState object for timer processing.
191-
"""
193+
counter_factory = CounterFactory()
194+
sampler = statesampler.StateSampler(
195+
'test_stage', counter_factory, sampling_period_ms=1)
196+
192197
mock_spec = operation_specs.WorkerDoFn(
193198
serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)),
194199
output_tags=[],
195200
input=None,
196201
side_inputs=[],
197202
output_coders=[])
198203

199-
sampler = statesampler.StateSampler(
200-
'test_stage', CounterFactory(), sampling_period_ms=1)
201-
202-
# 1. Create the operation WITHOUT the unexpected keyword argument
203-
op = operations.create_operation(
204-
name_context='test_op',
204+
op = operations.DoOperation(
205+
name=common.NameContext('test_op'),
205206
spec=mock_spec,
206-
counter_factory=CounterFactory(),
207-
state_sampler=sampler)
207+
counter_factory=counter_factory,
208+
sampler=sampler)
209+
210+
op.dofn_runner = Mock()
211+
op.timer_specs = {'timer_id': Mock()}
212+
state_duration_ms = 200
213+
margin_of_error = 0.75
214+
215+
def mock_process_user_timer(*args, **kwargs):
216+
time.sleep(state_duration_ms / 1000.0)
217+
218+
op.dofn_runner.process_user_timer = mock_process_user_timer
219+
220+
mock_timer_data = Mock()
221+
mock_timer_data.windows = [Mock()]
222+
mock_timer_data.user_key = Mock()
223+
mock_timer_data.fire_timestamp = Mock()
224+
mock_timer_data.paneinfo = Mock()
225+
mock_timer_data.dynamic_timer_tag = Mock()
208226

209-
self.assertIsNot(
210-
op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE)
227+
sampler.start()
228+
op.process_timer('timer_id', mock_timer_data)
229+
sampler.stop()
230+
sampler.commit_counters()
231+
232+
if not statesampler.FAST_SAMPLER:
233+
return
234+
235+
expected_counter_name = CounterName(
236+
'process-timers-msecs', step_name='test_op', stage_name='test_stage')
237+
238+
found_counter = None
239+
for counter in counter_factory.get_counters():
240+
if counter.name == expected_counter_name:
241+
found_counter = counter
242+
break
243+
244+
self.assertIsNotNone(
245+
found_counter,
246+
f"The expected counter '{expected_counter_name}' was not created.")
247+
248+
actual_value = found_counter.value()
249+
expected_value = state_duration_ms
250+
self.assertGreater(
251+
actual_value,
252+
expected_value * (1.0 - margin_of_error),
253+
"The timer metric was lower than expected.")
254+
self.assertLess(
255+
actual_value,
256+
expected_value * (1.0 + margin_of_error),
257+
"The timer metric was higher than expected.")
211258

212-
def test_do_operation_without_sampler(self):
259+
def test_do_operation_process_timer_metric_with_exception(self):
213260
"""
214-
Tests that a DoOperation without a state_sampler correctly uses the
215-
NOOP_SCOPED_STATE for timer processing.
261+
Tests that the 'process-timers-msecs' metric is still recorded
262+
when a timer callback in a DoOperation raises an exception.
216263
"""
264+
counter_factory = CounterFactory()
265+
sampler = statesampler.StateSampler(
266+
'test_stage', counter_factory, sampling_period_ms=1)
267+
217268
mock_spec = operation_specs.WorkerDoFn(
218269
serialized_fn=pickler.dumps((core.DoFn(), None, None, None, None)),
219270
output_tags=[],
220271
input=None,
221272
side_inputs=[],
222273
output_coders=[])
223274

224-
# 1. Create the operation WITHOUT the unexpected keyword argument
225-
op = operations.create_operation(
226-
name_context='test_op',
275+
op = operations.DoOperation(
276+
name=common.NameContext('test_op'),
227277
spec=mock_spec,
228-
counter_factory=CounterFactory(),
229-
state_sampler=None)
278+
counter_factory=counter_factory,
279+
sampler=sampler)
280+
281+
op.dofn_runner = Mock()
282+
op.timer_specs = {'timer_id': Mock()}
283+
state_duration_ms = 200
284+
margin_of_error = 0.75
285+
286+
def mock_process_user_timer(*args, **kwargs):
287+
time.sleep(state_duration_ms / 1000.0)
288+
raise ValueError("Test Exception")
289+
290+
op.dofn_runner.process_user_timer = mock_process_user_timer
291+
292+
mock_timer_data = Mock()
293+
mock_timer_data.windows = [Mock()]
294+
mock_timer_data.user_key = Mock()
295+
mock_timer_data.fire_timestamp = Mock()
296+
mock_timer_data.paneinfo = Mock()
297+
mock_timer_data.dynamic_timer_tag = Mock()
298+
299+
sampler.start()
300+
with self.assertRaises(ValueError):
301+
op.process_timer('timer_id', mock_timer_data)
302+
sampler.stop()
303+
sampler.commit_counters()
304+
305+
if not statesampler.FAST_SAMPLER:
306+
return
307+
308+
expected_counter_name = CounterName(
309+
'process-timers-msecs', step_name='test_op', stage_name='test_stage')
310+
311+
found_counter = None
312+
for counter in counter_factory.get_counters():
313+
if counter.name == expected_counter_name:
314+
found_counter = counter
315+
break
230316

231-
self.assertIs(
232-
op.scoped_timer_processing_state, statesampler.NOOP_SCOPED_STATE)
317+
self.assertIsNotNone(
318+
found_counter,
319+
f"The expected counter '{expected_counter_name}' was not created.")
320+
321+
actual_value = found_counter.value()
322+
expected_value = state_duration_ms
323+
self.assertGreater(
324+
actual_value,
325+
expected_value * (1.0 - margin_of_error),
326+
"The timer metric was lower than expected.")
327+
self.assertLess(
328+
actual_value,
329+
expected_value * (1.0 + margin_of_error),
330+
"The timer metric was higher than expected.")
233331

234332

235333
if __name__ == '__main__':

0 commit comments

Comments
 (0)