Skip to content

Commit 8e0f35f

Browse files
LEEKYEryan-mbuashundiptvalentyn
authored
Adds a pipeline option to terminate processing of an element after a timeout (python) (#35391)
* Add ptransform_timeout_duration option to pipeline options in python sdk * Propagate ptransform_timeout_duration Worker option to FnApiWorkerStatusHandler * fix import * Refactor: ptransform_timeout_duration -> element_processing_timeout * Propagate the value of element_processing_timeout * Update * restart lull unit test * wrap sys.exit call * fix test * fix test * fix * fixes * fix lint issues * Update sdks/python/apache_beam/runners/worker/sdk_worker.py Co-authored-by: Ryan Mbuashu-Ndip <mbuashundip@google.com> * fix formatter/lint issues * fix formatter/lint issues * fix formatter/lint issues * remove lower bound of timeout * format fix * Update sdks/python/apache_beam/options/pipeline_options.py Co-authored-by: Ryan Mbuashu-Ndip <mbuashundip@google.com> * minor update * minor fix * fix formatter issue * style fix * fix lint issues * fix lint issues * wording fix * combine 2 log lull methods * unit test * Remove redundant error message * call flush log handler method from main thread in worker status * remove * update * flush logger and shut down process if lull time is too long * initialized status handler with None * fix re-def issue * formatter issue fix * update * update * formatter issue fix * formatter issue fix * update * Use os._exit() to terminate the program * minor fix * minor fix * formatter issue fix * minor fix * minor fix * formatter issue fix * fix lint error * Switch branches to reduce logging --------- Co-authored-by: Ryan Mbuashu-Ndip <mbuashundip@google.com> Co-authored-by: tvalentyn <tvalentyn@users.noreply.github.com>
1 parent f3be4c4 commit 8e0f35f

File tree

6 files changed

+125
-54
lines changed

6 files changed

+125
-54
lines changed

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,14 @@ def _add_argparse_args(cls, parser):
14561456
'responsible for executing the user code and communicating with '
14571457
'the runner. Depending on the runner, there may be more than one '
14581458
'SDK Harness process running on the same worker node.'))
1459+
parser.add_argument(
1460+
'--element_processing_timeout_minutes',
1461+
type=int,
1462+
default=None,
1463+
help=(
1464+
'The time limit (in minutes) that an SDK worker allows for a'
1465+
' PTransform operation to process one element before signaling'
1466+
' the runner harness to restart the SDK worker.'))
14591467

14601468
def validate(self, validator):
14611469
errors = []

sdks/python/apache_beam/options/pipeline_options_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,18 @@ def test_experiments(self):
405405
self.assertEqual(options.get_all_options()['experiments'], None)
406406

407407
def test_worker_options(self):
408-
options = PipelineOptions(['--machine_type', 'abc', '--disk_type', 'def'])
408+
options = PipelineOptions([
409+
'--machine_type',
410+
'abc',
411+
'--disk_type',
412+
'def',
413+
'--element_processing_timeout_minutes',
414+
'10',
415+
])
409416
worker_options = options.view_as(WorkerOptions)
410417
self.assertEqual(worker_options.machine_type, 'abc')
411418
self.assertEqual(worker_options.disk_type, 'def')
419+
self.assertEqual(worker_options.element_processing_timeout_minutes, 10)
412420

413421
options = PipelineOptions(
414422
['--worker_machine_type', 'abc', '--worker_disk_type', 'def'])

sdks/python/apache_beam/runners/worker/sdk_worker.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __init__(
176176
# that should be reported to the runner when proocessing the first bundle.
177177
deferred_exception=None, # type: Optional[Exception]
178178
runner_capabilities=frozenset(), # type: FrozenSet[str]
179+
element_processing_timeout_minutes=None, # type: Optional[int]
179180
):
180181
# type: (...) -> None
181182
self._alive = True
@@ -207,6 +208,8 @@ def __init__(
207208
self._profiler_factory = profiler_factory
208209
self.data_sampler = data_sampler
209210
self.runner_capabilities = runner_capabilities
211+
self._element_processing_timeout_minutes = (
212+
element_processing_timeout_minutes)
210213

211214
def default_factory(id):
212215
# type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor
@@ -223,21 +226,21 @@ def default_factory(id):
223226
fns=self._fns,
224227
data_sampler=self.data_sampler,
225228
)
226-
229+
self._status_handler = None # type: Optional[FnApiWorkerStatusHandler]
227230
if status_address:
228231
try:
229232
self._status_handler = FnApiWorkerStatusHandler(
230233
status_address,
231234
self._bundle_processor_cache,
232235
self._state_cache,
233-
enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler]
236+
enable_heap_dump,
237+
element_processing_timeout_minutes=self.
238+
_element_processing_timeout_minutes)
234239
except Exception:
235240
traceback_string = traceback.format_exc()
236241
_LOGGER.warning(
237242
'Error creating worker status request handler, '
238243
'skipping status report. Trace back: %s' % traceback_string)
239-
else:
240-
self._status_handler = None
241244

