Skip to content

Commit 13a88a2

Browse files
authored
Remove Temp Workflows for Direct Calls (#430)
Calling a step directly from outside a workflow now always results in it being called as a normal Python function. Steps are now only retried if called from within workflows. Steps can still be enqueued.
1 parent 67d1d9e commit 13a88a2

File tree

8 files changed

+59
-120
lines changed

8 files changed

+59
-120
lines changed

dbos/_core.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,27 +1156,15 @@ def check_existing_result() -> Union[NoResult, R]:
11561156

11571157
@wraps(func)
11581158
def wrapper(*args: Any, **kwargs: Any) -> Any:
1159-
rr: Optional[str] = check_required_roles(func, fi)
1160-
# Entering step is allowed:
1161-
# No DBOS, just call the original function directly
1162-
# In a step already, just call the original function directly.
1163-
# In a workflow (that is not in a step already)
1164-
# Not in a workflow (we will start the single op workflow)
1165-
if not dbosreg.dbos or not dbosreg.dbos._launched:
1166-
# Call the original function directly
1167-
return func(*args, **kwargs)
1159+
# If the step is called from a workflow, run it as a step.
1160+
# Otherwise, run it as a normal function.
11681161
ctx = get_local_dbos_context()
1169-
if ctx and ctx.is_step():
1170-
# Call the original function directly
1171-
return func(*args, **kwargs)
1172-
if ctx and ctx.is_within_workflow():
1173-
assert ctx.is_workflow(), "Steps must be called from within workflows"
1162+
if ctx and ctx.is_workflow():
1163+
rr: Optional[str] = check_required_roles(func, fi)
11741164
with DBOSAssumeRole(rr):
11751165
return invoke_step(*args, **kwargs)
11761166
else:
1177-
tempwf = dbosreg.workflow_info_map.get("<temp>." + step_name)
1178-
assert tempwf
1179-
return tempwf(*args, **kwargs)
1167+
return func(*args, **kwargs)
11801168

11811169
wrapper = (
11821170
_mark_coroutine(wrapper) if inspect.iscoroutinefunction(func) else wrapper # type: ignore

tests/test_async.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -329,30 +329,6 @@ async def test_async_tx() -> None:
329329
DBOS.destroy(destroy_registry=True)
330330

331331

332-
@pytest.mark.asyncio
333-
async def test_async_step_temp(dbos: DBOS) -> None:
334-
step_counter: int = 0
335-
336-
@DBOS.step()
337-
async def test_step(var: str) -> str:
338-
await asyncio.sleep(0.1)
339-
nonlocal step_counter
340-
step_counter += 1
341-
DBOS.logger.info("I'm test_step")
342-
return var + f"step{step_counter}"
343-
344-
wfuuid = f"test_async_step_temp-{time.time_ns()}"
345-
with SetWorkflowID(wfuuid):
346-
result = await test_step("alice")
347-
assert result == "alicestep1"
348-
349-
with SetWorkflowID(wfuuid):
350-
result = await test_step("alice")
351-
assert result == "alicestep1"
352-
353-
assert step_counter == 1
354-
355-
356332
@pytest.mark.asyncio
357333
async def test_start_workflow_async(dbos: DBOS) -> None:
358334
wf_counter: int = 0

tests/test_classdecorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def step(self, x: int) -> int:
574574
def call_step() -> None:
575575
with SetWorkflowID(wfid):
576576
nonlocal return_value
577-
return_value = inst.step(input)
577+
return_value = DBOS.start_workflow(inst.step, input).get_result()
578578

579579
thread = threading.Thread(target=call_step)
580580
thread.start()

tests/test_concurrency.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ def test_workflow() -> str:
5959
res = test_step()
6060
return res
6161

62-
def test_comm_thread(id: str) -> str:
63-
with SetWorkflowID(id):
64-
return test_step()
65-
6662
# Need to set isolation level to a lower one, otherwise it gets serialization error instead (we already handle it correctly by automatic retries).
6763
@DBOS.transaction(isolation_level="REPEATABLE READ")
6864
def test_transaction() -> str:
@@ -97,15 +93,6 @@ def test_txn_thread(id: str) -> str:
9793
assert wf_handle1.get_result() == wfuuid
9894
assert wf_handle2.get_result() == wfuuid
9995

100-
# Make sure temp workflows can handle conflicts as well.
101-
wfuuid = str(uuid.uuid4())
102-
with ThreadPoolExecutor(max_workers=2) as executor:
103-
future1 = executor.submit(test_comm_thread, wfuuid)
104-
future2 = executor.submit(test_comm_thread, wfuuid)
105-
106-
assert future1.result() == wfuuid
107-
assert future2.result() == wfuuid
108-
10996
# Make sure temp transactions can handle conflicts as well.
11097
wfuuid = str(uuid.uuid4())
11198
with ThreadPoolExecutor(max_workers=2) as executor:

tests/test_dbos.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -307,16 +307,12 @@ def call_step(var: str) -> str:
307307
assert res == "var"
308308

309309
wfs = dbos._sys_db.get_workflows(gwi)
310-
assert len(wfs) == 2
310+
assert len(wfs) == 1
311311

312312
wfi1 = dbos._sys_db.get_workflow_status(wfs[0].workflow_id)
313313
assert wfi1
314314
assert wfi1["name"].startswith("<temp>")
315315

316-
wfi2 = dbos._sys_db.get_workflow_status(wfs[1].workflow_id)
317-
assert wfi2
318-
assert wfi2["name"].startswith("<temp>")
319-
320316
assert txn_counter == 1
321317
assert step_counter == 1
322318

@@ -350,7 +346,7 @@ def test_step(var: str) -> str:
350346
def test_retried_step(var: str) -> str:
351347
nonlocal retried_step_counter
352348
retried_step_counter += 1
353-
raise Exception(var)
349+
raise ValueError(var)
354350

355351
with pytest.raises(Exception) as exc_info:
356352
test_transaction("tval")
@@ -360,12 +356,12 @@ def test_retried_step(var: str) -> str:
360356
test_step("cval")
361357
assert "cval" == str(exc_info.value)
362358

363-
with pytest.raises(DBOSMaxStepRetriesExceeded) as exc_info:
359+
with pytest.raises(ValueError) as exc_info:
364360
test_retried_step("rval")
365361

366362
assert txn_counter == 1
367363
assert step_counter == 1
368-
assert retried_step_counter == 3
364+
assert retried_step_counter == 1
369365

370366

371367
def test_recovery_workflow(dbos: DBOS) -> None:
@@ -1102,9 +1098,6 @@ def test_bad_wf4(var: str) -> str:
11021098
with pytest.raises(Exception) as exc_info:
11031099
test_ns_transaction("h")
11041100
assert "data item should not be a function" in str(exc_info.value)
1105-
with pytest.raises(Exception) as exc_info:
1106-
test_ns_step("f")
1107-
assert "data item should not be a function" in str(exc_info.value)
11081101
with pytest.raises(Exception) as exc_info:
11091102
test_ns_wf("g")
11101103
assert "data item should not be a function" in str(exc_info.value)
@@ -1645,22 +1638,14 @@ def workflow(x: int) -> int:
16451638
async def test_step_without_dbos(dbos: DBOS, config: DBOSConfig) -> None:
16461639
DBOS.destroy(destroy_registry=True)
16471640

1648-
is_dbos_active = False
1649-
16501641
@DBOS.step()
16511642
def step(x: int) -> int:
1652-
if is_dbos_active:
1653-
assert DBOS.workflow_id is not None
1654-
else:
1655-
assert DBOS.workflow_id is None
1643+
assert DBOS.workflow_id is None
16561644
return x
16571645

16581646
@DBOS.step()
16591647
async def async_step(x: int) -> int:
1660-
if is_dbos_active:
1661-
assert DBOS.workflow_id is not None
1662-
else:
1663-
assert DBOS.workflow_id is None
1648+
assert DBOS.workflow_id is None
16641649
return x
16651650

16661651
assert step(5) == 5
@@ -1672,7 +1657,30 @@ async def async_step(x: int) -> int:
16721657
assert await async_step(5) == 5
16731658

16741659
DBOS.launch()
1675-
is_dbos_active = True
16761660

16771661
assert step(5) == 5
16781662
assert await async_step(5) == 5
1663+
1664+
assert len(DBOS.list_workflows()) == 0
1665+
1666+
1667+
def test_nested_steps(dbos: DBOS) -> None:
1668+
1669+
@DBOS.step()
1670+
def outer_step() -> str:
1671+
return inner_step()
1672+
1673+
@DBOS.step()
1674+
def inner_step() -> str:
1675+
id = DBOS.workflow_id
1676+
assert id is not None
1677+
return id
1678+
1679+
@DBOS.workflow()
1680+
def workflow() -> str:
1681+
return outer_step()
1682+
1683+
id = workflow()
1684+
steps = DBOS.list_workflow_steps(id)
1685+
assert len(steps) == 1
1686+
assert steps[0]["function_name"] == outer_step.__qualname__

tests/test_failures.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -333,22 +333,20 @@ def enqueue_failing_step() -> None:
333333
error_message = f"Step {failing_step.__qualname__} has exceeded its maximum of {max_attempts} retries"
334334

335335
# Test calling the step directly
336-
with pytest.raises(DBOSMaxStepRetriesExceeded) as excinfo:
336+
with pytest.raises(Exception) as excinfo:
337337
failing_step()
338-
assert error_message in str(excinfo.value)
339-
assert step_counter == max_attempts
340-
assert len(excinfo.value.errors) == max_attempts
341-
for error in excinfo.value.errors:
342-
assert isinstance(error, Exception)
343-
assert error
344-
assert "fail" in str(error)
345338

346339
# Test calling the workflow
347340
step_counter = 0
348341
with pytest.raises(DBOSMaxStepRetriesExceeded) as excinfo:
349342
failing_workflow()
350343
assert error_message in str(excinfo.value)
351344
assert step_counter == max_attempts
345+
assert len(excinfo.value.errors) == max_attempts
346+
for error in excinfo.value.errors:
347+
assert isinstance(error, Exception)
348+
assert error
349+
assert "fail" in str(error)
352350

353351
# Test enqueueing the step
354352
step_counter = 0
@@ -399,7 +397,6 @@ def failing_workflow() -> None:
399397

400398
assert failing_workflow() == None
401399
step_counter = 0
402-
assert failing_step() == None
403400

404401

405402
def test_recovery_during_retries(dbos: DBOS) -> None:

tests/test_queue.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def test_step(var: str) -> str:
181181
handle = queue.enqueue(test_step, "abc")
182182
assert handle.get_result() == "abc1"
183183
with SetWorkflowID(wfid):
184-
assert test_step("abc") == "abc1"
184+
handle = queue.enqueue(test_step, "abc")
185+
assert handle.get_result() == "abc1"
185186
assert step_counter == 1
186187

187188

0 commit comments

Comments
 (0)