diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index 6dc43dbf8da9..d2fa90c85085 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -18,6 +18,7 @@ from __future__ import absolute_import import logging +import random import uuid from concurrent.futures import ThreadPoolExecutor from math import floor @@ -55,9 +56,8 @@ class AsyncWrapper(beam.DoFn): TIMER_SET = ReadModifyWriteStateSpec('timer_set', coders.BooleanCoder()) TO_PROCESS = BagStateSpec( 'to_process', - coders.TupleCoder([coders.StrUtf8Coder(), coders.StrUtf8Coder()]), - ) - _timer_frequency = 20 + coders.TupleCoder( + [coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()])) # The below items are one per dofn (not instance) so are maps of UUID to # value. _processing_elements = {} @@ -75,7 +75,8 @@ def __init__( parallelism=1, callback_frequency=5, max_items_to_buffer=None, - max_wait_time=120, + timeout=1, + max_wait_time=0.5, ): """Wraps the sync_fn to create an asynchronous version. @@ -96,14 +97,17 @@ def __init__( max_items_to_buffer: We should ideally buffer enough to always be busy but not so much that the worker ooms. By default will be 2x the parallelism which should be good for most pipelines. - max_wait_time: The maximum amount of time an item should wait to be added - to the buffer. Used for testing to ensure timeouts are met. + timeout: The maximum amount of time an item should try to be scheduled + locally before it goes in the queue of waiting work. + max_wait_time: The maximum amount of sleep time while attempting to + schedule an item. Used in testing to ensure timeouts are met. """ self._sync_fn = sync_fn self._uuid = uuid.uuid4().hex self._parallelism = parallelism + self._timeout = timeout self._max_wait_time = max_wait_time - self._timer_frequency = 20 + self._timer_frequency = callback_frequency if max_items_to_buffer is None: self._max_items_to_buffer = max(parallelism * 2, 10) else: @@ -112,9 +116,6 @@ def __init__( AsyncWrapper._processing_elements[self._uuid] = {} AsyncWrapper._items_in_buffer[self._uuid] = 0 self.max_wait_time = max_wait_time - self.timer_frequency_ = callback_frequency - self.parallelism_ = parallelism - self._next_time_to_fire = Timestamp.now() + Duration(seconds=5) self._shared_handle = Shared() @staticmethod @@ -238,9 +239,9 @@ def schedule_item(self, element, ignore_buffer=False, *args, **kwargs): **kwargs: keyword arguments that the wrapped dofn requires. """ done = False - sleep_time = 1 + sleep_time = 0.01 total_sleep = 0 - while not done: + while not done and total_sleep < self._timeout: done = self.schedule_if_room(element, ignore_buffer, *args, **kwargs) if not done: sleep_time = min(self.max_wait_time, sleep_time * 2) @@ -256,10 +257,12 @@ def schedule_item(self, element, ignore_buffer=False, *args, **kwargs): total_sleep += sleep_time sleep(sleep_time) - def next_time_to_fire(self): + def next_time_to_fire(self, key): + random.seed(key) return ( floor((time() + self._timer_frequency) / self._timer_frequency) * - self._timer_frequency) + self._timer_frequency) + ( + random.random() * self._timer_frequency) def accepting_items(self): with AsyncWrapper._lock: @@ -301,7 +304,7 @@ def process( # Set a timer to fire on the next round increment of timer_frequency_. Note # we do this so that each messages timer doesn't get overwritten by the # next. - time_to_fire = self.next_time_to_fire() + time_to_fire = self.next_time_to_fire(element[0]) timer.set(time_to_fire) # Don't output any elements. This will be done in commit_finished_items. @@ -346,6 +349,7 @@ def commit_finished_items( # from local state and cancel their futures. to_remove = [] key = None + to_reschedule = [] if to_process_local: key = str(to_process_local[0][0]) else: @@ -387,9 +391,13 @@ def commit_finished_items( 'item %s found in processing state but not local state,' ' scheduling now', x) - self.schedule_item(x, ignore_buffer=True) + to_reschedule.append(x) items_rescheduled += 1 + # Reschedule the items not under a lock + for x in to_reschedule: + self.schedule_item(x, ignore_buffer=False) + # Update processing state to remove elements we've finished to_process.clear() for x in to_process_local: @@ -408,8 +416,8 @@ def commit_finished_items( # If there are items not yet finished then set a timer to fire in the # future. self._next_time_to_fire = Timestamp.now() + Duration(seconds=5) - if items_not_yet_finished > 0: - time_to_fire = self.next_time_to_fire() + if items_in_processing_state > 0: + time_to_fire = self.next_time_to_fire(key) timer.set(time_to_fire) # Each result is a list. We want to combine them into a single diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index ecc730a66f91..7577e215d1c7 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -343,10 +343,15 @@ def add_item(i): self.assertEqual(async_dofn._max_items_to_buffer, 5) self.check_items_in_buffer(async_dofn, 5) - # After 55 seconds all items should be finished (including those which were - # waiting on the buffer). + # Wait for all buffered items to finish. self.wait_for_empty(async_dofn, 100) + # This will commit buffered items and add new items which didn't fit in the + # buffer. result = async_dofn.commit_finished_items(fake_bag_state, fake_timer) + + # Wait for the new buffered items to finish. + self.wait_for_empty(async_dofn, 100) + result.extend(async_dofn.commit_finished_items(fake_bag_state, fake_timer)) self.check_output(result, expected_output) self.check_items_in_buffer(async_dofn, 0) @@ -414,33 +419,23 @@ def add_item(i): # Run for a while. Should be enough to start all items but not finish them # all. time.sleep(random.randint(30, 50)) - # Commit some stuff - pre_crash_results = [] - for i in range(0, 10): - pre_crash_results.append( - async_dofn.commit_finished_items( - bag_states['key' + str(i)], timers['key' + str(i)])) - # Wait for all items to at least make it into the buffer. done = False + results = [[] for _ in range(0, 10)] while not done: - time.sleep(10) done = True - for future in futures: - if not future.done(): + for i in range(0, 10): + results[i].extend( + async_dofn.commit_finished_items( + bag_states['key' + str(i)], timers['key' + str(i)])) + if not bag_states['key' + str(i)].items: + self.check_output(results[i], expected_outputs['key' + str(i)]) + else: done = False - break - - # Wait for all items to finish. - self.wait_for_empty(async_dofn) + time.sleep(random.randint(10, 30)) for i in range(0, 10): - result = async_dofn.commit_finished_items( - bag_states['key' + str(i)], timers['key' + str(i)]) - logging.info('pre_crash_results %s', pre_crash_results[i]) - logging.info('result %s', result) - self.check_output( - pre_crash_results[i] + result, expected_outputs['key' + str(i)]) + self.check_output(results[i], expected_outputs['key' + str(i)]) self.assertEqual(bag_states['key' + str(i)].items, [])