diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index a7db5bfb0e71..131949ecb4f1 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1455,6 +1455,14 @@ def _add_argparse_args(cls, parser): 'responsible for executing the user code and communicating with ' 'the runner. Depending on the runner, there may be more than one ' 'SDK Harness process running on the same worker node.')) + parser.add_argument( + '--element_processing_timeout_minutes', + type=int, + default=None, + help=( + 'The time limit (in minutes) that an SDK worker allows for a' + ' PTransform operation to process one element before signaling' + ' the runner harness to restart the SDK worker.')) def validate(self, validator): errors = [] diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index 06270d4cd310..cd6cce204b78 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -405,10 +405,18 @@ def test_experiments(self): self.assertEqual(options.get_all_options()['experiments'], None) def test_worker_options(self): - options = PipelineOptions(['--machine_type', 'abc', '--disk_type', 'def']) + options = PipelineOptions([ + '--machine_type', + 'abc', + '--disk_type', + 'def', + '--element_processing_timeout_minutes', + '10', + ]) worker_options = options.view_as(WorkerOptions) self.assertEqual(worker_options.machine_type, 'abc') self.assertEqual(worker_options.disk_type, 'def') + self.assertEqual(worker_options.element_processing_timeout_minutes, 10) options = PipelineOptions( ['--worker_machine_type', 'abc', '--worker_disk_type', 'def']) diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 5240674c7009..0b4c236d6b37 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -176,6 +176,7 @@ def __init__( # that should be reported to the runner when proocessing the first bundle. deferred_exception=None, # type: Optional[Exception] runner_capabilities=frozenset(), # type: FrozenSet[str] + element_processing_timeout_minutes=None, # type: Optional[int] ): # type: (...) -> None self._alive = True @@ -207,6 +208,8 @@ def __init__( self._profiler_factory = profiler_factory self.data_sampler = data_sampler self.runner_capabilities = runner_capabilities + self._element_processing_timeout_minutes = ( + element_processing_timeout_minutes) def default_factory(id): # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor @@ -223,21 +226,21 @@ def default_factory(id): fns=self._fns, data_sampler=self.data_sampler, ) - + self._status_handler = None # type: Optional[FnApiWorkerStatusHandler] if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache, self._state_cache, - enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler] + enable_heap_dump, + element_processing_timeout_minutes=self. + _element_processing_timeout_minutes) except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) - else: - self._status_handler = None # TODO(BEAM-8998) use common # thread_pool_executor.shared_unbounded_instance() to process bundle diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index b3c81fd93467..7ea0e0eb1099 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -24,7 +24,9 @@ import logging import os import re +import signal import sys +import time import traceback from google.protobuf import text_format @@ -47,6 +49,7 @@ _LOGGER = logging.getLogger(__name__) _ENABLE_GOOGLE_CLOUD_PROFILER = 'enable_google_cloud_profiler' +_FN_LOG_HANDLER = None def _import_beam_plugins(plugins): @@ -167,7 +170,9 @@ def create_harness(environment, dry_run=False): enable_heap_dump=enable_heap_dump, data_sampler=data_sampler, deferred_exception=deferred_exception, - runner_capabilities=runner_capabilities) + runner_capabilities=runner_capabilities, + element_processing_timeout_minutes=sdk_pipeline_options.view_as( + WorkerOptions).element_processing_timeout_minutes) return fn_log_handler, sdk_harness, sdk_pipeline_options @@ -202,7 +207,9 @@ def main(unused_argv): """Main entry point for SDK Fn Harness.""" (fn_log_handler, sdk_harness, sdk_pipeline_options) = create_harness(os.environ) - + global _FN_LOG_HANDLER + if fn_log_handler: + _FN_LOG_HANDLER = fn_log_handler gcp_profiler_name = _get_gcp_profiler_name_if_enabled(sdk_pipeline_options) if gcp_profiler_name: _start_profiler(gcp_profiler_name, os.environ["JOB_ID"]) @@ -219,6 +226,15 @@ def main(unused_argv): fn_log_handler.close() +def terminate_sdk_harness(): + """Flushes the FnApiLogRecordHandler if it exists.""" + _LOGGER.error('The SDK harness will be terminated in 5 seconds.') + time.sleep(5) + if _FN_LOG_HANDLER: + _FN_LOG_HANDLER.close() + os.kill(os.getpid(), signal.SIGINT) + + def _load_pipeline_options(options_json): if options_json is None: return {} diff --git a/sdks/python/apache_beam/runners/worker/worker_status.py b/sdks/python/apache_beam/runners/worker/worker_status.py index d67bd4437fbb..5a3e9852a54d 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status.py +++ b/sdks/python/apache_beam/runners/worker/worker_status.py @@ -165,7 +165,8 @@ def __init__( state_cache=None, enable_heap_dump=False, worker_id=None, - log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS): + log_lull_timeout_ns=DEFAULT_LOG_LULL_TIMEOUT_NS, + element_processing_timeout_minutes=None): """Initialize FnApiWorkerStatusHandler. Args: @@ -184,6 +185,11 @@ def __init__( self._status_channel) self._responses = queue.Queue() self.log_lull_timeout_ns = log_lull_timeout_ns + if element_processing_timeout_minutes: + self._element_processing_timeout_ns = ( + element_processing_timeout_minutes * 60 * 1e9) + else: + self._element_processing_timeout_ns = None self._last_full_thread_dump_secs = 0.0 self._last_lull_logged_secs = 0.0 self._server = threading.Thread( @@ -252,22 +258,45 @@ def _log_lull_in_bundle_processor(self, bundle_process_cache): self._log_lull_sampler_info(info, instruction) def _log_lull_sampler_info(self, sampler_info, instruction): - if not self._passed_lull_timeout_since_last_log(): + if (not sampler_info or not sampler_info.time_since_transition): return - if (sampler_info and sampler_info.time_since_transition and - sampler_info.time_since_transition > self.log_lull_timeout_ns): - lull_seconds = sampler_info.time_since_transition / 1e9 - step_name = sampler_info.state_name.step_name - state_name = sampler_info.state_name.name - if step_name and state_name: - step_name_log = ( - ' for PTransform{name=%s, state=%s}' % (step_name, state_name)) - else: - step_name_log = '' + log_lull = ( + self._passed_lull_timeout_since_last_log() and + sampler_info.time_since_transition > self.log_lull_timeout_ns) + timeout_exceeded = ( + self._element_processing_timeout_ns and + sampler_info.time_since_transition + > self._element_processing_timeout_ns) + + if not (log_lull or timeout_exceeded): + return - stack_trace = self._get_stack_trace(sampler_info) + lull_seconds = sampler_info.time_since_transition / 1e9 + step_name = sampler_info.state_name.step_name + state_name = sampler_info.state_name.name + if step_name and state_name: + step_name_log = ( + ' for PTransform{name=%s, state=%s}' % (step_name, state_name)) + else: + step_name_log = '' + stack_trace = self._get_stack_trace(sampler_info) + + if timeout_exceeded: + _LOGGER.error( + ( + 'Operation ongoing in bundle %s%s for at least %.2f seconds' + ' without outputting or completing.\n' + 'Current Traceback:\n%s'), + instruction, + step_name_log, + lull_seconds, + stack_trace, + ) + from apache_beam.runners.worker.sdk_worker_main import terminate_sdk_harness + terminate_sdk_harness() + if log_lull: _LOGGER.warning( ( 'Operation ongoing in bundle %s%s for at least %.2f seconds' diff --git a/sdks/python/apache_beam/runners/worker/worker_status_test.py b/sdks/python/apache_beam/runners/worker/worker_status_test.py index 1004d21e7fd3..67df1a324d9e 100644 --- a/sdks/python/apache_beam/runners/worker/worker_status_test.py +++ b/sdks/python/apache_beam/runners/worker/worker_status_test.py @@ -59,7 +59,8 @@ def setUp(self): self.test_port = self.server.add_insecure_port('[::]:0') self.server.start() self.url = 'localhost:%s' % self.test_port - self.fn_status_handler = FnApiWorkerStatusHandler(self.url) + self.fn_status_handler = FnApiWorkerStatusHandler( + self.url, element_processing_timeout_minutes=10) def tearDown(self): self.server.stop(5) @@ -89,42 +90,48 @@ def test_generate_error(self, mock_method): def test_log_lull_in_bundle_processor(self): def get_state_sampler_info_for_lull(lull_duration_s): return "bundle-id", statesampler.StateSamplerInfo( - CounterName('progress-msecs', 'stage_name', 'step_name'), - 1, - lull_duration_s * 1e9, - threading.current_thread()) + CounterName('progress-msecs', 'stage_name', 'step_name'), + 1, + lull_duration_s * 1e9, + threading.current_thread()) now = time.time() with mock.patch('logging.Logger.warning') as warn_mock: - with mock.patch('time.time') as time_mock: - time_mock.return_value = now - bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60) - self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) - - bundle_id_template = warn_mock.call_args[0][1] - step_name_template = warn_mock.call_args[0][2] - processing_template = warn_mock.call_args[0][3] - traceback = warn_mock.call_args = warn_mock.call_args[0][4] - - self.assertIn('bundle-id', bundle_id_template) - self.assertIn('step_name', step_name_template) - self.assertEqual(21 * 60, processing_template) - self.assertIn('test_log_lull_in_bundle_processor', traceback) - - with mock.patch('time.time') as time_mock: - time_mock.return_value = now + 6 * 60 # 6 minutes - bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60) - self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) - - with mock.patch('time.time') as time_mock: - time_mock.return_value = now + 21 * 60 # 21 minutes - bundle_id, sampler_info = get_state_sampler_info_for_lull(10 * 60) - self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) - - with mock.patch('time.time') as time_mock: - time_mock.return_value = now + 42 * 60 # 21 minutes after previous one - bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60) - self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) + with mock.patch( + 'apache_beam.runners.worker.sdk_worker_main.terminate_sdk_harness' + ) as flush_mock: + with mock.patch('time.time') as time_mock: + time_mock.return_value = now + bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60) + self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) + bundle_id_template = warn_mock.call_args[0][1] + step_name_template = warn_mock.call_args[0][2] + processing_template = warn_mock.call_args[0][3] + traceback = warn_mock.call_args = warn_mock.call_args[0][4] + + self.assertIn('bundle-id', bundle_id_template) + self.assertIn('step_name', step_name_template) + self.assertEqual(21 * 60, processing_template) + self.assertIn('test_log_lull_in_bundle_processor', traceback) + flush_mock.assert_called_once() + + with mock.patch('time.time') as time_mock: + time_mock.return_value = now + 6 * 60 # 6 minutes + bundle_id, sampler_info = get_state_sampler_info_for_lull(21 * 60) + self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) + self.assertEqual(flush_mock.call_count, 2) + + with mock.patch('time.time') as time_mock: + time_mock.return_value = now + 21 * 60 # 21 minutes + bundle_id, sampler_info = get_state_sampler_info_for_lull(10 * 60) + self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) + self.assertEqual(flush_mock.call_count, 2) + + with mock.patch('time.time') as time_mock: + time_mock.return_value = now + 42 * 60 # 42 minutes + bundle_id, sampler_info = get_state_sampler_info_for_lull(11 * 60) + self.fn_status_handler._log_lull_sampler_info(sampler_info, bundle_id) + self.assertEqual(flush_mock.call_count, 3) class HeapDumpTest(unittest.TestCase):