Skip to content

Commit 4651b40

Browse files
authored
Set Event From Step (#483)
1 parent 205566d commit 4651b40

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

dbos/_core.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,21 +1278,24 @@ def recv(dbos: "DBOS", topic: Optional[str] = None, timeout_seconds: float = 60)
12781278
def set_event(dbos: "DBOS", key: str, value: Any) -> None:
12791279
cur_ctx = get_local_dbos_context()
12801280
if cur_ctx is not None:
1281-
# Must call it within a workflow
1282-
assert (
1283-
cur_ctx.is_workflow()
1284-
), "set_event() must be called from within a workflow"
1285-
attributes: TracedAttributes = {
1286-
"name": "set_event",
1287-
}
1288-
with EnterDBOSStep(attributes):
1289-
ctx = assert_current_dbos_context()
1290-
dbos._sys_db.set_event(
1291-
ctx.workflow_id, ctx.curr_step_function_id, key, value
1281+
if cur_ctx.is_workflow():
1282+
# If called from a workflow function, run as a step
1283+
attributes: TracedAttributes = {
1284+
"name": "set_event",
1285+
}
1286+
with EnterDBOSStep(attributes):
1287+
ctx = assert_current_dbos_context()
1288+
dbos._sys_db.set_event_from_workflow(
1289+
ctx.workflow_id, ctx.curr_step_function_id, key, value
1290+
)
1291+
elif cur_ctx.is_step():
1292+
dbos._sys_db.set_event_from_step(cur_ctx.workflow_id, key, value)
1293+
else:
1294+
raise DBOSException(
1295+
"set_event() must be called from within a workflow or step"
12921296
)
12931297
else:
1294-
# Cannot call it from outside of a workflow
1295-
raise DBOSException("set_event() must be called from within a workflow")
1298+
raise DBOSException("set_event() must be called from within a workflow or step")
12961299

12971300

12981301
def get_event(

dbos/_sys_db.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ def sleep(
15241524
return duration
15251525

15261526
@db_retry()
1527-
def set_event(
1527+
def set_event_from_workflow(
15281528
self,
15291529
workflow_uuid: str,
15301530
function_id: int,
@@ -1566,6 +1566,26 @@ def set_event(
15661566
}
15671567
self._record_operation_result_txn(output, conn=c)
15681568

1569+
def set_event_from_step(
1570+
self,
1571+
workflow_uuid: str,
1572+
key: str,
1573+
message: Any,
1574+
) -> None:
1575+
with self.engine.begin() as c:
1576+
c.execute(
1577+
self.dialect.insert(SystemSchema.workflow_events)
1578+
.values(
1579+
workflow_uuid=workflow_uuid,
1580+
key=key,
1581+
value=_serialization.serialize(message),
1582+
)
1583+
.on_conflict_do_update(
1584+
index_elements=["workflow_uuid", "key"],
1585+
set_={"value": _serialization.serialize(message)},
1586+
)
1587+
)
1588+
15691589
def get_all_events(self, workflow_id: str) -> Dict[str, Any]:
15701590
"""
15711591
Get all events currently present for a workflow ID.

tests/test_dbos.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -965,41 +965,57 @@ def test_send_recv_workflow(topic: str) -> str:
965965
def test_set_get_events(dbos: DBOS) -> None:
966966
@DBOS.workflow()
967967
def test_setevent_workflow() -> None:
968-
dbos.set_event("key1", "value1")
969-
dbos.set_event("key2", "value2")
970-
dbos.set_event("key3", None)
968+
DBOS.set_event("key1", "value1")
969+
DBOS.set_event("key2", "value2")
970+
DBOS.set_event("key3", None)
971+
set_event_step()
972+
973+
@DBOS.step()
974+
def set_event_step() -> None:
975+
DBOS.set_event("key4", "value4")
971976

972977
@DBOS.workflow()
973978
def test_getevent_workflow(
974-
target_uuid: str, key: str, timeout_seconds: float = 10
979+
target_uuid: str, key: str, timeout: float = 0.0
975980
) -> Optional[str]:
976-
msg = dbos.get_event(target_uuid, key, timeout_seconds)
981+
msg = dbos.get_event(target_uuid, key, timeout)
977982
return str(msg) if msg is not None else None
978983

979-
wfuuid = str(uuid.uuid4())
980-
with SetWorkflowID(wfuuid):
984+
wfid = str(uuid.uuid4())
985+
with SetWorkflowID(wfid):
981986
test_setevent_workflow()
982-
with SetWorkflowID(wfuuid):
987+
with SetWorkflowID(wfid):
983988
test_setevent_workflow()
984989

985-
value1 = test_getevent_workflow(wfuuid, "key1")
990+
value1 = test_getevent_workflow(wfid, "key1")
986991
assert value1 == "value1"
987992

988-
value2 = test_getevent_workflow(wfuuid, "key2")
993+
value2 = test_getevent_workflow(wfid, "key2")
989994
assert value2 == "value2"
990995

991996
# Run getEvent outside of a workflow
992-
value1 = dbos.get_event(wfuuid, "key1")
997+
value1 = DBOS.get_event(wfid, "key1", 0)
993998
assert value1 == "value1"
994999

995-
value2 = dbos.get_event(wfuuid, "key2")
1000+
value2 = DBOS.get_event(wfid, "key2", 0)
9961001
assert value2 == "value2"
9971002

9981003
begin_time = time.time()
999-
value3 = test_getevent_workflow(wfuuid, "key3")
1004+
value3 = test_getevent_workflow(wfid, "key3")
10001005
assert value3 is None
1001-
duration = time.time() - begin_time
1002-
assert duration < 1 # None is from the event not from the timeout
1006+
1007+
value4 = DBOS.get_event(wfid, "key4", 0)
1008+
assert value4 == "value4"
1009+
1010+
steps = DBOS.list_workflow_steps(wfid)
1011+
assert len(steps) == 4
1012+
assert (
1013+
steps[0]["function_name"]
1014+
== steps[1]["function_name"]
1015+
== steps[2]["function_name"]
1016+
== "DBOS.setEvent"
1017+
)
1018+
assert steps[3]["function_name"] == set_event_step.__qualname__
10031019

10041020
# Test OAOO
10051021
timeout_uuid = str(uuid.uuid4())

0 commit comments

Comments
 (0)