Skip to content

Commit be47d2a

Browse files
committed
fixes test_base_repo
1 parent e91ee16 commit be47d2a

File tree

2 files changed

+92
-81
lines changed

2 files changed

+92
-81
lines changed

packages/postgres-database/src/simcore_postgres_database/base_repo.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ async def transaction_context(
2727
):
2828
async with get_or_create_connection(engine, connection) as conn:
2929
if conn.in_transaction():
30-
# FIXME: should not extend nested? async with conn.begin_nested():
31-
# depends on the analysis of test_sa_transactions
32-
# might need another function that produces a transaction ONLY
33-
yield conn
30+
async with conn.begin_nested(): # savepoint
31+
yield conn
3432
else:
3533
try:
3634
async with conn.begin():

packages/postgres-database/tests/test_base_repo.py

Lines changed: 90 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# pylint: disable=too-many-arguments
55

66

7-
from typing import Any, TypedDict
7+
from typing import Any, NamedTuple
88

99
import pytest
1010
import sqlalchemy as sa
@@ -13,6 +13,7 @@
1313
transaction_context,
1414
)
1515
from simcore_postgres_database.models.tags import tags
16+
from sqlalchemy.exc import IntegrityError
1617
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
1718

1819

@@ -21,55 +22,57 @@ async def test_sa_transactions(asyncpg_engine: AsyncEngine):
2122
# SEE https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
2223
#
2324

25+
# READ query
2426
total_count_query = sa.select(sa.func.count()).select_from(tags)
2527

2628
# WRITE queries
27-
query1 = tags.insert().values(name="query1", color="blue")
28-
query11 = tags.insert().values(name="query11", color="blue")
29-
query111 = tags.insert().values(name="query111", color="blue")
30-
query1111 = tags.insert().values(name="query1111", color="blue")
31-
query112 = tags.insert().values(name="query112", color="blue")
32-
query12 = tags.insert().values(name="query12", color="blue")
33-
query2 = tags.insert().values(name="query2", color="blue")
29+
query1 = (
30+
tags.insert().values(id=2, name="query1", color="blue").returning(tags.c.id)
31+
)
32+
query11 = (
33+
tags.insert().values(id=3, name="query11", color="blue").returning(tags.c.id)
34+
)
35+
query12 = (
36+
tags.insert().values(id=5, name="query12", color="blue").returning(tags.c.id)
37+
)
38+
query2 = (
39+
tags.insert().values(id=6, name="query2", color="blue").returning(tags.c.id)
40+
)
41+
query2 = (
42+
tags.insert().values(id=7, name="query2", color="blue").returning(tags.c.id)
43+
)
44+
45+
async with asyncpg_engine.connect() as conn, conn.begin(): # starts transaction (savepoint)
46+
47+
result = await conn.execute(query1)
48+
assert result.scalar() == 2
3449

35-
# to make it fail, just repeat query since `id` is unique
36-
# TODO: query1111 = tags.insert().values(id=1, name="query1111", color="blue")
37-
# TODO: await conn.commit() # explicit commit !
38-
39-
# TODO: if this is true, then the order of execution is NOT preserved?
40-
async with asyncpg_engine.connect() as conn:
50+
total_count = (await conn.execute(total_count_query)).scalar()
51+
assert total_count == 1
4152

42-
await conn.execute(query1)
53+
rows = (await conn.execute(tags.select().where(tags.c.id == 2))).fetchall()
54+
assert rows
55+
assert rows[0].id == 2
4356

44-
async with conn.begin(): # savepoint
57+
async with conn.begin_nested(): # savepoint
4558
await conn.execute(query11)
4659

47-
async with conn.begin_nested(): # savepoint
48-
await conn.execute(query111)
49-
60+
with pytest.raises(IntegrityError):
5061
async with conn.begin_nested(): # savepoint
51-
await conn.execute(query1111)
52-
53-
await conn.execute(query112)
54-
55-
total_count = (await conn.execute(total_count_query)).scalar()
56-
assert total_count == 1 # (query1111)
62+
await conn.execute(query11)
5763

5864
await conn.execute(query12)
5965

6066
total_count = (await conn.execute(total_count_query)).scalar()
61-
assert total_count == 3 # query111, (query1111), query112
67+
assert total_count == 3 # since query11 (second time) reverted!
6268

6369
await conn.execute(query2)
6470

6571
total_count = (await conn.execute(total_count_query)).scalar()
66-
assert total_count == 5 # query11, (query111, (query1111), query112), query2
67-
68-
total_count = (await conn.execute(total_count_query)).scalar()
69-
assert total_count == 7 # includes query1, query2
72+
assert total_count == 4
7073

7174

72-
class _PageDict(TypedDict):
75+
class _PageTuple(NamedTuple):
7376
total_count: int
7477
rows: list[dict[str, Any]]
7578

@@ -95,22 +98,22 @@ async def get_by_id(
9598
self,
9699
connection: AsyncConnection | None = None,
97100
*,
98-
record_id: int,
101+
row_id: int,
99102
) -> dict[str, Any] | None:
100103
async with get_or_create_connection(self.engine, connection) as conn:
101104
result = await conn.execute(
102-
sa.select(self.table).where(self.table.c.id == record_id)
105+
sa.select(self.table).where(self.table.c.id == row_id)
103106
)
104-
record = result.fetchone()
105-
return dict(record) if record else None
107+
row = result.mappings().fetchone()
108+
return dict(row) if row else None
106109

