Skip to content

Commit bcc9cd7

Browse files
authored
Fix Classes (#212)
- Support directly calling (synchronously or asynchronously) steps or transactions in configured classes. - Fix an issue where the arguments of an asynchronously called configured class method weren't properly saved. - Many new tests for configured classes.
1 parent 5c221ac commit bcc9cd7

File tree

5 files changed

+495
-115
lines changed

5 files changed

+495
-115
lines changed

dbos/_core.py

Lines changed: 37 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
get_or_create_func_info,
6464
get_temp_workflow_type,
6565
set_dbos_func_name,
66+
set_func_info,
6667
set_temp_workflow_type,
6768
)
6869
from ._roles import check_required_roles
@@ -286,6 +287,7 @@ def execute_workflow_by_id(
286287
ctx.request = (
287288
_serialization.deserialize(request) if request is not None else None
288289
)
290+
# If this function belongs to a configured class, add that class instance as its first argument
289291
if status["config_name"] is not None:
290292
config_name = status["config_name"]
291293
class_name = status["class_name"]
@@ -295,59 +297,30 @@ def execute_workflow_by_id(
295297
workflow_id,
296298
f"Cannot execute workflow because instance '{iname}' is not registered",
297299
)
298-
299-
if startNew:
300-
return start_workflow(
301-
dbos,
302-
wf_func,
303-
status["queue_name"],
304-
True,
305-
dbos._registry.instance_info_map[iname],
306-
*inputs["args"],
307-
**inputs["kwargs"],
308-
)
309-
else:
310-
with SetWorkflowID(workflow_id):
311-
return start_workflow(
312-
dbos,
313-
wf_func,
314-
status["queue_name"],
315-
True,
316-
dbos._registry.instance_info_map[iname],
317-
*inputs["args"],
318-
**inputs["kwargs"],
319-
)
300+
class_instance = dbos._registry.instance_info_map[iname]
301+
inputs["args"] = (class_instance,) + inputs["args"]
302+
# If this function is a class method, add that class object as its first argument
320303
elif status["class_name"] is not None:
321304
class_name = status["class_name"]
322305
if class_name not in dbos._registry.class_info_map:
323306
raise DBOSWorkflowFunctionNotFoundError(
324307
workflow_id,
325308
f"Cannot execute workflow because class '{class_name}' is not registered",
326309
)
310+
class_object = dbos._registry.class_info_map[class_name]
311+
inputs["args"] = (class_object,) + inputs["args"]
327312

328-
if startNew:
329-
return start_workflow(
330-
dbos,
331-
wf_func,
332-
status["queue_name"],
333-
True,
334-
dbos._registry.class_info_map[class_name],
335-
*inputs["args"],
336-
**inputs["kwargs"],
337-
)
338-
else:
339-
with SetWorkflowID(workflow_id):
340-
return start_workflow(
341-
dbos,
342-
wf_func,
343-
status["queue_name"],
344-
True,
345-
dbos._registry.class_info_map[class_name],
346-
*inputs["args"],
347-
**inputs["kwargs"],
348-
)
313+
if startNew:
314+
return start_workflow(
315+
dbos,
316+
wf_func,
317+
status["queue_name"],
318+
True,
319+
*inputs["args"],
320+
**inputs["kwargs"],
321+
)
349322
else:
350-
if startNew:
323+
with SetWorkflowID(workflow_id):
351324
return start_workflow(
352325
dbos,
353326
wf_func,
@@ -356,16 +329,6 @@ def execute_workflow_by_id(
356329
*inputs["args"],
357330
**inputs["kwargs"],
358331
)
359-
else:
360-
with SetWorkflowID(workflow_id):
361-
return start_workflow(
362-
dbos,
363-
wf_func,
364-
status["queue_name"],
365-
True,
366-
*inputs["args"],
367-
**inputs["kwargs"],
368-
)
369332

370333

371334
@overload
@@ -398,9 +361,12 @@ def start_workflow(
398361
*args: P.args,
399362
**kwargs: P.kwargs,
400363
) -> "WorkflowHandle[R]":
364+
# If the function has a class, add the class object as its first argument
401365
fself: Optional[object] = None
402366
if hasattr(func, "__self__"):
403367
fself = func.__self__
368+
if fself is not None:
369+
args = (fself,) + args # type: ignore
404370

405371
fi = get_func_info(func)
406372
if fi is None:
@@ -436,17 +402,13 @@ def start_workflow(
436402
new_wf_ctx.id_assigned_for_next_workflow = new_wf_ctx.assign_workflow_id()
437403
new_wf_id = new_wf_ctx.id_assigned_for_next_workflow
438404

439-
gin_args: Tuple[Any, ...] = args
440-
if fself is not None:
441-
gin_args = (fself,)
442-
443405
status = _init_workflow(
444406
dbos,
445407
new_wf_ctx,
446408
inputs=inputs,
447409
wf_name=get_dbos_func_name(func),
448-
class_name=get_dbos_class_name(fi, func, gin_args),
449-
config_name=get_config_name(fi, func, gin_args),
410+
class_name=get_dbos_class_name(fi, func, args),
411+
config_name=get_config_name(fi, func, args),
450412
temp_wf_type=get_temp_workflow_type(func),
451413
queue=queue_name,
452414
max_recovery_attempts=fi.max_recovery_attempts,
@@ -464,27 +426,15 @@ def start_workflow(
464426
)
465427
return WorkflowHandlePolling(new_wf_id, dbos)
466428

467-
if fself is not None:
468-
future = dbos._executor.submit(
469-
cast(Callable[..., R], _execute_workflow_wthread),
470-
dbos,
471-
status,
472-
func,
473-
new_wf_ctx,
474-
fself,
475-
*args,
476-
**kwargs,
477-
)
478-
else:
479-
future = dbos._executor.submit(
480-
cast(Callable[..., R], _execute_workflow_wthread),
481-
dbos,
482-
status,
483-
func,
484-
new_wf_ctx,
485-
*args,
486-
**kwargs,
487-
)
429+
future = dbos._executor.submit(
430+
cast(Callable[..., R], _execute_workflow_wthread),
431+
dbos,
432+
status,
433+
func,
434+
new_wf_ctx,
435+
*args,
436+
**kwargs,
437+
)
488438
return WorkflowHandleFuture(new_wf_id, future, dbos)
489439

490440

@@ -516,6 +466,8 @@ def workflow_wrapper(
516466

517467
@wraps(func)
518468
def wrapper(*args: Any, **kwargs: Any) -> R:
469+
fi = get_func_info(func)
470+
assert fi is not None
519471
if dbosreg.dbos is None:
520472
raise DBOSException(
521473
f"Function {func.__name__} invoked before DBOS initialized"
@@ -726,6 +678,8 @@ def temp_wf(*args: Any, **kwargs: Any) -> Any:
726678
set_temp_workflow_type(temp_wf, "transaction")
727679
dbosreg.register_wf_function(get_dbos_func_name(temp_wf), wrapped_wf)
728680
wrapper.__orig_func = temp_wf # type: ignore
681+
set_func_info(wrapped_wf, get_or_create_func_info(func))
682+
set_func_info(temp_wf, get_or_create_func_info(func))
729683

730684
return cast(F, wrapper)
731685

@@ -875,6 +829,8 @@ async def temp_wf_async(*args: Any, **kwargs: Any) -> Any:
875829
set_temp_workflow_type(temp_wf, "step")
876830
dbosreg.register_wf_function(get_dbos_func_name(temp_wf), wrapped_wf)
877831
wrapper.__orig_func = temp_wf # type: ignore
832+
set_func_info(wrapped_wf, get_or_create_func_info(func))
833+
set_func_info(temp_wf, get_or_create_func_info(func))
878834

879835
return cast(Callable[P, R], wrapper)
880836

dbos/_registrations.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from dataclasses import dataclass
23
from enum import Enum
34
from types import FunctionType
45
from typing import Any, Callable, List, Literal, Optional, Tuple, Type, cast
@@ -31,9 +32,9 @@ def set_temp_workflow_type(f: Any, name: TempWorkflowType) -> None:
3132
setattr(f, "dbos_temp_workflow_type", name)
3233

3334

35+
@dataclass
3436
class DBOSClassInfo:
35-
def __init__(self) -> None:
36-
self.def_required_roles: Optional[List[str]] = None
37+
def_required_roles: Optional[List[str]] = None
3738

3839

3940
class DBOSFuncType(Enum):
@@ -44,12 +45,12 @@ class DBOSFuncType(Enum):
4445
Instance = 4
4546

4647

48+
@dataclass
4749
class DBOSFuncInfo:
48-
def __init__(self) -> None:
49-
self.class_info: Optional[DBOSClassInfo] = None
50-
self.func_type: DBOSFuncType = DBOSFuncType.Unknown
51-
self.required_roles: Optional[List[str]] = None
52-
self.max_recovery_attempts = DEFAULT_MAX_RECOVERY_ATTEMPTS
50+
class_info: Optional[DBOSClassInfo] = None
51+
func_type: DBOSFuncType = DBOSFuncType.Unknown
52+
required_roles: Optional[List[str]] = None
53+
max_recovery_attempts: int = DEFAULT_MAX_RECOVERY_ATTEMPTS
5354

5455

5556
def get_or_create_class_info(cls: Type[Any]) -> DBOSClassInfo:
@@ -110,6 +111,10 @@ def get_or_create_func_info(func: Callable[..., Any]) -> DBOSFuncInfo:
110111
return fi
111112

112113

114+
def set_func_info(func: Callable[..., Any], fi: DBOSFuncInfo) -> None:
115+
setattr(func, "dbos_func_decorator_info", fi)
116+
117+
113118
def get_class_info(cls: Type[Any]) -> Optional[DBOSClassInfo]:
114119
if hasattr(cls, "dbos_class_decorator_info"):
115120
ci: DBOSClassInfo = getattr(cls, "dbos_class_decorator_info")

tests/conftest.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import glob
22
import os
33
import subprocess
4-
import warnings
4+
import time
55
from typing import Any, Generator, Tuple
66

77
import pytest
@@ -10,6 +10,7 @@
1010
from flask import Flask
1111

1212
from dbos import DBOS, ConfigFile
13+
from dbos._schemas.system_database import SystemSchema
1314

1415

1516
@pytest.fixture(scope="session")
@@ -149,3 +150,19 @@ def dbos_flask(
149150
def pytest_collection_modifyitems(session: Any, config: Any, items: Any) -> None:
150151
for item in items:
151152
item._nodeid = "\n" + item.nodeid + "\n"
153+
154+
155+
def queue_entries_are_cleaned_up(dbos: DBOS) -> bool:
156+
max_tries = 10
157+
success = False
158+
for i in range(max_tries):
159+
with dbos._sys_db.engine.begin() as c:
160+
query = sa.select(sa.func.count()).select_from(SystemSchema.workflow_queue)
161+
row = c.execute(query).fetchone()
162+
assert row is not None
163+
count = row[0]
164+
if count == 0:
165+
success = True
166+
break
167+
time.sleep(1)
168+
return success

0 commit comments

Comments
 (0)