Skip to content

Commit 3ede142

Browse files
committed
tests on new utils
1 parent 0fc59fa commit 3ede142

File tree

1 file changed

+213
-0
lines changed

1 file changed

+213
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# pylint: disable=redefined-outer-name
2+
# pylint: disable=unused-argument
3+
# pylint: disable=unused-variable
4+
# pylint: disable=too-many-arguments
5+
6+
7+
from typing import Any, NamedTuple
8+
9+
import pytest
10+
import sqlalchemy as sa
11+
from simcore_postgres_database.models.tags import tags
12+
from simcore_postgres_database.utils_repos import (
13+
get_or_create_connection,
14+
transaction_context,
15+
)
16+
from sqlalchemy.exc import IntegrityError
17+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
18+
19+
20+
async def test_sa_transactions(asyncpg_engine: AsyncEngine):
21+
#
22+
# SEE https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
23+
#
24+
25+
# READ query
26+
total_count_query = sa.select(sa.func.count()).select_from(tags)
27+
28+
# WRITE queries
29+
query1 = (
30+
tags.insert().values(id=2, name="query1", color="blue").returning(tags.c.id)
31+
)
32+
query11 = (
33+
tags.insert().values(id=3, name="query11", color="blue").returning(tags.c.id)
34+
)
35+
query12 = (
36+
tags.insert().values(id=5, name="query12", color="blue").returning(tags.c.id)
37+
)
38+
query2 = (
39+
tags.insert().values(id=6, name="query2", color="blue").returning(tags.c.id)
40+
)
41+
query2 = (
42+
tags.insert().values(id=7, name="query2", color="blue").returning(tags.c.id)
43+
)
44+
45+
async with asyncpg_engine.connect() as conn, conn.begin(): # starts transaction (savepoint)
46+
47+
result = await conn.execute(query1)
48+
assert result.scalar() == 2
49+
50+
total_count = (await conn.execute(total_count_query)).scalar()
51+
assert total_count == 1
52+
53+
rows = (await conn.execute(tags.select().where(tags.c.id == 2))).fetchall()
54+
assert rows
55+
assert rows[0].id == 2
56+
57+
async with conn.begin_nested(): # savepoint
58+
await conn.execute(query11)
59+
60+
with pytest.raises(IntegrityError):
61+
async with conn.begin_nested(): # savepoint
62+
await conn.execute(query11)
63+
64+
await conn.execute(query12)
65+
66+
total_count = (await conn.execute(total_count_query)).scalar()
67+
assert total_count == 3 # since query11 (second time) reverted!
68+
69+
await conn.execute(query2)
70+
71+
total_count = (await conn.execute(total_count_query)).scalar()
72+
assert total_count == 4
73+
74+
75+
class _PageTuple(NamedTuple):
76+
total_count: int
77+
rows: list[dict[str, Any]]
78+
79+
80+
class OneResourceRepoDemo:
81+
# This is a PROTOTYPE of how one could implement a generic
82+
# repo that provides CRUD operations providing a given table
83+
def __init__(self, engine: AsyncEngine, table: sa.Table):
84+
if "id" not in table.columns:
85+
msg = "id column expected"
86+
raise ValueError(msg)
87+
self.table = table
88+
89+
self.engine = engine
90+
91+
async def create(self, connection: AsyncConnection | None = None, **kwargs) -> int:
92+
async with transaction_context(self.engine, connection) as conn:
93+
result = await conn.execute(self.table.insert().values(**kwargs))
94+
assert result # nosec
95+
return result.inserted_primary_key[0]
96+
97+
async def get_by_id(
98+
self,
99+
connection: AsyncConnection | None = None,
100+
*,
101+
row_id: int,
102+
) -> dict[str, Any] | None:
103+
async with get_or_create_connection(self.engine, connection) as conn:
104+
result = await conn.execute(
105+
sa.select(self.table).where(self.table.c.id == row_id)
106+
)
107+
row = result.mappings().fetchone()
108+
return dict(row) if row else None
109+
110+
async def get_page(
111+
self,
112+
connection: AsyncConnection | None = None,
113+
*,
114+
limit: int,
115+
offset: int = 0,
116+
) -> _PageTuple:
117+
async with get_or_create_connection(self.engine, connection) as conn:
118+
# Compute total count
119+
total_count_query = sa.select(sa.func.count()).select_from(self.table)
120+
total_count_result = await conn.execute(total_count_query)
121+
total_count = total_count_result.scalar()
122+
123+
# Fetch paginated results
124+
query = sa.select(self.table).limit(limit).offset(offset)
125+
result = await conn.execute(query)
126+
rows = [dict(row) for row in result.mappings().fetchall()]
127+
128+
return _PageTuple(total_count=total_count or 0, rows=rows)
129+
130+
async def update(
131+
self,
132+
connection: AsyncConnection | None = None,
133+
*,
134+
row_id: int,
135+
**values,
136+
) -> bool:
137+
async with transaction_context(self.engine, connection) as conn:
138+
result = await conn.execute(
139+
self.table.update().where(self.table.c.id == row_id).values(**values)
140+
)
141+
return result.rowcount > 0
142+
143+
async def delete(
144+
self,
145+
connection: AsyncConnection | None = None,
146+
*,
147+
row_id: int,
148+
) -> bool:
149+
async with transaction_context(self.engine, connection) as conn:
150+
result = await conn.execute(
151+
self.table.delete().where(self.table.c.id == row_id)
152+
)
153+
return result.rowcount > 0
154+
155+
156+
async def test_oneresourcerepodemo_prototype(asyncpg_engine: AsyncEngine):
157+
158+
tags_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=tags)
159+
160+
# create
161+
tag_id = await tags_repo.create(name="cyan tag", color="cyan")
162+
assert tag_id > 0
163+
164+
# get, list
165+
tag = await tags_repo.get_by_id(row_id=tag_id)
166+
assert tag
167+
168+
page = await tags_repo.get_page(limit=10)
169+
assert page.total_count == 1
170+
assert page.rows == [tag]
171+
172+
# update
173+
ok = await tags_repo.update(row_id=tag_id, name="changed name")
174+
assert ok
175+
176+
updated_tag = await tags_repo.get_by_id(row_id=tag_id)
177+
assert updated_tag
178+
assert updated_tag["name"] != tag["name"]
179+
180+
# delete
181+
ok = await tags_repo.delete(row_id=tag_id)
182+
assert ok
183+
184+
assert not await tags_repo.get_by_id(row_id=tag_id)
185+
186+
187+
async def test_transaction_context(asyncpg_engine: AsyncEngine):
188+
# (1) Using transaction_context and fails
189+
fake_error_msg = "some error"
190+
191+
def _something_raises_here():
192+
raise RuntimeError(fake_error_msg)
193+
194+
tags_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=tags)
195+
196+
# using external transaction_context: commits upon __aexit__
197+
async with transaction_context(asyncpg_engine) as conn:
198+
await tags_repo.create(conn, name="cyan tag", color="cyan")
199+
await tags_repo.create(conn, name="red tag", color="red")
200+
assert (await tags_repo.get_page(conn, limit=10, offset=0)).total_count == 2
201+
202+
# using internal: auto-commit
203+
await tags_repo.create(name="red tag", color="red")
204+
assert (await tags_repo.get_page(limit=10, offset=0)).total_count == 3
205+
206+
# auto-rollback
207+
with pytest.raises(RuntimeError, match=fake_error_msg): # noqa: PT012
208+
async with transaction_context(asyncpg_engine) as conn:
209+
await tags_repo.create(conn, name="violet tag", color="violet")
210+
assert (await tags_repo.get_page(conn, limit=10, offset=0)).total_count == 4
211+
_something_raises_here()
212+
213+
assert (await tags_repo.get_page(limit=10, offset=0)).total_count == 3

0 commit comments

Comments
 (0)