Skip to content

Commit 752569d

Browse files
authored
Support concurrent recv and get_event (#345)
This PR fixes a race condition when multiple workflow instances are waiting on `recv` or multiple callers are waiting on `get_event`. The solution is to maintain a thread safe map. - For `recv`, only one workflow instance should be waiting, because `recv` consumes messages. If a workflow is already waiting for `recv`, directly raise a `DBOSWorkflowConflictIDError` and wait for the existing workflow to finish. - For `get_event`, multiple callers can wait on the same event. In this case, we maintain a reference counter and only delete the condition variable when nobody is waiting. - Add unit tests for concurrent recv and get_event.
1 parent 327eb0a commit 752569d

File tree

2 files changed

+130
-15
lines changed

2 files changed

+130
-15
lines changed

dbos/_sys_db.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,47 @@ class StepInfo(TypedDict):
222222
_dbos_null_topic = "__null__topic__"
223223

224224

225+
class ConditionCount(TypedDict):
226+
condition: threading.Condition
227+
count: int
228+
229+
230+
class ThreadSafeConditionDict:
231+
def __init__(self) -> None:
232+
self._dict: Dict[str, ConditionCount] = {}
233+
self._lock = threading.Lock()
234+
235+
def get(self, key: str) -> Optional[threading.Condition]:
236+
with self._lock:
237+
if key not in self._dict:
238+
# Key does not exist, return None
239+
return None
240+
return self._dict[key]["condition"]
241+
242+
def set(
243+
self, key: str, value: threading.Condition
244+
) -> tuple[bool, threading.Condition]:
245+
with self._lock:
246+
if key in self._dict:
247+
# Key already exists, do not overwrite. Increment the wait count.
248+
cc = self._dict[key]
249+
cc["count"] += 1
250+
return False, cc["condition"]
251+
self._dict[key] = ConditionCount(condition=value, count=1)
252+
return True, value
253+
254+
def pop(self, key: str) -> None:
255+
with self._lock:
256+
if key in self._dict:
257+
cc = self._dict[key]
258+
cc["count"] -= 1
259+
if cc["count"] == 0:
260+
# No more threads waiting on this condition, remove it
261+
del self._dict[key]
262+
else:
263+
dbos_logger.warning(f"Key {key} not found in condition dictionary.")
264+
265+
225266
class SystemDatabase:
226267

227268
def __init__(
@@ -248,8 +289,8 @@ def __init__(
248289
self._engine_kwargs = engine_kwargs
249290

250291
self.notification_conn: Optional[psycopg.connection.Connection] = None
251-
self.notifications_map: Dict[str, threading.Condition] = {}
252-
self.workflow_events_map: Dict[str, threading.Condition] = {}
292+
self.notifications_map = ThreadSafeConditionDict()
293+
self.workflow_events_map = ThreadSafeConditionDict()
253294

254295
# Now we can run background processes
255296
self._run_background_processes = True
@@ -1288,7 +1329,12 @@ def recv(
12881329
condition = threading.Condition()
12891330
# Must acquire first before adding to the map. Otherwise, the notification listener may notify it before the condition is acquired and waited.
12901331
condition.acquire()
1291-
self.notifications_map[payload] = condition
1332+
success, _ = self.notifications_map.set(payload, condition)
1333+
if not success:
1334+
# This should not happen, but if it does, it means the workflow is executed concurrently.
1335+
condition.release()
1336+
self.notifications_map.pop(payload)
1337+
raise DBOSWorkflowConflictIDError(workflow_uuid)
12921338

12931339
# Check if the key is already in the database. If not, wait for the notification.
12941340
init_recv: Sequence[Any]
@@ -1381,23 +1427,23 @@ def _notification_listener(self) -> None:
13811427
f"Received notification on channel: {channel}, payload: {notify.payload}"
13821428
)
13831429
if channel == "dbos_notifications_channel":
1384-
if (
1385-
notify.payload
1386-
and notify.payload in self.notifications_map
1387-
):
1388-
condition = self.notifications_map[notify.payload]
1430+
if notify.payload:
1431+
condition = self.notifications_map.get(notify.payload)
1432+
if condition is None:
1433+
# No condition found for this payload
1434+
continue
13891435
condition.acquire()
13901436
condition.notify_all()
13911437
condition.release()
13921438
dbos_logger.debug(
13931439
f"Signaled notifications condition for {notify.payload}"
13941440
)
13951441
elif channel == "dbos_workflow_events_channel":
1396-
if (
1397-
notify.payload
1398-
and notify.payload in self.workflow_events_map
1399-
):
1400-
condition = self.workflow_events_map[notify.payload]
1442+
if notify.payload:
1443+
condition = self.workflow_events_map.get(notify.payload)
1444+
if condition is None:
1445+
# No condition found for this payload
1446+
continue
14011447
condition.acquire()
14021448
condition.notify_all()
14031449
condition.release()
@@ -1535,8 +1581,13 @@ def get_event(
15351581

15361582
payload = f"{target_uuid}::{key}"
15371583
condition = threading.Condition()
1538-
self.workflow_events_map[payload] = condition
15391584
condition.acquire()
1585+
success, existing_condition = self.workflow_events_map.set(payload, condition)
1586+
if not success:
1587+
# Wait on the existing condition
1588+
condition.release()
1589+
condition = existing_condition
1590+
condition.acquire()
15401591

15411592
# Check if the key is already in the database. If not, wait for the notification.
15421593
init_recv: Sequence[Any]

tests/test_concurrency.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import uuid
44
from concurrent.futures import Future, ThreadPoolExecutor
5-
from typing import Tuple
5+
from typing import Tuple, cast
66

77
from sqlalchemy import text
88

@@ -108,3 +108,67 @@ def test_txn_thread(id: str) -> str:
108108

109109
assert future1.result() == wfuuid
110110
assert future2.result() == wfuuid
111+
112+
113+
def test_concurrent_recv(dbos: DBOS) -> None:
114+
condition = threading.Condition()
115+
counter = 0
116+
117+
@DBOS.workflow()
118+
def test_workflow(topic: str) -> str:
119+
nonlocal counter
120+
condition.acquire()
121+
counter += 1
122+
if counter % 2 == 1:
123+
# Wait for the other one to notify
124+
condition.wait()
125+
else:
126+
# Notify the other one
127+
condition.notify()
128+
condition.release()
129+
m = cast(str, DBOS.recv(topic, 5))
130+
return m
131+
132+
def test_thread(id: str, topic: str) -> str:
133+
with SetWorkflowID(id):
134+
return test_workflow(topic)
135+
136+
wfuuid = str(uuid.uuid4())
137+
topic = "test_topic"
138+
with ThreadPoolExecutor(max_workers=2) as executor:
139+
future1 = executor.submit(test_thread, wfuuid, topic)
140+
future2 = executor.submit(test_thread, wfuuid, topic)
141+
142+
expected_message = "test message"
143+
DBOS.send(wfuuid, expected_message, topic)
144+
# Both should return the same message
145+
assert future1.result() == future2.result()
146+
assert future1.result() == expected_message
147+
# Make sure the notification map is empty
148+
assert not dbos._sys_db.notifications_map._dict
149+
150+
151+
def test_concurrent_getevent(dbos: DBOS) -> None:
152+
@DBOS.workflow()
153+
def test_workflow(event_name: str, value: str) -> str:
154+
DBOS.set_event(event_name, value)
155+
return value
156+
157+
def test_thread(id: str, event_name: str) -> str:
158+
return cast(str, DBOS.get_event(id, event_name, 5))
159+
160+
wfuuid = str(uuid.uuid4())
161+
event_name = "test_event"
162+
with ThreadPoolExecutor(max_workers=2) as executor:
163+
future1 = executor.submit(test_thread, wfuuid, event_name)
164+
future2 = executor.submit(test_thread, wfuuid, event_name)
165+
166+
expected_message = "test message"
167+
with SetWorkflowID(wfuuid):
168+
test_workflow(event_name, expected_message)
169+
170+
# Both should return the same message
171+
assert future1.result() == future2.result()
172+
assert future1.result() == expected_message
173+
# Make sure the event map is empty
174+
assert not dbos._sys_db.workflow_events_map._dict

0 commit comments

Comments
 (0)