Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions sdks/python/apache_beam/transforms/async_dofn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import absolute_import

import logging
import random
import uuid
from concurrent.futures import ThreadPoolExecutor
from math import floor
Expand Down Expand Up @@ -55,9 +56,8 @@ class AsyncWrapper(beam.DoFn):
TIMER_SET = ReadModifyWriteStateSpec('timer_set', coders.BooleanCoder())
TO_PROCESS = BagStateSpec(
'to_process',
coders.TupleCoder([coders.StrUtf8Coder(), coders.StrUtf8Coder()]),
)
_timer_frequency = 20
coders.TupleCoder(
[coders.FastPrimitivesCoder(), coders.FastPrimitivesCoder()]))
# The below items are one per dofn (not instance) so are maps of UUID to
# value.
_processing_elements = {}
Expand All @@ -75,7 +75,8 @@ def __init__(
parallelism=1,
callback_frequency=5,
max_items_to_buffer=None,
max_wait_time=120,
timeout=1,
max_wait_time=0.5,
):
"""Wraps the sync_fn to create an asynchronous version.

Expand All @@ -96,14 +97,17 @@ def __init__(
max_items_to_buffer: We should ideally buffer enough to always be busy but
not so much that the worker ooms. By default will be 2x the parallelism
which should be good for most pipelines.
max_wait_time: The maximum amount of time an item should wait to be added
to the buffer. Used for testing to ensure timeouts are met.
timeout: The maximum amount of time an item should try to be scheduled
locally before it goes in the queue of waiting work.
max_wait_time: The maximum amount of sleep time while attempting to
schedule an item. Used in testing to ensure timeouts are met.
"""
self._sync_fn = sync_fn
self._uuid = uuid.uuid4().hex
self._parallelism = parallelism
self._timeout = timeout
self._max_wait_time = max_wait_time
self._timer_frequency = 20
self._timer_frequency = callback_frequency
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As best I can tell, self._timer_frequency is used but self.timer_frequency_ is not. Is there any reason to have both? Same goes for all of these duped fields

