Skip to content

Commit 023c748

Browse files
authored
Add a test for SQLAlchemy ORM (#199)
1 parent 75c5c57 commit 023c748

File tree

1 file changed

+113
-0
lines changed

1 file changed

+113
-0
lines changed

tests/test_sqlalchemy.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import uuid
2+
3+
import pytest
4+
import sqlalchemy as sa
5+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
6+
7+
# Public API
8+
from dbos import DBOS, SetWorkflowID
9+
10+
11+
# Declare a SQLAlchemy ORM base class
12+
class Base(DeclarativeBase):
13+
pass
14+
15+
16+
# Declare a SQLAlchemy ORM class for accessing the database table.
17+
class Hello(Base):
18+
__tablename__ = "dbos_hello"
19+
greet_count: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
20+
name: Mapped[str] = mapped_column(nullable=False)
21+
22+
def __repr__(self) -> str:
23+
return f"Hello(greet_count={self.greet_count!r}, name={self.name!r})"
24+
25+
26+
def test_simple_transaction(dbos: DBOS, postgres_db_engine: sa.Engine) -> None:
27+
txn_counter: int = 0
28+
assert dbos._app_db_field is not None
29+
Base.metadata.drop_all(dbos._app_db_field.engine)
30+
Base.metadata.create_all(dbos._app_db_field.engine)
31+
32+
@DBOS.transaction()
33+
def test_transaction(name: str) -> str:
34+
new_greeting = Hello(name=name)
35+
DBOS.sql_session.add(new_greeting)
36+
stmt = (
37+
sa.select(Hello)
38+
.where(Hello.name == name)
39+
.order_by(Hello.greet_count.desc())
40+
.limit(1)
41+
)
42+
row = DBOS.sql_session.scalar(stmt)
43+
assert row is not None
44+
greet_count = row.greet_count
45+
nonlocal txn_counter
46+
txn_counter += 1
47+
return name + str(greet_count)
48+
49+
assert test_transaction("alice") == "alice1"
50+
assert test_transaction("alice") == "alice2"
51+
assert txn_counter == 2
52+
53+
# Test OAOO
54+
wfuuid = str(uuid.uuid4())
55+
with SetWorkflowID(wfuuid):
56+
assert test_transaction("alice") == "alice3"
57+
with SetWorkflowID(wfuuid):
58+
assert test_transaction("alice") == "alice3"
59+
assert txn_counter == 3 # Only increment once
60+
61+
Base.metadata.drop_all(dbos._app_db_field.engine)
62+
63+
# Make sure no transactions are left open
64+
with postgres_db_engine.begin() as conn:
65+
result = conn.execute(
66+
sa.text(
67+
"select * from pg_stat_activity where state = 'idle in transaction'"
68+
)
69+
).fetchall()
70+
assert len(result) == 0
71+
72+
73+
def test_error_transaction(dbos: DBOS, postgres_db_engine: sa.Engine) -> None:
74+
txn_counter: int = 0
75+
assert dbos._app_db_field is not None
76+
# Drop the database but don't re-create. Should fail.
77+
Base.metadata.drop_all(dbos._app_db_field.engine)
78+
79+
@DBOS.transaction()
80+
def test_transaction(name: str) -> str:
81+
nonlocal txn_counter
82+
txn_counter += 1
83+
new_greeting = Hello(name=name)
84+
DBOS.sql_session.add(new_greeting)
85+
return name
86+
87+
with pytest.raises(Exception) as exc_info:
88+
test_transaction("alice")
89+
assert 'relation "dbos_hello" does not exist' in str(exc_info.value)
90+
assert txn_counter == 1
91+
92+
# Test OAOO
93+
wfuuid = str(uuid.uuid4())
94+
with SetWorkflowID(wfuuid):
95+
with pytest.raises(Exception) as exc_info:
96+
test_transaction("alice")
97+
assert 'relation "dbos_hello" does not exist' in str(exc_info.value)
98+
assert txn_counter == 2
99+
100+
with SetWorkflowID(wfuuid):
101+
with pytest.raises(Exception) as exc_info:
102+
test_transaction("alice")
103+
assert 'relation "dbos_hello" does not exist' in str(exc_info.value)
104+
assert txn_counter == 2
105+
106+
# Make sure no transactions are left open
107+
with postgres_db_engine.begin() as conn:
108+
result = conn.execute(
109+
sa.text(
110+
"select * from pg_stat_activity where state = 'idle in transaction'"
111+
)
112+
).fetchall()
113+
assert len(result) == 0

0 commit comments

Comments
 (0)