diff --git a/sdks/python/apache_beam/transforms/async_dofn.py b/sdks/python/apache_beam/transforms/async_dofn.py index d2fa90c85085..5e1c6d219f4b 100644 --- a/sdks/python/apache_beam/transforms/async_dofn.py +++ b/sdks/python/apache_beam/transforms/async_dofn.py @@ -77,6 +77,7 @@ def __init__( max_items_to_buffer=None, timeout=1, max_wait_time=0.5, + id_fn=None, ): """Wraps the sync_fn to create an asynchronous version. @@ -101,6 +102,8 @@ def __init__( locally before it goes in the queue of waiting work. max_wait_time: The maximum amount of sleep time while attempting to schedule an item. Used in testing to ensure timeouts are met. + id_fn: A function that returns a hashable object from an element. This + will be used to track items instead of the element's default hash. """ self._sync_fn = sync_fn self._uuid = uuid.uuid4().hex @@ -108,6 +111,7 @@ def __init__( self._timeout = timeout self._max_wait_time = max_wait_time self._timer_frequency = callback_frequency + self._id_fn = id_fn or (lambda x: x) if max_items_to_buffer is None: self._max_items_to_buffer = max(parallelism * 2, 10) else: @@ -205,7 +209,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): True if the item was scheduled False otherwise. """ with AsyncWrapper._lock: - if element in AsyncWrapper._processing_elements[self._uuid]: + element_id = self._id_fn(element[1]) + if element_id in AsyncWrapper._processing_elements[self._uuid]: logging.info('item %s already in processing elements', element) return True if self.accepting_items() or ignore_buffer: @@ -214,7 +219,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs): lambda: self.sync_fn_process(element, *args, **kwargs), ) result.add_done_callback(self.decrement_items_in_buffer) - AsyncWrapper._processing_elements[self._uuid][element] = result + AsyncWrapper._processing_elements[self._uuid][element_id] = ( + element, result) AsyncWrapper._items_in_buffer[self._uuid] += 1 return True else: @@ -345,9 +351,6 @@ def commit_finished_items( to_process_local = list(to_process.read()) - # For all elements that in local state but not processing state delete them - # from local state and cancel their futures. - to_remove = [] key = None to_reschedule = [] if to_process_local: @@ -362,27 +365,32 @@ def commit_finished_items( # given key. Skip items in processing_elements which are for a different # key. with AsyncWrapper._lock: - for x in AsyncWrapper._processing_elements[self._uuid]: - if x[0] == key and x not in to_process_local: + processing_elements = AsyncWrapper._processing_elements[self._uuid] + to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local} + to_remove_ids = [] + for element_id, (element, future) in processing_elements.items(): + if element[0] == key and element_id not in to_process_local_ids: items_cancelled += 1 - AsyncWrapper._processing_elements[self._uuid][x].cancel() - to_remove.append(x) + future.cancel() + to_remove_ids.append(element_id) logging.info( - 'cancelling item %s which is no longer in processing state', x) - for x in to_remove: - AsyncWrapper._processing_elements[self._uuid].pop(x) + 'cancelling item %s which is no longer in processing state', + element) + for element_id in to_remove_ids: + processing_elements.pop(element_id) # For all elements which have finished processing output their result. to_return = [] finished_items = [] for x in to_process_local: items_in_se_state += 1 - if x in AsyncWrapper._processing_elements[self._uuid]: - if AsyncWrapper._processing_elements[self._uuid][x].done(): - to_return.append( - AsyncWrapper._processing_elements[self._uuid][x].result()) + x_id = self._id_fn(x[1]) + if x_id in processing_elements: + _, future = processing_elements[x_id] + if future.done(): + to_return.append(future.result()) finished_items.append(x) - AsyncWrapper._processing_elements[self._uuid].pop(x) + processing_elements.pop(x_id) items_finished += 1 else: items_not_yet_finished += 1 diff --git a/sdks/python/apache_beam/transforms/async_dofn_test.py b/sdks/python/apache_beam/transforms/async_dofn_test.py index 7577e215d1c7..fe75de05ccd5 100644 --- a/sdks/python/apache_beam/transforms/async_dofn_test.py +++ b/sdks/python/apache_beam/transforms/async_dofn_test.py @@ -119,6 +119,40 @@ def check_items_in_buffer(self, async_dofn, expected_count): expected_count, ) + def test_custom_id_fn(self): + class CustomIdObject: + def __init__(self, element_id, value): + self.element_id = element_id + self.value = value + + def __hash__(self): + return hash(self.element_id) + + def __eq__(self, other): + return self.element_id == other.element_id + + dofn = BasicDofn() + async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id) + async_dofn.setup() + fake_bag_state = FakeBagState([]) + fake_timer = FakeTimer(0) + msg1 = ('key1', CustomIdObject(1, 'a')) + msg2 = ('key1', CustomIdObject(1, 'b')) + + result = async_dofn.process( + msg1, to_process=fake_bag_state, timer=fake_timer) + self.assertEqual(result, []) + + # The second message should be a no-op as it has the same id. + result = async_dofn.process( + msg2, to_process=fake_bag_state, timer=fake_timer) + self.assertEqual(result, []) + + self.wait_for_empty(async_dofn) + result = async_dofn.commit_finished_items(fake_bag_state, fake_timer) + self.check_output(result, [('key1', msg1[1])]) + self.assertEqual(fake_bag_state.items, []) + def test_basic(self): # Setup an async dofn and send a message in to process. dofn = BasicDofn()