Skip to content

Commit ac04592

Browse files
committed
tests
1 parent 759b3f7 commit ac04592

File tree

1 file changed

+58
-21
lines changed

1 file changed

+58
-21
lines changed

packages/postgres-database/tests/test_base_repo.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
get_or_create_connection,
1313
transaction_context,
1414
)
15+
from simcore_postgres_database.models.tags import tags
1516
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
1617

1718

@@ -24,9 +25,13 @@ class OneResourceRepoDemo:
2425
# This is a PROTOTYPE of how one could implement a generic
2526
# repo that provides CRUD operations providing a given table
2627
def __init__(self, engine: AsyncEngine, table: sa.Table):
27-
self.engine = engine
28+
if "id" not in table.columns:
29+
msg = "id column expected"
30+
raise ValueError(msg)
2831
self.table = table
2932

33+
self.engine = engine
34+
3035
async def create(self, connection: AsyncConnection | None = None, **kwargs) -> int:
3136
async with transaction_context(self.engine, connection) as conn:
3237
result = await conn.execute(self.table.insert().values(**kwargs))
@@ -35,7 +40,10 @@ async def create(self, connection: AsyncConnection | None = None, **kwargs) -> i
3540
return result.inserted_primary_key[0]
3641

3742
async def get_by_id(
38-
self, record_id: int, connection: AsyncConnection | None = None
43+
self,
44+
connection: AsyncConnection | None = None,
45+
*,
46+
record_id: int,
3947
) -> dict[str, Any] | None:
4048
async with get_or_create_connection(self.engine, connection) as conn:
4149
result = await conn.execute(
@@ -45,7 +53,11 @@ async def get_by_id(
4553
return dict(record) if record else None
4654

4755
async def get_page(
48-
self, limit: int, offset: int, connection: AsyncConnection | None = None
56+
self,
57+
connection: AsyncConnection | None = None,
58+
*,
59+
limit: int,
60+
offset: int,
4961
) -> _PageDict:
5062
async with get_or_create_connection(self.engine, connection) as conn:
5163
# Compute total count
@@ -61,7 +73,11 @@ async def get_page(
6173
return _PageDict(total_count=total_count or 0, rows=records)
6274

6375
async def update(
64-
self, record_id: int, connection: AsyncConnection | None = None, **values
76+
self,
77+
connection: AsyncConnection | None = None,
78+
*,
79+
record_id: int,
80+
**values,
6581
) -> bool:
6682
async with transaction_context(self.engine, connection) as conn:
6783
result = await conn.execute(
@@ -71,7 +87,10 @@ async def update(
7187
return result.rowcount > 0
7288

7389
async def delete(
74-
self, record_id: int, connection: AsyncConnection | None = None
90+
self,
91+
connection: AsyncConnection | None = None,
92+
*,
93+
record_id: int,
7594
) -> bool:
7695
async with transaction_context(self.engine, connection) as conn:
7796
result = await conn.execute(
@@ -81,30 +100,48 @@ async def delete(
81100
return result.rowcount > 0
82101

83102

84-
@pytest.mark.skip()
85-
async def test_sqlachemy_asyncio_example(asyncpg_engine: AsyncEngine):
103+
async def test_transaction_context(asyncpg_engine: AsyncEngine):
86104
#
87-
# Same example as in https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
88-
# but using `t1_repo`
89-
#
90-
meta = sa.MetaData()
91-
t1 = sa.Table("t1", meta, sa.Column("name", sa.String(50), primary_key=True))
105+
# Similar to example in https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
106+
# using tags
92107

93-
t1_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=t1)
108+
tags_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=tags)
94109

95-
async with transaction_context(asyncpg_engine) as conn:
110+
# (1) Using transaction_context and fails
111+
fake_error_msg = "some error"
96112

97-
await conn.run_sync(meta.drop_all)
98-
await conn.run_sync(meta.create_all)
113+
def _something_raises_here():
114+
raise RuntimeError(fake_error_msg)
99115

100-
await t1_repo.create(conn, name="some name 1")
101-
await t1_repo.create(conn, name="some name 2")
116+
async def _create_blue_like_tags(conn):
117+
async with conn.begin(): # NOTE: embedded transaction here!
118+
await tags_repo.create(conn, name="cyan tag", color="cyan")
119+
_something_raises_here()
120+
await tags_repo.create(conn, name="violet tag", color="violet")
102121

103-
async with transaction_context(asyncpg_engine) as conn:
104-
page = await t1_repo.get_page(limit=50, offset=0, connection=conn)
122+
async def _create_four_tags(conn):
123+
await tags_repo.create(conn, name="red tag", color="red")
124+
await _create_blue_like_tags(conn)
125+
await tags_repo.create(conn, name="green tag", color="green")
126+
127+
with pytest.raises(RuntimeError, match=fake_error_msg):
128+
async with transaction_context(asyncpg_engine) as conn:
129+
await _create_four_tags(conn)
130+
131+
page = await tags_repo.get_page(limit=50, offset=0)
132+
assert page["total_count"] == 0, "Transaction did not happen"
105133

134+
# (2) using internal connections
135+
await tags_repo.create(name="blue tag", color="blue")
136+
await tags_repo.create(conn, name="red tag", color="red")
137+
page = await tags_repo.get_page(limit=50, offset=0)
138+
assert page["total_count"] == 2
139+
140+
# (3) using external embedded
141+
async with transaction_context(asyncpg_engine) as conn:
142+
page = await tags_repo.get_page(conn, limit=50, offset=0)
106143
assert page["total_count"] == 2
107144

108145
# select a Result, which will be delivered with buffered results
109-
result = await conn.execute(sa.select(t1).where(t1.c.name == "some name 1"))
146+
result = await conn.execute(sa.select(tags).where(tags.c.name == "some name 1"))
110147
assert result.fetchall()

0 commit comments

Comments
 (0)