1+ from abc import ABC , abstractmethod
12from typing import Any , Dict , List , Optional , TypedDict
23
4+ import psycopg
35import sqlalchemy as sa
4- import sqlalchemy .dialects .postgresql as pg
56from sqlalchemy import inspect , text
67from sqlalchemy .exc import DBAPIError
78from sqlalchemy .orm import Session , sessionmaker
89
10+ from dbos ._migration import get_sqlite_timestamp_expr
11+
912from . import _serialization
1013from ._error import DBOSUnexpectedStepError , DBOSWorkflowConflictIDError
1114from ._logger import dbos_logger
@@ -29,7 +32,7 @@ class RecordedResult(TypedDict):
2932 error : Optional [str ] # JSON (jsonpickle)
3033
3134
32- class ApplicationDatabase :
35+ class ApplicationDatabase ( ABC ) :
3336
3437 def __init__ (
3538 self ,
@@ -38,95 +41,37 @@ def __init__(
3841 engine_kwargs : Dict [str , Any ],
3942 debug_mode : bool = False ,
4043 ):
41- app_db_url = sa .make_url (database_url ).set (drivername = "postgresql+psycopg" )
42-
43- if engine_kwargs is None :
44- engine_kwargs = {}
45-
46- self .engine = sa .create_engine (
47- app_db_url ,
48- ** engine_kwargs ,
49- )
44+ self .engine = self ._create_engine (database_url , engine_kwargs )
5045 self ._engine_kwargs = engine_kwargs
5146 self .sessionmaker = sessionmaker (bind = self .engine )
5247 self .debug_mode = debug_mode
5348
54- def run_migrations (self ) -> None :
55- if self .debug_mode :
56- dbos_logger .warning (
57- "Application database migrations are skipped in debug mode."
58- )
59- return
60- # Check if the database exists
61- app_db_url = self .engine .url
62- postgres_db_engine = sa .create_engine (
63- app_db_url .set (database = "postgres" ),
64- ** self ._engine_kwargs ,
65- )
66- with postgres_db_engine .connect () as conn :
67- conn .execution_options (isolation_level = "AUTOCOMMIT" )
68- if not conn .execute (
69- sa .text ("SELECT 1 FROM pg_database WHERE datname=:db_name" ),
70- parameters = {"db_name" : app_db_url .database },
71- ).scalar ():
72- conn .execute (sa .text (f"CREATE DATABASE { app_db_url .database } " ))
73- postgres_db_engine .dispose ()
74-
75- # Create the dbos schema and transaction_outputs table in the application database
76- with self .engine .begin () as conn :
77- # Check if schema exists first
78- schema_exists = conn .execute (
79- sa .text (
80- "SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema_name"
81- ),
82- parameters = {"schema_name" : ApplicationSchema .schema },
83- ).scalar ()
84-
85- if not schema_exists :
86- schema_creation_query = sa .text (
87- f"CREATE SCHEMA { ApplicationSchema .schema } "
88- )
89- conn .execute (schema_creation_query )
90-
91- inspector = inspect (self .engine )
92- if not inspector .has_table (
93- "transaction_outputs" , schema = ApplicationSchema .schema
94- ):
95- ApplicationSchema .metadata_obj .create_all (self .engine )
96- else :
97- columns = inspector .get_columns (
98- "transaction_outputs" , schema = ApplicationSchema .schema
99- )
100- column_names = [col ["name" ] for col in columns ]
49+ @abstractmethod
50+ def _create_engine (
51+ self , database_url : str , engine_kwargs : Dict [str , Any ]
52+ ) -> sa .Engine :
53+ """Create a database engine specific to the database type."""
54+ pass
10155
102- if "function_name" not in column_names :
103- # Column missing, alter table to add it
104- with self .engine .connect () as conn :
105- conn .execute (
106- text (
107- f"""
108- ALTER TABLE { ApplicationSchema .schema } .transaction_outputs
109- ADD COLUMN function_name TEXT NOT NULL DEFAULT '';
110- """
111- )
112- )
113- conn .commit ()
56+ @abstractmethod
57+ def run_migrations (self ) -> None :
58+ """Run database migrations specific to the database type."""
59+ pass
11460
11561 def destroy (self ) -> None :
11662 self .engine .dispose ()
11763
118- @staticmethod
11964 def record_transaction_output (
120- session : Session , output : TransactionResultInternal
65+ self , session : Session , output : TransactionResultInternal
12166 ) -> None :
12267 try :
12368 session .execute (
124- pg .insert (ApplicationSchema .transaction_outputs ).values (
69+ sa .insert (ApplicationSchema .transaction_outputs ).values (
12570 workflow_uuid = output ["workflow_uuid" ],
12671 function_id = output ["function_id" ],
12772 output = output ["output" ],
12873 error = None ,
129- txn_id = sa . text ( "(select pg_current_xact_id_if_assigned()::text)" ) ,
74+ txn_id = "" ,
13075 txn_snapshot = output ["txn_snapshot" ],
13176 executor_id = (
13277 output ["executor_id" ] if output ["executor_id" ] else None
@@ -135,7 +80,7 @@ def record_transaction_output(
13580 )
13681 )
13782 except DBAPIError as dbapi_error :
138- if dbapi_error . orig . sqlstate == "23505" : # type: ignore
83+ if self . _is_unique_constraint_violation ( dbapi_error ):
13984 raise DBOSWorkflowConflictIDError (output ["workflow_uuid" ])
14085 raise
14186
@@ -145,14 +90,12 @@ def record_transaction_error(self, output: TransactionResultInternal) -> None:
14590 try :
14691 with self .engine .begin () as conn :
14792 conn .execute (
148- pg .insert (ApplicationSchema .transaction_outputs ).values (
93+ sa .insert (ApplicationSchema .transaction_outputs ).values (
14994 workflow_uuid = output ["workflow_uuid" ],
15095 function_id = output ["function_id" ],
15196 output = None ,
15297 error = output ["error" ],
153- txn_id = sa .text (
154- "(select pg_current_xact_id_if_assigned()::text)"
155- ),
98+ txn_id = "" ,
15699 txn_snapshot = output ["txn_snapshot" ],
157100 executor_id = (
158101 output ["executor_id" ] if output ["executor_id" ] else None
@@ -161,7 +104,7 @@ def record_transaction_error(self, output: TransactionResultInternal) -> None:
161104 )
162105 )
163106 except DBAPIError as dbapi_error :
164- if dbapi_error . orig . sqlstate == "23505" : # type: ignore
107+ if self . _is_unique_constraint_violation ( dbapi_error ):
165108 raise DBOSWorkflowConflictIDError (output ["workflow_uuid" ])
166109 raise
167110
@@ -283,3 +226,197 @@ def garbage_collect(
283226 )
284227
285228 c .execute (delete_query )
229+
230+ @abstractmethod
231+ def _is_unique_constraint_violation (self , dbapi_error : DBAPIError ) -> bool :
232+ """Check if the error is a unique constraint violation."""
233+ pass
234+
235+ @abstractmethod
236+ def _is_serialization_error (self , dbapi_error : DBAPIError ) -> bool :
237+ """Check if the error is a serialization/concurrency error."""
238+ pass
239+
240+ @staticmethod
241+ def create (
242+ database_url : str ,
243+ engine_kwargs : Dict [str , Any ],
244+ debug_mode : bool = False ,
245+ ) -> "ApplicationDatabase" :
246+ """Factory method to create the appropriate ApplicationDatabase implementation based on URL."""
247+ if database_url .startswith ("sqlite" ):
248+ return SQLiteApplicationDatabase (
249+ database_url = database_url ,
250+ engine_kwargs = engine_kwargs ,
251+ debug_mode = debug_mode ,
252+ )
253+ else :
254+ # Default to PostgreSQL for postgresql://, postgres://, or other URLs
255+ return PostgresApplicationDatabase (
256+ database_url = database_url ,
257+ engine_kwargs = engine_kwargs ,
258+ debug_mode = debug_mode ,
259+ )
260+
261+
262+ class PostgresApplicationDatabase (ApplicationDatabase ):
263+ """PostgreSQL-specific implementation of ApplicationDatabase."""
264+
265+ def _create_engine (
266+ self , database_url : str , engine_kwargs : Dict [str , Any ]
267+ ) -> sa .Engine :
268+ """Create a PostgreSQL engine."""
269+ app_db_url = sa .make_url (database_url ).set (drivername = "postgresql+psycopg" )
270+
271+ if engine_kwargs is None :
272+ engine_kwargs = {}
273+
274+ # TODO: Make the schema dynamic so this isn't needed
275+ ApplicationSchema .transaction_outputs .schema = "dbos"
276+
277+ return sa .create_engine (
278+ app_db_url ,
279+ ** engine_kwargs ,
280+ )
281+
282+ def run_migrations (self ) -> None :
283+ if self .debug_mode :
284+ dbos_logger .warning (
285+ "Application database migrations are skipped in debug mode."
286+ )
287+ return
288+ # Check if the database exists
289+ app_db_url = self .engine .url
290+ postgres_db_engine = sa .create_engine (
291+ app_db_url .set (database = "postgres" ),
292+ ** self ._engine_kwargs ,
293+ )
294+ with postgres_db_engine .connect () as conn :
295+ conn .execution_options (isolation_level = "AUTOCOMMIT" )
296+ if not conn .execute (
297+ sa .text ("SELECT 1 FROM pg_database WHERE datname=:db_name" ),
298+ parameters = {"db_name" : app_db_url .database },
299+ ).scalar ():
300+ conn .execute (sa .text (f"CREATE DATABASE { app_db_url .database } " ))
301+ postgres_db_engine .dispose ()
302+
303+ # Create the dbos schema and transaction_outputs table in the application database
304+ with self .engine .begin () as conn :
305+ # Check if schema exists first
306+ schema_exists = conn .execute (
307+ sa .text (
308+ "SELECT 1 FROM information_schema.schemata WHERE schema_name = :schema_name"
309+ ),
310+ parameters = {"schema_name" : ApplicationSchema .schema },
311+ ).scalar ()
312+
313+ if not schema_exists :
314+ schema_creation_query = sa .text (
315+ f"CREATE SCHEMA { ApplicationSchema .schema } "
316+ )
317+ conn .execute (schema_creation_query )
318+
319+ inspector = inspect (self .engine )
320+ if not inspector .has_table (
321+ "transaction_outputs" , schema = ApplicationSchema .schema
322+ ):
323+ ApplicationSchema .metadata_obj .create_all (self .engine )
324+ else :
325+ columns = inspector .get_columns (
326+ "transaction_outputs" , schema = ApplicationSchema .schema
327+ )
328+ column_names = [col ["name" ] for col in columns ]
329+
330+ if "function_name" not in column_names :
331+ # Column missing, alter table to add it
332+ with self .engine .connect () as conn :
333+ conn .execute (
334+ text (
335+ f"""
336+ ALTER TABLE { ApplicationSchema .schema } .transaction_outputs
337+ ADD COLUMN function_name TEXT NOT NULL DEFAULT '';
338+ """
339+ )
340+ )
341+ conn .commit ()
342+
343+ def _is_unique_constraint_violation (self , dbapi_error : DBAPIError ) -> bool :
344+ """Check if the error is a unique constraint violation in PostgreSQL."""
345+ return dbapi_error .orig .sqlstate == "23505" # type: ignore
346+
347+ def _is_serialization_error (self , dbapi_error : DBAPIError ) -> bool :
348+ """Check if the error is a serialization/concurrency error in PostgreSQL."""
349+ # 40001: serialization_failure (MVCC conflict)
350+ # 40P01: deadlock_detected
351+ driver_error = dbapi_error .orig
352+ return (
353+ driver_error is not None
354+ and isinstance (driver_error , psycopg .OperationalError )
355+ and driver_error .sqlstate in ("40001" , "40P01" )
356+ )
357+
358+
359+ class SQLiteApplicationDatabase (ApplicationDatabase ):
360+ """SQLite-specific implementation of ApplicationDatabase."""
361+
362+ def _create_engine (
363+ self , database_url : str , engine_kwargs : Dict [str , Any ]
364+ ) -> sa .Engine :
365+ """Create a SQLite engine."""
366+ # TODO: Make the schema dynamic so this isn't needed
367+ ApplicationSchema .transaction_outputs .schema = None
368+ return sa .create_engine (database_url )
369+
370+ def run_migrations (self ) -> None :
371+ if self .debug_mode :
372+ dbos_logger .warning (
373+ "Application database migrations are skipped in debug mode."
374+ )
375+ return
376+
377+ with self .engine .begin () as conn :
378+ # Check if table exists
379+ result = conn .execute (
380+ sa .text (
381+ "SELECT name FROM sqlite_master WHERE type='table' AND name='transaction_outputs'"
382+ )
383+ ).fetchone ()
384+
385+ if result is None :
386+ # Create the table with proper SQLite syntax
387+ conn .execute (
388+ sa .text (
389+ f"""
390+ CREATE TABLE transaction_outputs (
391+ workflow_uuid TEXT NOT NULL,
392+ function_id INTEGER NOT NULL,
393+ output TEXT,
394+ error TEXT,
395+ txn_id TEXT,
396+ txn_snapshot TEXT NOT NULL,
397+ executor_id TEXT,
398+ function_name TEXT NOT NULL DEFAULT '',
399+ created_at BIGINT NOT NULL DEFAULT { get_sqlite_timestamp_expr ()} ,
400+ PRIMARY KEY (workflow_uuid, function_id)
401+ )
402+ """
403+ )
404+ )
405+ # Create the index
406+ conn .execute (
407+ sa .text (
408+ "CREATE INDEX transaction_outputs_created_at_index ON transaction_outputs (created_at)"
409+ )
410+ )
411+
412+ def _is_unique_constraint_violation (self , dbapi_error : DBAPIError ) -> bool :
413+ """Check if the error is a unique constraint violation in SQLite."""
414+ return "UNIQUE constraint failed" in str (dbapi_error .orig )
415+
416+ def _is_serialization_error (self , dbapi_error : DBAPIError ) -> bool :
417+ """Check if the error is a serialization/concurrency error in SQLite."""
418+ # SQLite database is locked or busy errors
419+ error_msg = str (dbapi_error .orig ).lower ()
420+ return (
421+ "database is locked" in error_msg or "database table is locked" in error_msg
422+ )
0 commit comments