242245
# TODO(BEAM-8998) use common
243246
# thread_pool_executor.shared_unbounded_instance() to process bundle

sdks/python/apache_beam/runners/worker/sdk_worker_main.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
import logging
2525
import os
2626
import re
27+
import signal
2728
import sys
29+
import time
2830
import traceback
2931

3032
from google.protobuf import text_format
@@ -47,6 +49,7 @@
4749

4850
_LOGGER = logging.getLogger(__name__)
4951
_ENABLE_GOOGLE_CLOUD_PROFILER = 'enable_google_cloud_profiler'
52+
_FN_LOG_HANDLER = None
5053

5154

5255
def _import_beam_plugins(plugins):
@@ -167,7 +170,9 @@ def create_harness(environment, dry_run=False):
167170
enable_heap_dump=enable_heap_dump,
168171
data_sampler=data_sampler,
169172
deferred_exception=deferred_exception,
170-
runner_capabilities=runner_capabilities)
173+
runner_capabilities=runner_capabilities,
174+
element_processing_timeout_minutes=sdk_pipeline_options.view_as(
175+
WorkerOptions).element_processing_timeout_minutes)
171176
return fn_log_handler, sdk_harness, sdk_pipeline_options
172177

173178

@@ -202,7 +207,9 @@ def main(unused_argv):
202207
"""Main entry point for SDK Fn Harness."""
203208
(fn_log_handler, sdk_harness,
204209
sdk_pipeline_options) = create_harness(os.environ)
205-
210+
global _FN_LOG_HANDLER
211+
if fn_log_handler:
212+
_FN_LOG_HANDLER = fn_log_handler
206213
gcp_profiler_name = _get_gcp_profiler_name_if_enabled(sdk_pipeline_options)
207214
if gcp_profiler_name:
208215
_start_profiler(gcp_profiler_name, os.environ["JOB_ID"])
@@ -219,6 +226,15 @@ def main(unused_argv):
219226
fn_log_handler.close()
220227

221228

229+
def terminate_sdk_harness():
230+
"""Flushes the FnApiLogRecordHandler if it exists."""
231+
_LOGGER.error('The SDK harness will be terminated in 5 seconds.')
232+
time.sleep(5)
233+
if _FN_LOG_HANDLER:
234+
_FN_LOG_HANDLER.close()
235+
os.kill(os.getpid(), signal.SIGINT)
236+
237+
222238
def _load_pipeline_options(options_json):
223239
if options_json is None:
224240
return {}

