Skip to content

Commit e96ade6

Browse files
authored
Check for conflicting decorators (#220)
1 parent 9945afc commit e96ade6

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

dbos/_core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def decorate_workflow(
524524
) -> Callable[[Callable[P, R]], Callable[P, R]]:
525525
def _workflow_decorator(func: Callable[P, R]) -> Callable[P, R]:
526526
wrapped_func = workflow_wrapper(reg, func, max_recovery_attempts)
527-
reg.register_wf_function(func.__qualname__, wrapped_func)
527+
reg.register_wf_function(func.__qualname__, wrapped_func, "workflow")
528528
return wrapped_func
529529

530530
return _workflow_decorator
@@ -676,7 +676,9 @@ def temp_wf(*args: Any, **kwargs: Any) -> Any:
676676
wrapped_wf = workflow_wrapper(dbosreg, temp_wf)
677677
set_dbos_func_name(temp_wf, "<temp>." + func.__qualname__)
678678
set_temp_workflow_type(temp_wf, "transaction")
679-
dbosreg.register_wf_function(get_dbos_func_name(temp_wf), wrapped_wf)
679+
dbosreg.register_wf_function(
680+
get_dbos_func_name(temp_wf), wrapped_wf, "transaction"
681+
)
680682
wrapper.__orig_func = temp_wf # type: ignore
681683
set_func_info(wrapped_wf, get_or_create_func_info(func))
682684
set_func_info(temp_wf, get_or_create_func_info(func))
@@ -827,7 +829,7 @@ async def temp_wf_async(*args: Any, **kwargs: Any) -> Any:
827829
wrapped_wf = workflow_wrapper(dbosreg, temp_wf)
828830
set_dbos_func_name(temp_wf, "<temp>." + func.__qualname__)
829831
set_temp_workflow_type(temp_wf, "step")
830-
dbosreg.register_wf_function(get_dbos_func_name(temp_wf), wrapped_wf)
832+
dbosreg.register_wf_function(get_dbos_func_name(temp_wf), wrapped_wf, "step")
831833
wrapper.__orig_func = temp_wf # type: ignore
832834
set_func_info(wrapped_wf, get_or_create_func_info(func))
833835
set_func_info(temp_wf, get_or_create_func_info(func))

dbos/_dbos.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@
8585
get_local_dbos_context,
8686
)
8787
from ._dbos_config import ConfigFile, load_config, set_env_vars
88-
from ._error import DBOSException, DBOSNonExistentWorkflowError
88+
from ._error import (
89+
DBOSConflictingRegistrationError,
90+
DBOSException,
91+
DBOSNonExistentWorkflowError,
92+
)
8993
from ._logger import add_otlp_to_all_loggers, dbos_logger
9094
from ._sys_db import SystemDatabase
9195

@@ -144,14 +148,19 @@ def _get_or_create_dbos_registry() -> DBOSRegistry:
144148
class DBOSRegistry:
145149
def __init__(self) -> None:
146150
self.workflow_info_map: dict[str, Workflow[..., Any]] = {}
151+
self.function_type_map: dict[str, str] = {}
147152
self.class_info_map: dict[str, type] = {}
148153
self.instance_info_map: dict[str, object] = {}
149154
self.queue_info_map: dict[str, Queue] = {}
150155
self.pollers: list[RegisteredJob] = []
151156
self.dbos: Optional[DBOS] = None
152157
self.config: Optional[ConfigFile] = None
153158

154-
def register_wf_function(self, name: str, wrapped_func: F) -> None:
159+
def register_wf_function(self, name: str, wrapped_func: F, functype: str) -> None:
160+
if name in self.function_type_map:
161+
if self.function_type_map[name] != functype:
162+
raise DBOSConflictingRegistrationError(name)
163+
self.function_type_map[name] = functype
155164
self.workflow_info_map[name] = wrapped_func
156165

157166
def register_class(self, cls: type, ci: DBOSClassInfo) -> None:
@@ -324,7 +333,7 @@ def send_temp_workflow(
324333
temp_send_wf = workflow_wrapper(self._registry, send_temp_workflow)
325334
set_dbos_func_name(send_temp_workflow, TEMP_SEND_WF_NAME)
326335
set_temp_workflow_type(send_temp_workflow, "send")
327-
self._registry.register_wf_function(TEMP_SEND_WF_NAME, temp_send_wf)
336+
self._registry.register_wf_function(TEMP_SEND_WF_NAME, temp_send_wf, "send")
328337

329338
for handler in dbos_logger.handlers:
330339
handler.flush()

dbos/_error.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class DBOSErrorCode(Enum):
3636
MaxStepRetriesExceeded = 7
3737
NotAuthorized = 8
3838
ConflictingWorkflowError = 9
39+
ConflictingRegistrationError = 25
3940

4041

4142
class DBOSWorkflowConflictIDError(DBOSException):
@@ -127,3 +128,13 @@ def __init__(self) -> None:
127128
"Step reached maximum retries.",
128129
dbos_error_code=DBOSErrorCode.MaxStepRetriesExceeded.value,
129130
)
131+
132+
133+
class DBOSConflictingRegistrationError(DBOSException):
134+
"""Exception raised when conflicting decorators are applied to the same function."""
135+
136+
def __init__(self, name: str) -> None:
137+
super().__init__(
138+
f"Operation (Name: {name}) is already registered with a conflicting function type",
139+
dbos_error_code=DBOSErrorCode.ConflictingRegistrationError.value,
140+
)

tests/test_dbos.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# Private API because this is a test
1818
from dbos._context import assert_current_dbos_context, get_local_dbos_context
19-
from dbos._error import DBOSMaxStepRetriesExceeded
19+
from dbos._error import DBOSConflictingRegistrationError, DBOSMaxStepRetriesExceeded
2020
from dbos._schemas.system_database import SystemSchema
2121
from dbos._sys_db import GetWorkflowsInput
2222

@@ -1208,6 +1208,20 @@ def test_workflow(var: str) -> str:
12081208
assert test_workflow(var) == var
12091209

12101210

1211+
def test_double_decoration(dbos: DBOS) -> None:
1212+
with pytest.raises(
1213+
DBOSConflictingRegistrationError,
1214+
match="is already registered with a conflicting function type",
1215+
):
1216+
1217+
@DBOS.step()
1218+
@DBOS.transaction()
1219+
def my_function() -> None:
1220+
pass
1221+
1222+
my_function()
1223+
1224+
12111225
def test_app_version(config: ConfigFile) -> None:
12121226
def is_hex(s: str) -> bool:
12131227
return all(c in "0123456789abcdefABCDEF" for c in s)

0 commit comments

Comments
 (0)