Skip to content

Commit 8c44a9e

Browse files
authored
Support custom id function in async_dofn (#36779)
* Allow for a custom id function other than the default hashing funciton. * fix formatting errors * Formatting Fix 2 * fix linter errors * change element_ to _
1 parent addc06e commit 8c44a9e

File tree

2 files changed

+59
-17
lines changed

2 files changed

+59
-17
lines changed

sdks/python/apache_beam/transforms/async_dofn.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
max_items_to_buffer=None,
7878
timeout=1,
7979
max_wait_time=0.5,
80+
id_fn=None,
8081
):
8182
"""Wraps the sync_fn to create an asynchronous version.
8283
@@ -101,13 +102,16 @@ def __init__(
101102
locally before it goes in the queue of waiting work.
102103
max_wait_time: The maximum amount of sleep time while attempting to
103104
schedule an item. Used in testing to ensure timeouts are met.
105+
id_fn: A function that returns a hashable object from an element. This
106+
will be used to track items instead of the element's default hash.
104107
"""
105108
self._sync_fn = sync_fn
106109
self._uuid = uuid.uuid4().hex
107110
self._parallelism = parallelism
108111
self._timeout = timeout
109112
self._max_wait_time = max_wait_time
110113
self._timer_frequency = callback_frequency
114+
self._id_fn = id_fn or (lambda x: x)
111115
if max_items_to_buffer is None:
112116
self._max_items_to_buffer = max(parallelism * 2, 10)
113117
else:
@@ -205,7 +209,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
205209
True if the item was scheduled False otherwise.
206210
"""
207211
with AsyncWrapper._lock:
208-
if element in AsyncWrapper._processing_elements[self._uuid]:
212+
element_id = self._id_fn(element[1])
213+
if element_id in AsyncWrapper._processing_elements[self._uuid]:
209214
logging.info('item %s already in processing elements', element)
210215
return True
211216
if self.accepting_items() or ignore_buffer:
@@ -214,7 +219,8 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
214219
lambda: self.sync_fn_process(element, *args, **kwargs),
215220
)
216221
result.add_done_callback(self.decrement_items_in_buffer)
217-
AsyncWrapper._processing_elements[self._uuid][element] = result
222+
AsyncWrapper._processing_elements[self._uuid][element_id] = (
223+
element, result)
218224
AsyncWrapper._items_in_buffer[self._uuid] += 1
219225
return True
220226
else:
@@ -345,9 +351,6 @@ def commit_finished_items(
345351

346352
to_process_local = list(to_process.read())
347353

348-
# For all elements that in local state but not processing state delete them
349-
# from local state and cancel their futures.
350-
to_remove = []
351354
key = None
352355
to_reschedule = []
353356
if to_process_local:
@@ -362,27 +365,32 @@ def commit_finished_items(
362365
# given key. Skip items in processing_elements which are for a different
363366
# key.
364367
with AsyncWrapper._lock:
365-
for x in AsyncWrapper._processing_elements[self._uuid]:
366-
if x[0] == key and x not in to_process_local:
368+
processing_elements = AsyncWrapper._processing_elements[self._uuid]
369+
to_process_local_ids = {self._id_fn(e[1]) for e in to_process_local}
370+
to_remove_ids = []
371+
for element_id, (element, future) in processing_elements.items():
372+
if element[0] == key and element_id not in to_process_local_ids:
367373
items_cancelled += 1
368-
AsyncWrapper._processing_elements[self._uuid][x].cancel()
369-
to_remove.append(x)
374+
future.cancel()
375+
to_remove_ids.append(element_id)
370376
logging.info(
371-
'cancelling item %s which is no longer in processing state', x)
372-
for x in to_remove:
373-
AsyncWrapper._processing_elements[self._uuid].pop(x)
377+
'cancelling item %s which is no longer in processing state',
378+
element)
379+
for element_id in to_remove_ids:
380+
processing_elements.pop(element_id)
374381

375382
# For all elements which have finished processing output their result.
376383
to_return = []
377384
finished_items = []
378385
for x in to_process_local:
379386
items_in_se_state += 1
380-
if x in AsyncWrapper._processing_elements[self._uuid]:
381-
if AsyncWrapper._processing_elements[self._uuid][x].done():
382-
to_return.append(
383-
AsyncWrapper._processing_elements[self._uuid][x].result())
387+
x_id = self._id_fn(x[1])
388+
if x_id in processing_elements:
389+
_, future = processing_elements[x_id]
390+
if future.done():
391+
to_return.append(future.result())
384392
finished_items.append(x)
385-
AsyncWrapper._processing_elements[self._uuid].pop(x)
393+
processing_elements.pop(x_id)
386394
items_finished += 1
387395
else:
388396
items_not_yet_finished += 1

sdks/python/apache_beam/transforms/async_dofn_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,40 @@ def check_items_in_buffer(self, async_dofn, expected_count):
119119
expected_count,
120120
)
121121

122+
def test_custom_id_fn(self):
123+
class CustomIdObject:
124+
def __init__(self, element_id, value):
125+
self.element_id = element_id
126+
self.value = value
127+
128+
def __hash__(self):
129+
return hash(self.element_id)
130+
131+
def __eq__(self, other):
132+
return self.element_id == other.element_id
133+
134+
dofn = BasicDofn()
135+
async_dofn = async_lib.AsyncWrapper(dofn, id_fn=lambda x: x.element_id)
136+
async_dofn.setup()
137+
fake_bag_state = FakeBagState([])
138+
fake_timer = FakeTimer(0)
139+
msg1 = ('key1', CustomIdObject(1, 'a'))
140+
msg2 = ('key1', CustomIdObject(1, 'b'))
141+
142+
result = async_dofn.process(
143+
msg1, to_process=fake_bag_state, timer=fake_timer)
144+
self.assertEqual(result, [])
145+
146+
# The second message should be a no-op as it has the same id.
147+
result = async_dofn.process(
148+
msg2, to_process=fake_bag_state, timer=fake_timer)
149+
self.assertEqual(result, [])
150+
151+
self.wait_for_empty(async_dofn)
152+
result = async_dofn.commit_finished_items(fake_bag_state, fake_timer)
153+
self.check_output(result, [('key1', msg1[1])])
154+
self.assertEqual(fake_bag_state.items, [])
155+
122156
def test_basic(self):
123157
# Setup an async dofn and send a message in to process.
124158
dofn = BasicDofn()

0 commit comments

Comments
 (0)