Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions sdks/python/apache_beam/transforms/async_dofn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
max_items_to_buffer=None,
timeout=1,
max_wait_time=0.5,
id_fn=None,
):
"""Wraps the sync_fn to create an asynchronous version.

Expand All @@ -101,13 +102,16 @@ def __init__(
locally before it goes in the queue of waiting work.
max_wait_time: The maximum amount of sleep time while attempting to
schedule an item. Used in testing to ensure timeouts are met.
id_fn: A function that returns a hashable object from an element. This
will be used to track items instead of the element's default hash.
"""
self._sync_fn = sync_fn
self._uuid = uuid.uuid4().hex
self._parallelism = parallelism
self._timeout = timeout
self._max_wait_time = max_wait_time
self._timer_frequency = callback_frequency
self._id_fn = id_fn or (lambda x: x)
if max_items_to_buffer is None:
self._max_items_to_buffer = max(parallelism * 2, 10)
else:
Expand Down Expand Up @@ -205,7 +209,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
True if the item was scheduled False otherwise.
"""
with AsyncWrapper._lock:
if element in AsyncWrapper._processing_elements[self._uuid]:
element_id = self._id_fn(element[1])
if element_id in AsyncWrapper._processing_elements[self._uuid]:
logging.info('item %s already in processing elements', element)
return True
if self.accepting_items() or ignore_buffer:
Expand All @@ -214,7 +219,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
lambda: self.sync_fn_process(element, *args, **kwargs),
)
result.add_done_callback(self.decrement_items_in_buffer)
AsyncWrapper._processing_elements[self._uuid][element] = result
AsyncWrapper._processing_elements[self._uuid][element_id] = (
element, result)
AsyncWrapper._items_in_buffer[self._uuid] += 1
return True
else:
Expand Down Expand Up @@ -345,9 +351,6 @@ def commit_finished_items(

to_process_local = list(to_process.read())

# For all elements that in local state but not processing state delete them
# from local state and cancel their futures.
to_remove = []
key = None
to_reschedule = []
if to_process_local:
Expand All @@ -362,27 +365,32 @@ def commit_finished_items(
# given key. Skip items in processing_elements which are for a different
# key.
with AsyncWrapper._lock:
for x in AsyncWrapper._processing_elements[self._uuid]:
if x[0] == key and x not in to_process_local:
processing_elements = AsyncWrapper._processing_elements[self._uuid]
to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local}
to_remove_ids = []
for element_id, (element, future) in processing_elements.items():
if element[0] == key and element_id not in to_process_local_ids:
items_cancelled += 1
AsyncWrapper._processing_elements[self._uuid][x].cancel()
to_remove.append(x)
future.cancel()
to_remove_ids.append(element_id)
logging.info(
'cancelling item %s which is no longer in processing state', x)
for x in to_remove:
AsyncWrapper._processing_elements[self._uuid].pop(x)
'cancelling item %s which is no longer in processing state',
element)
for element_id in to_remove_ids:
processing_elements.pop(element_id)

# For all elements which have finished processing output their result.
to_return = []
finished_items = []
for x in to_process_local:
items_in_se_state += 1
if x in AsyncWrapper._processing_elements[self._uuid]:
if AsyncWrapper._processing_elements[self._uuid][x].done():
to_return.append(
AsyncWrapper._processing_elements[self._uuid][x].result())
x_id = self._id_fn(x[1])
if x_id in processing_elements:
_, future = processing_elements[x_id]
if future.done():
to_return.append(future.result())
finished_items.append(x)
AsyncWrapper._processing_elements[self._uuid].pop(x)
processing_elements.pop(x_id)
items_finished += 1
else:
items_not_yet_finished += 1
Expand Down
34 changes: 34 additions & 0 deletions sdks/python/apache_beam/transforms/async_dofn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,40 @@ def check_items_in_buffer(self, async_dofn, expected_count):
expected_count,
)

def test_custom_id_fn(self):
class CustomIdObject:
def __init__(self, element_id, value):
self.element_id = element_id
self.value = value

def __hash__(self):
return hash(self.element_id)

def __eq__(self, other):
return self.element_id == other.element_id

dofn = BasicDofn()
async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id)
async_dofn.setup()
fake_bag_state = FakeBagState([])
fake_timer = FakeTimer(0)
msg1 = ('key1', CustomIdObject(1, 'a'))
msg2 = ('key1', CustomIdObject(1, 'b'))

result = async_dofn.process(
msg1, to_process=fake_bag_state, timer=fake_timer)
self.assertEqual(result, [])

# The second message should be a no-op as it has the same id.
result = async_dofn.process(
msg2, to_process=fake_bag_state, timer=fake_timer)
self.assertEqual(result, [])

self.wait_for_empty(async_dofn)
result = async_dofn.commit_finished_items(fake_bag_state, fake_timer)
self.check_output(result, [('key1', msg1[1])])
self.assertEqual(fake_bag_state.items, [])

def test_basic(self):
# Setup an async dofn and send a message in to process.
dofn = BasicDofn()
Expand Down
Loading