if max_items_to_buffer is None:
self._max_items_to_buffer = max(parallelism * 2, 10)
else:
Expand All @@ -112,9 +116,6 @@ def __init__(
AsyncWrapper._processing_elements[self._uuid] = {}
AsyncWrapper._items_in_buffer[self._uuid] = 0
self.max_wait_time = max_wait_time
self.timer_frequency_ = callback_frequency
self.parallelism_ = parallelism
self._next_time_to_fire = Timestamp.now() + Duration(seconds=5)
self._shared_handle = Shared()

@staticmethod
Expand Down Expand Up @@ -238,9 +239,9 @@ def schedule_item(self, element, ignore_buffer=False, *args, **kwargs):
**kwargs: keyword arguments that the wrapped dofn requires.
"""
done = False
sleep_time = 1
sleep_time = 0.01
total_sleep = 0
while not done:
while not done and total_sleep < self._timeout:
done = self.schedule_if_room(element, ignore_buffer, *args, **kwargs)
if not done:
sleep_time = min(self.max_wait_time, sleep_time * 2)
Expand All @@ -256,10 +257,12 @@ def schedule_item(self, element, ignore_buffer=False, *args, **kwargs):
total_sleep += sleep_time
sleep(sleep_time)

def next_time_to_fire(self):
def next_time_to_fire(self, key):
random.seed(key)
return (
floor((time() + self._timer_frequency) / self._timer_frequency) *
self._timer_frequency)
self._timer_frequency) + (
random.random() * self._timer_frequency)
Comment on lines +260 to +265
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like doing all of the work to find a round increment of _timer_frequency is wasted compute once you add the extra fuzziness of random.random() * self._timer_frequency since you're no longer on a round increment afterwards

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started just having keys setting a timer now + 10s. That doesn't work because as new work arrives the timer firing time keeps getting pushed out. ie. an element arrives at t=1, we want to check back on it at t=11 so we set the timer, but then an element arrives at t=9 and overwrites the timer to t=19.

Next setup was having this round increment firing time. so any message that arrives between t=0 and t=10 sets the timer for 0:10. That way the element at t=9 doesn't override the timer to t=19 but keeps it at t=10.

That works but means we see a spike of timers at t=10, t=20, t=30 etc. There isn't any reason the timers all need to fire at these round increments so this is attempting to add fuzzing per key (since timers are per key). Ideally this means that any 1 key has buckets 10s apart so the overwriting problem is fixed but also means that across multiple keys the buckets don't all fire at the same time. I believe this is what the random.seed(key) on line 260 is doing but correct me if I'm wrong.

Also, let me know if you know an easier way to obtain this pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, makes sense. There may be a way to queue up per-key firing times, but that's a more substantial piece of work since Beam timers don't work like that themselves. The compute here is negligible so I'm not particularly worried about it, if we wanted to use some memory instead we could pick an offset per-key and store it, that would eliminate the spikes but keep a consistent cadence of each key firing at the desired frequency when in steady state


def accepting_items(self):
with AsyncWrapper._lock:
Expand Down Expand Up @@ -301,7 +304,7 @@ def process(
# Set a timer to fire on the next round increment of timer_frequency_. Note
# we do this so that each messages timer doesn't get overwritten by the
# next.
time_to_fire = self.next_time_to_fire()
time_to_fire = self.next_time_to_fire(element[0])
timer.set(time_to_fire)

# Don't output any elements. This will be done in commit_finished_items.
Expand Down Expand Up @@ -346,6 +349,7 @@ def commit_finished_items(
# from local state and cancel their futures.
to_remove = []
key = None
to_reschedule = []
if to_process_local:
key = str(to_process_local[0][0])
else:
Expand Down Expand Up @@ -387,9 +391,13 @@ def commit_finished_items(
'item %s found in processing state but not local state,'
' scheduling now',
x)
self.schedule_item(x, ignore_buffer=True)
to_reschedule.append(x)
items_rescheduled += 1

# Reschedule the items not under a lock
for x in to_reschedule:
self.schedule_item(x, ignore_buffer=False)

# Update processing state to remove elements we've finished
to_process.clear()
for x in to_process_local:
Expand All @@ -408,8 +416,8 @@ def commit_finished_items(
# If there are items not yet finished then set a timer to fire in the
# future.
self._next_time_to_fire = Timestamp.now() + Duration(seconds=5)
if items_not_yet_finished > 0:
time_to_fire = self.next_time_to_fire()
if items_in_processing_state > 0:
time_to_fire = self.next_time_to_fire(key)
timer.set(time_to_fire)

# Each result is a list. We want to combine them into a single
Expand Down
39 changes: 17 additions & 22 deletions sdks/python/apache_beam/transforms/async_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,15 @@ def add_item(i):
self.assertEqual(async_dofn._max_items_to_buffer, 5)
self.check_items_in_buffer(async_dofn, 5)

# After 55 seconds all items should be finished (including those which were
# waiting on the buffer).
# Wait for all buffered items to finish.
self.wait_for_empty(async_dofn, 100)
# This will commit buffered items and add new items which didn't fit in the
# buffer.
result = async_dofn.commit_finished_items(fake_bag_state, fake_timer)

# Wait for the new buffered items to finish.
self.wait_for_empty(async_dofn, 100)
result.extend(async_dofn.commit_finished_items(fake_bag_state, fake_timer))
self.check_output(result, expected_output)
self.check_items_in_buffer(async_dofn, 0)

Expand Down Expand Up @@ -414,33 +419,23 @@ def add_item(i):
# Run for a while. Should be enough to start all items but not finish them
# all.
time.sleep(random.randint(30, 50))
# Commit some stuff
pre_crash_results = []
for i in range(0, 10):
pre_crash_results.append(
async_dofn.commit_finished_items(
bag_states['key' + str(i)], timers['key' + str(i)]))

# Wait for all items to at least make it into the buffer.
done = False
results = [[] for _ in range(0, 10)]
while not done:
time.sleep(10)
done = True
for future in futures:
if not future.done():
for i in range(0, 10):
results[i].extend(
async_dofn.commit_finished_items(
bag_states['key' + str(i)], timers['key' + str(i)]))
if not bag_states['key' + str(i)].items:
self.check_output(results[i], expected_outputs['key' + str(i)])
else:
done = False
break

# Wait for all items to finish.
self.wait_for_empty(async_dofn)
time.sleep(random.randint(10, 30))

for i in range(0, 10):
result = async_dofn.commit_finished_items(
bag_states['key' + str(i)], timers['key' + str(i)])
logging.info('pre_crash_results %s', pre_crash_results[i])
logging.info('result %s', result)
self.check_output(
pre_crash_results[i] + result, expected_outputs['key' + str(i)])
self.check_output(results[i], expected_outputs['key' + str(i)])
self.assertEqual(bag_states['key' + str(i)].items, [])


Expand Down
Loading