107110
async def get_page(
108111
self,
109112
connection: AsyncConnection | None = None,
110113
*,
111114
limit: int,
112-
offset: int,
113-
) -> _PageDict:
115+
offset: int = 0,
116+
) -> _PageTuple:
114117
async with get_or_create_connection(self.engine, connection) as conn:
115118
# Compute total count
116119
total_count_query = sa.select(sa.func.count()).select_from(self.table)
@@ -120,81 +123,91 @@ async def get_page(
120123
# Fetch paginated results
121124
query = sa.select(self.table).limit(limit).offset(offset)
122125
result = await conn.execute(query)
123-
records = [dict(**row) for row in result.fetchall()]
126+
rows = [dict(row) for row in result.mappings().fetchall()]
124127

125-
return _PageDict(total_count=total_count or 0, rows=records)
128+
return _PageTuple(total_count=total_count or 0, rows=rows)
126129

127130
async def update(
128131
self,
129132
connection: AsyncConnection | None = None,
130133
*,
131-
record_id: int,
134+
row_id: int,
132135
**values,
133136
) -> bool:
134137
async with transaction_context(self.engine, connection) as conn:
135138
result = await conn.execute(
136-
self.table.update().where(self.table.c.id == record_id).values(**values)
139+
self.table.update().where(self.table.c.id == row_id).values(**values)
137140
)
138141
return result.rowcount > 0
139142

140143
async def delete(
141144
self,
142145
connection: AsyncConnection | None = None,
143146
*,
144-
record_id: int,
147+
row_id: int,
145148
) -> bool:
146149
async with transaction_context(self.engine, connection) as conn:
147150
result = await conn.execute(
148-
self.table.delete().where(self.table.c.id == record_id)
151+
self.table.delete().where(self.table.c.id == row_id)
149152
)
150153
return result.rowcount > 0
151154

152155

153-
async def test_transaction_context(asyncpg_engine: AsyncEngine):
156+
async def test_oneresourcerepodemo_prototype(asyncpg_engine: AsyncEngine):
154157

155158
tags_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=tags)
156159

160+
# create
161+
tag_id = await tags_repo.create(name="cyan tag", color="cyan")
162+
assert tag_id > 0
163+
164+
# get, list
165+
tag = await tags_repo.get_by_id(row_id=tag_id)
166+
assert tag
167+
168+
page = await tags_repo.get_page(limit=10)
169+
assert page.total_count == 1
170+
assert page.rows == [tag]
171+
172+
# update
173+
ok = await tags_repo.update(row_id=tag_id, name="changed name")
174+
assert ok
175+
176+
updated_tag = await tags_repo.get_by_id(row_id=tag_id)
177+
assert updated_tag
178+
assert updated_tag["name"] != tag["name"]
179+
180+
# delete
181+
ok = await tags_repo.delete(row_id=tag_id)
182+
assert ok
183+
184+
assert not await tags_repo.get_by_id(row_id=tag_id)
185+
186+
187+
async def test_transaction_context(asyncpg_engine: AsyncEngine):
157188
# (1) Using transaction_context and fails
158189
fake_error_msg = "some error"
159190

160191
def _something_raises_here():
161192
raise RuntimeError(fake_error_msg)
162193

163-
async def _create_blue_like_tags(connection):
164-
# NOTE: embedded transaction here!!!
165-
async with transaction_context(asyncpg_engine, connection) as conn:
166-
await tags_repo.create(conn, name="cyan tag", color="cyan")
167-
_something_raises_here()
168-
await tags_repo.create(conn, name="violet tag", color="violet")
169-
170-
async def _create_four_tags(connection):
171-
await tags_repo.create(connection, name="red tag", color="red")
172-
await _create_blue_like_tags(connection)
173-
await tags_repo.create(connection, name="green tag", color="green")
174-
175-
with pytest.raises(RuntimeError, match=fake_error_msg):
176-
async with transaction_context(asyncpg_engine) as conn:
177-
await tags_repo.create(conn, name="red tag", color="red")
178-
_something_raises_here()
179-
await tags_repo.create(conn, name="green tag", color="green")
180-
181-
print(asyncpg_engine.pool.status())
182-
assert conn.closed
194+
tags_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=tags)
183195

184-
page = await tags_repo.get_page(limit=50, offset=0)
185-
assert page["total_count"] == 0, "Transaction did not happen"
196+
# using external transaction_context: commits upon __aexit__
197+
async with transaction_context(asyncpg_engine) as conn:
198+
await tags_repo.create(conn, name="cyan tag", color="cyan")
199+
await tags_repo.create(conn, name="red tag", color="red")
200+
assert (await tags_repo.get_page(conn, limit=10, offset=0)).total_count == 2
186201

187-
# (2) using internal connections
188-
await tags_repo.create(name="blue tag", color="blue")
202+
# using internal: auto-commit
189203
await tags_repo.create(name="red tag", color="red")
190-
page = await tags_repo.get_page(limit=50, offset=0)
191-
assert page["total_count"] == 2
204+
assert (await tags_repo.get_page(limit=10, offset=0)).total_count == 3
192205

193-
# (3) using external embedded
194-
async with transaction_context(asyncpg_engine) as conn:
195-
page = await tags_repo.get_page(conn, limit=50, offset=0)
196-
assert page["total_count"] == 2
206+
# auto-rollback
207+
with pytest.raises(RuntimeError, match=fake_error_msg): # noqa: PT012
208+
async with transaction_context(asyncpg_engine) as conn:
209+
await tags_repo.create(conn, name="violet tag", color="violet")
210+
assert (await tags_repo.get_page(conn, limit=10, offset=0)).total_count == 4
211+
_something_raises_here()
197212

198-
# select a Result, which will be delivered with buffered results
199-
result = await conn.execute(sa.select(tags).where(tags.c.name == "blue tag"))
200-
assert result.fetchall()
213+
assert (await tags_repo.get_page(limit=10, offset=0)).total_count == 3

0 commit comments

Comments
 (0)