Skip to content

Commit 46d6dbc

Browse files
committed
adapting base tests
1 parent ac04592 commit 46d6dbc

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed
Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
1+
import logging
12
from contextlib import asynccontextmanager
23

34
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
45

6+
_logger = logging.getLogger(__name__)
7+
58

69
@asynccontextmanager
710
async def get_or_create_connection(
811
engine: AsyncEngine, connection: AsyncConnection | None = None
912
):
10-
close_conn = False
11-
if connection is None:
13+
# creator is responsible of closing connection
14+
is_connection_created = connection is None
15+
if is_connection_created:
1216
connection = await engine.connect()
13-
close_conn = True
1417
try:
1518
yield connection
1619
finally:
17-
if close_conn:
20+
if is_connection_created:
1821
await connection.close()
1922

2023

@@ -24,8 +27,12 @@ async def transaction_context(
2427
):
2528
async with get_or_create_connection(engine, connection) as conn:
2629
if conn.in_transaction():
27-
async with conn.begin_nested():
28-
yield conn
30+
# async with conn.begin_nested():
31+
yield conn
2932
else:
30-
async with conn.begin():
31-
yield conn
33+
try:
34+
async with conn.begin():
35+
yield conn
36+
finally:
37+
assert not conn.closed # nosec
38+
assert not conn.in_transaction() # nosec

packages/postgres-database/tests/test_base_repo.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def __init__(self, engine: AsyncEngine, table: sa.Table):
3535
async def create(self, connection: AsyncConnection | None = None, **kwargs) -> int:
3636
async with transaction_context(self.engine, connection) as conn:
3737
result = await conn.execute(self.table.insert().values(**kwargs))
38-
await conn.commit()
3938
assert result # nosec
4039
return result.inserted_primary_key[0]
4140

@@ -68,7 +67,7 @@ async def get_page(
6867
# Fetch paginated results
6968
query = sa.select(self.table).limit(limit).offset(offset)
7069
result = await conn.execute(query)
71-
records = [dict(row) for row in result.fetchall()]
70+
records = [dict(**row) for row in result.fetchall()]
7271

7372
return _PageDict(total_count=total_count or 0, rows=records)
7473

@@ -83,7 +82,6 @@ async def update(
8382
result = await conn.execute(
8483
self.table.update().where(self.table.c.id == record_id).values(**values)
8584
)
86-
await conn.commit()
8785
return result.rowcount > 0
8886

8987
async def delete(
@@ -96,10 +94,16 @@ async def delete(
9694
result = await conn.execute(
9795
self.table.delete().where(self.table.c.id == record_id)
9896
)
99-
await conn.commit()
10097
return result.rowcount > 0
10198

10299

100+
# async def test_it(asyncpg_engine: AsyncEngine):
101+
102+
# async with asyncpg_engine.connect() as conn:
103+
# async with conn.begin():
104+
# conn.execute()
105+
106+
103107
async def test_transaction_context(asyncpg_engine: AsyncEngine):
104108
#
105109
# Similar to example in https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
@@ -113,27 +117,33 @@ async def test_transaction_context(asyncpg_engine: AsyncEngine):
113117
def _something_raises_here():
114118
raise RuntimeError(fake_error_msg)
115119

116-
async def _create_blue_like_tags(conn):
117-
async with conn.begin(): # NOTE: embedded transaction here!
120+
async def _create_blue_like_tags(connection):
121+
# NOTE: embedded transaction here!!!
122+
async with transaction_context(asyncpg_engine, connection) as conn:
118123
await tags_repo.create(conn, name="cyan tag", color="cyan")
119124
_something_raises_here()
120125
await tags_repo.create(conn, name="violet tag", color="violet")
121126

122-
async def _create_four_tags(conn):
123-
await tags_repo.create(conn, name="red tag", color="red")
124-
await _create_blue_like_tags(conn)
125-
await tags_repo.create(conn, name="green tag", color="green")
127+
async def _create_four_tags(connection):
128+
await tags_repo.create(connection, name="red tag", color="red")
129+
await _create_blue_like_tags(connection)
130+
await tags_repo.create(connection, name="green tag", color="green")
126131

127132
with pytest.raises(RuntimeError, match=fake_error_msg):
128133
async with transaction_context(asyncpg_engine) as conn:
129-
await _create_four_tags(conn)
134+
await tags_repo.create(conn, name="red tag", color="red")
135+
_something_raises_here()
136+
await tags_repo.create(conn, name="green tag", color="green")
137+
138+
print(asyncpg_engine.pool.status())
139+
assert conn.closed
130140

131141
page = await tags_repo.get_page(limit=50, offset=0)
132142
assert page["total_count"] == 0, "Transaction did not happen"
133143

134144
# (2) using internal connections
135145
await tags_repo.create(name="blue tag", color="blue")
136-
await tags_repo.create(conn, name="red tag", color="red")
146+
await tags_repo.create(name="red tag", color="red")
137147
page = await tags_repo.get_page(limit=50, offset=0)
138148
assert page["total_count"] == 2
139149

@@ -143,5 +153,5 @@ async def _create_four_tags(conn):
143153
assert page["total_count"] == 2
144154

145155
# select a Result, which will be delivered with buffered results
146-
result = await conn.execute(sa.select(tags).where(tags.c.name == "some name 1"))
156+
result = await conn.execute(sa.select(tags).where(tags.c.name == "blue tag"))
147157
assert result.fetchall()

0 commit comments

Comments
 (0)