Skip to content

Commit cbfe3bf

Browse files
authored
DBOS Client (#265)
DBOS Client, but this time in Python!
1 parent d358c49 commit cbfe3bf

File tree

12 files changed

+760
-52
lines changed

12 files changed

+760
-52
lines changed

dbos/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from . import _error as error
2+
from ._client import DBOSClient, EnqueueOptions
23
from ._context import DBOSContextEnsure, DBOSContextSetAuth, SetWorkflowID
34
from ._dbos import DBOS, DBOSConfiguredInstance, WorkflowHandle
45
from ._dbos_config import ConfigFile, DBOSConfig, get_dbos_database_url, load_config
@@ -11,9 +12,11 @@
1112
"ConfigFile",
1213
"DBOSConfig",
1314
"DBOS",
15+
"DBOSClient",
1416
"DBOSConfiguredInstance",
1517
"DBOSContextEnsure",
1618
"DBOSContextSetAuth",
19+
"EnqueueOptions",
1720
"GetWorkflowsInput",
1821
"KafkaMessage",
1922
"SetWorkflowID",

dbos/_app_db.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,18 @@ class RecordedResult(TypedDict):
2727

2828
class ApplicationDatabase:
2929

30-
def __init__(self, config: ConfigFile, *, debug_mode: bool = False):
31-
self.config = config
30+
def __init__(self, database: DatabaseConfig, *, debug_mode: bool = False):
3231

33-
app_db_name = config["database"]["app_db_name"]
32+
app_db_name = database["app_db_name"]
3433

3534
# If the application database does not already exist, create it
3635
if not debug_mode:
3736
postgres_db_url = sa.URL.create(
3837
"postgresql+psycopg",
39-
username=config["database"]["username"],
40-
password=config["database"]["password"],
41-
host=config["database"]["hostname"],
42-
port=config["database"]["port"],
38+
username=database["username"],
39+
password=database["password"],
40+
host=database["hostname"],
41+
port=database["port"],
4342
database="postgres",
4443
)
4544
postgres_db_engine = sa.create_engine(postgres_db_url)
@@ -55,25 +54,25 @@ def __init__(self, config: ConfigFile, *, debug_mode: bool = False):
5554
# Create a connection pool for the application database
5655
app_db_url = sa.URL.create(
5756
"postgresql+psycopg",
58-
username=config["database"]["username"],
59-
password=config["database"]["password"],
60-
host=config["database"]["hostname"],
61-
port=config["database"]["port"],
57+
username=database["username"],
58+
password=database["password"],
59+
host=database["hostname"],
60+
port=database["port"],
6261
database=app_db_name,
6362
)
6463

6564
connect_args = {}
6665
if (
67-
"connectionTimeoutMillis" in config["database"]
68-
and config["database"]["connectionTimeoutMillis"]
66+
"connectionTimeoutMillis" in database
67+
and database["connectionTimeoutMillis"]
6968
):
7069
connect_args["connect_timeout"] = int(
71-
config["database"]["connectionTimeoutMillis"] / 1000
70+
database["connectionTimeoutMillis"] / 1000
7271
)
7372

7473
self.engine = sa.create_engine(
7574
app_db_url,
76-
pool_size=config["database"]["app_db_pool_size"],
75+
pool_size=database["app_db_pool_size"],
7776
max_overflow=0,
7877
pool_timeout=30,
7978
connect_args=connect_args,

dbos/_client.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import asyncio
2+
import sys
3+
import uuid
4+
from typing import Any, Generic, Optional, TypedDict, TypeVar
5+
6+
if sys.version_info < (3, 11):
7+
from typing_extensions import NotRequired
8+
else:
9+
from typing import NotRequired
10+
11+
from dbos import _serialization
12+
from dbos._dbos import WorkflowHandle, WorkflowHandleAsync
13+
from dbos._dbos_config import parse_database_url_to_dbconfig
14+
from dbos._error import DBOSNonExistentWorkflowError
15+
from dbos._registrations import DEFAULT_MAX_RECOVERY_ATTEMPTS
16+
from dbos._serialization import WorkflowInputs
17+
from dbos._sys_db import SystemDatabase, WorkflowStatusInternal, WorkflowStatusString
18+
from dbos._workflow_commands import WorkflowStatus, get_workflow
19+
20+
R = TypeVar("R", covariant=True) # A generic type for workflow return values
21+
22+
23+
class EnqueueOptions(TypedDict):
24+
workflow_name: str
25+
workflow_class_name: NotRequired[str]
26+
queue_name: str
27+
app_version: NotRequired[str]
28+
workflow_id: NotRequired[str]
29+
30+
31+
class WorkflowHandleClientPolling(Generic[R]):
32+
33+
def __init__(self, workflow_id: str, sys_db: SystemDatabase):
34+
self.workflow_id = workflow_id
35+
self._sys_db = sys_db
36+
37+
def get_workflow_id(self) -> str:
38+
return self.workflow_id
39+
40+
def get_result(self) -> R:
41+
res: R = self._sys_db.await_workflow_result(self.workflow_id)
42+
return res
43+
44+
def get_status(self) -> "WorkflowStatus":
45+
status = get_workflow(self._sys_db, self.workflow_id, True)
46+
if status is None:
47+
raise DBOSNonExistentWorkflowError(self.workflow_id)
48+
return status
49+
50+
51+
class WorkflowHandleClientAsyncPolling(Generic[R]):
52+
53+
def __init__(self, workflow_id: str, sys_db: SystemDatabase):
54+
self.workflow_id = workflow_id
55+
self._sys_db = sys_db
56+
57+
def get_workflow_id(self) -> str:
58+
return self.workflow_id
59+
60+
async def get_result(self) -> R:
61+
res: R = await asyncio.to_thread(
62+
self._sys_db.await_workflow_result, self.workflow_id
63+
)
64+
return res
65+
66+
async def get_status(self) -> "WorkflowStatus":
67+
status = await asyncio.to_thread(
68+
get_workflow, self._sys_db, self.workflow_id, True
69+
)
70+
if status is None:
71+
raise DBOSNonExistentWorkflowError(self.workflow_id)
72+
return status
73+
74+
75+
class DBOSClient:
76+
def __init__(self, database_url: str, *, system_database: Optional[str] = None):
77+
db_config = parse_database_url_to_dbconfig(database_url)
78+
if system_database is not None:
79+
db_config["sys_db_name"] = system_database
80+
self._sys_db = SystemDatabase(db_config)
81+
82+
def destroy(self) -> None:
83+
self._sys_db.destroy()
84+
85+
def _enqueue(self, options: EnqueueOptions, *args: Any, **kwargs: Any) -> str:
86+
workflow_name = options["workflow_name"]
87+
queue_name = options["queue_name"]
88+
89+
workflow_class_name = options.get("workflow_class_name")
90+
app_version = options.get("app_version")
91+
max_recovery_attempts = options.get("max_recovery_attempts")
92+
if max_recovery_attempts is None:
93+
max_recovery_attempts = DEFAULT_MAX_RECOVERY_ATTEMPTS
94+
workflow_id = options.get("workflow_id")
95+
if workflow_id is None:
96+
workflow_id = str(uuid.uuid4())
97+
98+
status: WorkflowStatusInternal = {
99+
"workflow_uuid": workflow_id,
100+
"status": WorkflowStatusString.ENQUEUED.value,
101+
"name": workflow_name,
102+
"class_name": workflow_class_name,
103+
"queue_name": queue_name,
104+
"app_version": app_version,
105+
"config_name": None,
106+
"authenticated_user": None,
107+
"assumed_role": None,
108+
"authenticated_roles": None,
109+
"request": None,
110+
"output": None,
111+
"error": None,
112+
"created_at": None,
113+
"updated_at": None,
114+
"executor_id": None,
115+
"recovery_attempts": None,
116+
"app_id": None,
117+
}
118+
119+
inputs: WorkflowInputs = {
120+
"args": args,
121+
"kwargs": kwargs,
122+
}
123+
124+
wf_status = self._sys_db.insert_workflow_status(status)
125+
self._sys_db.update_workflow_inputs(
126+
workflow_id, _serialization.serialize_args(inputs)
127+
)
128+
if wf_status == WorkflowStatusString.ENQUEUED.value:
129+
self._sys_db.enqueue(workflow_id, queue_name)
130+
return workflow_id
131+
132+
def enqueue(
133+
self, options: EnqueueOptions, *args: Any, **kwargs: Any
134+
) -> WorkflowHandle[R]:
135+
workflow_id = self._enqueue(options, *args, **kwargs)
136+
return WorkflowHandleClientPolling[R](workflow_id, self._sys_db)
137+
138+
async def enqueue_async(
139+
self, options: EnqueueOptions, *args: Any, **kwargs: Any
140+
) -> WorkflowHandleAsync[R]:
141+
workflow_id = await asyncio.to_thread(self._enqueue, options, *args, **kwargs)
142+
return WorkflowHandleClientAsyncPolling[R](workflow_id, self._sys_db)
143+
144+
def retrieve_workflow(self, workflow_id: str) -> WorkflowHandle[R]:
145+
status = get_workflow(self._sys_db, workflow_id, True)
146+
if status is None:
147+
raise DBOSNonExistentWorkflowError(workflow_id)
148+
return WorkflowHandleClientPolling[R](workflow_id, self._sys_db)
149+
150+
async def retrieve_workflow_async(self, workflow_id: str) -> WorkflowHandleAsync[R]:
151+
status = asyncio.to_thread(get_workflow, self._sys_db, workflow_id, True)
152+
if status is None:
153+
raise DBOSNonExistentWorkflowError(workflow_id)
154+
return WorkflowHandleClientAsyncPolling[R](workflow_id, self._sys_db)
155+
156+
def send(
157+
self,
158+
destination_id: str,
159+
message: Any,
160+
topic: Optional[str] = None,
161+
idempotency_key: Optional[str] = None,
162+
) -> None:
163+
idempotency_key = idempotency_key if idempotency_key else str(uuid.uuid4())
164+
status: WorkflowStatusInternal = {
165+
"workflow_uuid": f"{destination_id}-{idempotency_key}",
166+
"status": WorkflowStatusString.SUCCESS.value,
167+
"name": "temp_workflow-send-client",
168+
"class_name": None,
169+
"queue_name": None,
170+
"config_name": None,
171+
"authenticated_user": None,
172+
"assumed_role": None,
173+
"authenticated_roles": None,
174+
"request": None,
175+
"output": None,
176+
"error": None,
177+
"created_at": None,
178+
"updated_at": None,
179+
"executor_id": None,
180+
"recovery_attempts": None,
181+
"app_id": None,
182+
"app_version": None,
183+
}
184+
self._sys_db.insert_workflow_status(status)
185+
self._sys_db.send(status["workflow_uuid"], 0, destination_id, message, topic)
186+
187+
async def send_async(
188+
self,
189+
destination_id: str,
190+
message: Any,
191+
topic: Optional[str] = None,
192+
idempotency_key: Optional[str] = None,
193+
) -> None:
194+
return await asyncio.to_thread(
195+
self.send, destination_id, message, topic, idempotency_key
196+
)
197+
198+
def get_event(self, workflow_id: str, key: str, timeout_seconds: float = 60) -> Any:
199+
return self._sys_db.get_event(workflow_id, key, timeout_seconds)
200+
201+
async def get_event_async(
202+
self, workflow_id: str, key: str, timeout_seconds: float = 60
203+
) -> Any:
204+
return await asyncio.to_thread(
205+
self.get_event, workflow_id, key, timeout_seconds
206+
)

dbos/_dbos.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,12 @@ def _launch(self, *, debug_mode: bool = False) -> None:
445445
dbos_logger.info(f"Executor ID: {GlobalParams.executor_id}")
446446
dbos_logger.info(f"Application version: {GlobalParams.app_version}")
447447
self._executor_field = ThreadPoolExecutor(max_workers=64)
448-
self._sys_db_field = SystemDatabase(self.config, debug_mode=debug_mode)
449-
self._app_db_field = ApplicationDatabase(self.config, debug_mode=debug_mode)
448+
self._sys_db_field = SystemDatabase(
449+
self.config["database"], debug_mode=debug_mode
450+
)
451+
self._app_db_field = ApplicationDatabase(
452+
self.config["database"], debug_mode=debug_mode
453+
)
450454

451455
if debug_mode:
452456
return

dbos/_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
8585
for _, queue in dbos._registry.queue_info_map.items():
8686
try:
8787
wf_ids = dbos._sys_db.start_queued_workflows(
88-
queue, GlobalParams.executor_id
88+
queue, GlobalParams.executor_id, GlobalParams.app_version
8989
)
9090
for id in wf_ids:
9191
execute_workflow_by_id(dbos, id)

0 commit comments

Comments
 (0)