sdks/python/apache_beam/runners/worker/worker_status.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def __init__(
165165
state_cache=None,
166166
enable_heap_dump=False,
167167
worker_id=None,
168-
log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS):
168+
log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS,
169+
element_processing_timeout_minutes=None):
169170
"""Initialize FnApiWorkerStatusHandler.
170171
171172
Args:
@@ -184,6 +185,11 @@ def __init__(
184185
self._status_channel)
185186
self._responses = queue.Queue()
186187
self.log_lull_timeout_ns = log_lull_timeout_ns
188+
if element_processing_timeout_minutes:
189+
self._element_processing_timeout_ns = (
190+
element_processing_timeout_minutes * 60 * 1e9)
191+
else:
192+
self._element_processing_timeout_ns = None
187193
self._last_full_thread_dump_secs = 0.0
188194
self._last_lull_logged_secs = 0.0
189195
self._server = threading.Thread(
@@ -252,22 +258,45 @@ def _log_lull_in_bundle_processor(self, bundle_process_cache):
252258
self._log_lull_sampler_info(info, instruction)
253259

254260
def _log_lull_sampler_info(self, sampler_info, instruction):
255-
if not self._passed_lull_timeout_since_last_log():
261+
if (not sampler_info or not sampler_info.time_since_transition):
256262
return
257-
if (sampler_info and sampler_info.time_since_transition and
258-
sampler_info.time_since_transition > self.log_lull_timeout_ns):
259-
lull_seconds = sampler_info.time_since_transition / 1e9
260263

261-
step_name = sampler_info.state_name.step_name
262-
state_name = sampler_info.state_name.name
263-
if step_name and state_name:
264-
step_name_log = (
265-
' for PTransform{name=%s, state=%s}' % (step_name, state_name))
266-
else:
267-
step_name_log = ''
264+
log_lull = (
265+
self._passed_lull_timeout_since_last_log() and
266+
sampler_info.time_since_transition > self.log_lull_timeout_ns)
267+
timeout_exceeded = (
268+
self._element_processing_timeout_ns and
269+
sampler_info.time_since_transition
270+
> self._element_processing_timeout_ns)
271+
272+
if not (log_lull or timeout_exceeded):
273+
return
268274

269-
stack_trace = self._get_stack_trace(sampler_info)
275+
lull_seconds = sampler_info.time_since_transition / 1e9
276+
step_name = sampler_info.state_name.step_name
277+
state_name = sampler_info.state_name.name
278+
if step_name and state_name:
279+
step_name_log = (
280+
' for PTransform{name=%s, state=%s}' % (step_name, state_name))
281+
else:
282+
step_name_log = ''
283+
stack_trace = self._get_stack_trace(sampler_info)
284+
285+
if timeout_exceeded:
286+
_LOGGER.error(
287+
(
288+
'Operation ongoing in bundle %s%s for at least %.2f seconds'
289+
' without outputting or completing.\n'
290+
'Current Traceback:\n%s'),
291+
instruction,
292+
step_name_log,
293+
lull_seconds,
294+
stack_trace,
295+
)
296+
from apache_beam.runners.worker.sdk_worker_main import terminate_sdk_harness
297+
terminate_sdk_harness()
270298

299+
if log_lull:
271300
_LOGGER.warning(
272301
(
273302
'Operation ongoing in bundle %s%s for at least %.2f seconds'

sdks/python/apache_beam/runners/worker/worker_status_test.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def setUp(self):
5959
self.test_port = self.server.add_insecure_port('[::]:0')
6060
self.server.start()
6161
self.url = 'localhost:%s' % self.test_port
62-
self.fn_status_handler = FnApiWorkerStatusHandler(self.url)
62+
self.fn_status_handler = FnApiWorkerStatusHandler(
63+
self.url, element_processing_timeout_minutes=10)
6364

6465
def tearDown(self):
6566
self.server.stop(5)
@@ -89,42 +90,48 @@ def test_generate_error(self, mock_method):
8990
def test_log_lull_in_bundle_processor(self):
9091
def get_state_sampler_info_for_lull(lull_duration_s):
9192
return "bundle-id", statesampler.StateSamplerInfo(
92-
CounterName('progress-msecs', 'stage_name', 'step_name'),
93-
1,
94-
lull_duration_s * 1e9,
95-
threading.current_thread())
93+
CounterName('progress-msecs', 'stage_name', 'step_name'),
94+
1,
95+
lull_duration_s * 1e9,
96+
threading.current_thread())
9697

9798
now = time.time()
9899
with mock.patch('logging.Logger.warning') as warn_mock:
99-
with mock.patch('time.time') as time_mock:
100-
time_mock.return_value = now
101-
bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60)
102-
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
103-
104-
bundle_id_template = warn_mock.call_args[0][1]
105-
step_name_template = warn_mock.call_args[0][2]
106-
processing_template = warn_mock.call_args[0][3]
107-
traceback = warn_mock.call_args = warn_mock.call_args[0][4]
108-
109-
self.assertIn('bundle-id', bundle_id_template)
110-
self.assertIn('step_name', step_name_template)
111-
self.assertEqual(21 * 60, processing_template)
112-
self.assertIn('test_log_lull_in_bundle_processor', traceback)
113-
114-
with mock.patch('time.time') as time_mock:
115-
time_mock.return_value = now + 6 * 60 # 6 minutes
116-
bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60)
117-
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
118-
119-
with mock.patch('time.time') as time_mock:
120-
time_mock.return_value = now + 21 * 60 # 21 minutes
121-
bundle_id, sampler_info = get_state_sampler_info_for_lull(10 * 60)
122-
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
123-
124-
with mock.patch('time.time') as time_mock:
125-
time_mock.return_value = now + 42 * 60 # 21 minutes after previous one
126-
bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60)
127-
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
100+
with mock.patch(
101+
'apache_beam.runners.worker.sdk_worker_main.terminate_sdk_harness'
102+
) as flush_mock:
103+
with mock.patch('time.time') as time_mock:
104+
time_mock.return_value = now
105+
bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60)
106+
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
107+
bundle_id_template = warn_mock.call_args[0][1]
108+
step_name_template = warn_mock.call_args[0][2]
109+
processing_template = warn_mock.call_args[0][3]
110+
traceback = warn_mock.call_args = warn_mock.call_args[0][4]
111+
112+
self.assertIn('bundle-id', bundle_id_template)
113+
self.assertIn('step_name', step_name_template)
114+
self.assertEqual(21 * 60, processing_template)
115+
self.assertIn('test_log_lull_in_bundle_processor', traceback)
116+
flush_mock.assert_called_once()
117+
118+
with mock.patch('time.time') as time_mock:
119+
time_mock.return_value = now + 6 * 60 # 6 minutes
120+
bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60)
121+
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
122+
self.assertEqual(flush_mock.call_count, 2)
123+
124+
with mock.patch('time.time') as time_mock:
125+
time_mock.return_value = now + 21 * 60 # 21 minutes
126+
bundle_id, sampler_info = get_state_sampler_info_for_lull(10 * 60)
127+
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
128+
self.assertEqual(flush_mock.call_count, 2)
129+
130+
with mock.patch('time.time') as time_mock:
131+
time_mock.return_value = now + 42 * 60 # 42 minutes
132+
bundle_id, sampler_info = get_state_sampler_info_for_lull(11 * 60)
133+
self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id)
134+
self.assertEqual(flush_mock.call_count, 3)
128135

129136

130137
class HeapDumpTest(unittest.TestCase):

0 commit comments

Comments
 (0)