Skip to content

Commit 4926381

Browse files
committed
mv base to tests
1 parent 90c4be4 commit 4926381

File tree

2 files changed

+81
-87
lines changed

2 files changed

+81
-87
lines changed
Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from contextlib import asynccontextmanager
2-
from typing import Any, TypedDict
32

4-
import sqlalchemy as sa
53
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
64

75

@@ -31,67 +29,3 @@ async def transaction_context(
3129
else:
3230
async with conn.begin():
3331
yield conn
34-
35-
36-
class _PageDict(TypedDict):
37-
total_count: int
38-
rows: list[dict[str, Any]]
39-
40-
41-
class MinimalRepo:
42-
def __init__(self, engine: AsyncEngine, table: sa.Table):
43-
self.engine = engine
44-
self.table = table
45-
46-
async def create(self, connection: AsyncConnection | None = None, **kwargs) -> int:
47-
async with get_or_create_connection(self.engine, connection) as conn:
48-
result = await conn.execute(self.table.insert().values(**kwargs))
49-
await conn.commit()
50-
assert result # nosec
51-
return result.inserted_primary_key[0]
52-
53-
async def get_by_id(
54-
self, record_id: int, connection: AsyncConnection | None = None
55-
) -> dict[str, Any] | None:
56-
async with get_or_create_connection(self.engine, connection) as conn:
57-
result = await conn.execute(
58-
sa.select(self.table).where(self.table.c.id == record_id)
59-
)
60-
record = result.fetchone()
61-
return dict(record) if record else None
62-
63-
async def get_page(
64-
self, limit: int, offset: int, connection: AsyncConnection | None = None
65-
) -> _PageDict:
66-
async with get_or_create_connection(self.engine, connection) as conn:
67-
# Compute total count
68-
total_count_query = sa.select(sa.func.count()).select_from(self.table)
69-
total_count_result = await conn.execute(total_count_query)
70-
total_count = total_count_result.scalar()
71-
72-
# Fetch paginated results
73-
query = sa.select(self.table).limit(limit).offset(offset)
74-
result = await conn.execute(query)
75-
records = [dict(row) for row in result.fetchall()]
76-
77-
return _PageDict(total_count=total_count or 0, rows=records)
78-
79-
async def update(
80-
self, record_id: int, connection: AsyncConnection | None = None, **values
81-
) -> bool:
82-
async with get_or_create_connection(self.engine, connection) as conn:
83-
result = await conn.execute(
84-
self.table.update().where(self.table.c.id == record_id).values(**values)
85-
)
86-
await conn.commit()
87-
return result.rowcount > 0
88-
89-
async def delete(
90-
self, record_id: int, connection: AsyncConnection | None = None
91-
) -> bool:
92-
async with get_or_create_connection(self.engine, connection) as conn:
93-
result = await conn.execute(
94-
self.table.delete().where(self.table.c.id == record_id)
95-
)
96-
await conn.commit()
97-
return result.rowcount > 0

packages/postgres-database/tests/test_base_repo.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,106 @@
33
# pylint: disable=unused-variable
44
# pylint: disable=too-many-arguments
55

6-
from collections.abc import AsyncIterator, Callable
6+
7+
from typing import Any, TypedDict
78

89
import sqlalchemy as sa
9-
from simcore_postgres_database.base_repo import MinimalRepo, transaction_context
10-
from sqlalchemy.ext.asyncio import AsyncEngine
10+
from simcore_postgres_database.base_repo import (
11+
get_or_create_connection,
12+
transaction_context,
13+
)
14+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
15+
16+
17+
class _PageDict(TypedDict):
18+
total_count: int
19+
rows: list[dict[str, Any]]
20+
21+
22+
class OneResourceRepoDemo:
23+
# This is a PROTOTYPE of how one could implement a generic
24+
# repo that provides CRUD operations providing a given table
25+
def __init__(self, engine: AsyncEngine, table: sa.Table):
26+
self.engine = engine
27+
self.table = table
28+
29+
async def create(self, connection: AsyncConnection | None = None, **kwargs) -> int:
30+
async with get_or_create_connection(self.engine, connection) as conn:
31+
result = await conn.execute(self.table.insert().values(**kwargs))
32+
await conn.commit()
33+
assert result # nosec
34+
return result.inserted_primary_key[0]
1135

36+
async def get_by_id(
37+
self, record_id: int, connection: AsyncConnection | None = None
38+
) -> dict[str, Any] | None:
39+
async with get_or_create_connection(self.engine, connection) as conn:
40+
result = await conn.execute(
41+
sa.select(self.table).where(self.table.c.id == record_id)
42+
)
43+
record = result.fetchone()
44+
return dict(record) if record else None
1245

13-
async def asyncio_engine(
14-
make_asyncio_engine: Callable[[bool], AsyncEngine]
15-
) -> AsyncIterator[AsyncEngine]:
16-
engine = make_asyncio_engine(echo=True)
17-
try:
18-
yield engine
19-
except Exception:
20-
# for AsyncEngine created in function scope, close and
21-
# clean-up pooled connections
22-
await engine.dispose()
46+
async def get_page(
47+
self, limit: int, offset: int, connection: AsyncConnection | None = None
48+
) -> _PageDict:
49+
async with get_or_create_connection(self.engine, connection) as conn:
50+
# Compute total count
51+
total_count_query = sa.select(sa.func.count()).select_from(self.table)
52+
total_count_result = await conn.execute(total_count_query)
53+
total_count = total_count_result.scalar()
2354

55+
# Fetch paginated results
56+
query = sa.select(self.table).limit(limit).offset(offset)
57+
result = await conn.execute(query)
58+
records = [dict(row) for row in result.fetchall()]
2459

25-
async def test_it(asyncio_engine: AsyncEngine):
60+
return _PageDict(total_count=total_count or 0, rows=records)
2661

62+
async def update(
63+
self, record_id: int, connection: AsyncConnection | None = None, **values
64+
) -> bool:
65+
async with get_or_create_connection(self.engine, connection) as conn:
66+
result = await conn.execute(
67+
self.table.update().where(self.table.c.id == record_id).values(**values)
68+
)
69+
await conn.commit()
70+
return result.rowcount > 0
71+
72+
async def delete(
73+
self, record_id: int, connection: AsyncConnection | None = None
74+
) -> bool:
75+
async with get_or_create_connection(self.engine, connection) as conn:
76+
result = await conn.execute(
77+
self.table.delete().where(self.table.c.id == record_id)
78+
)
79+
await conn.commit()
80+
return result.rowcount > 0
81+
82+
83+
async def test_sqlachemy_asyncio_example(asyncpg_engine: AsyncEngine):
84+
#
85+
# Same example as in https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
86+
# but using `t1_repo`
87+
#
2788
meta = sa.MetaData()
2889
t1 = sa.Table("t1", meta, sa.Column("name", sa.String(50), primary_key=True))
2990

30-
t1_repo = MinimalRepo(engine=asyncio_engine, table=t1)
91+
t1_repo = OneResourceRepoDemo(engine=asyncpg_engine, table=t1)
92+
93+
async with transaction_context(asyncpg_engine) as conn:
3194

32-
async with transaction_context(asyncio_engine) as conn:
3395
await conn.run_sync(meta.drop_all)
3496
await conn.run_sync(meta.create_all)
3597

3698
await t1_repo.create(conn, name="some name 1")
3799
await t1_repo.create(conn, name="some name 2")
38100

39-
async with transaction_context(asyncio_engine) as conn:
40-
101+
async with transaction_context(asyncpg_engine) as conn:
41102
page = await t1_repo.get_page(limit=50, offset=0, connection=conn)
42103

43104
assert page["total_count"] == 2
44105

45-
# select a Result, which will be delivered with buffered
46-
# results
106+
# select a Result, which will be delivered with buffered results
47107
result = await conn.execute(sa.select(t1).where(t1.c.name == "some name 1"))
48-
print(result.fetchall())
108+
assert result.fetchall()

0 commit comments

Comments
 (0)