Skip to content

Commit 5303c7f

Browse files
committed
minimal repo
1 parent 99a8e80 commit 5303c7f

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from contextlib import asynccontextmanager
2+
from typing import Any, TypedDict
3+
4+
import sqlalchemy as sa
5+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
6+
7+
8+
@asynccontextmanager
9+
async def get_or_create_connection(
10+
engine: AsyncEngine, connection: AsyncConnection | None = None
11+
):
12+
close_conn = False
13+
if connection is None:
14+
connection = await engine.connect()
15+
close_conn = True
16+
try:
17+
yield connection
18+
finally:
19+
if close_conn:
20+
await connection.close()
21+
22+
23+
@asynccontextmanager
24+
async def transaction_context(
25+
engine: AsyncEngine, connection: AsyncConnection | None = None
26+
):
27+
async with get_or_create_connection(engine, connection) as conn:
28+
if conn.in_transaction():
29+
async with conn.begin_nested():
30+
yield conn
31+
else:
32+
async with conn.begin():
33+
yield conn
34+
35+
36+
class _PageDict(TypedDict):
37+
total_count: int
38+
rows: list[dict[str, Any]]
39+
40+
41+
class MinimalRepo:
42+
def __init__(self, engine: AsyncEngine, table: sa.Table):
43+
self.engine = engine
44+
self.table = table
45+
46+
async def create(self, connection: AsyncConnection | None = None, **kwargs) -> int:
47+
async with get_or_create_connection(self.engine, connection) as conn:
48+
result = await conn.execute(self.table.insert().values(**kwargs))
49+
await conn.commit()
50+
assert result # nosec
51+
return result.inserted_primary_key[0]
52+
53+
async def get_by_id(
54+
self, record_id: int, connection: AsyncConnection | None = None
55+
) -> dict[str, Any] | None:
56+
async with get_or_create_connection(self.engine, connection) as conn:
57+
result = await conn.execute(
58+
sa.select(self.table).where(self.table.c.id == record_id)
59+
)
60+
record = result.fetchone()
61+
return dict(record) if record else None
62+
63+
async def get_page(
64+
self, limit: int, offset: int, connection: AsyncConnection | None = None
65+
) -> _PageDict:
66+
async with get_or_create_connection(self.engine, connection) as conn:
67+
# Compute total count
68+
total_count_query = sa.select(sa.func.count()).select_from(self.table)
69+
total_count_result = await conn.execute(total_count_query)
70+
total_count = total_count_result.scalar()
71+
72+
# Fetch paginated results
73+
query = sa.select(self.table).limit(limit).offset(offset)
74+
result = await conn.execute(query)
75+
records = [dict(row) for row in result.fetchall()]
76+
77+
return _PageDict(total_count=total_count or 0, rows=records)
78+
79+
async def update(
80+
self, record_id: int, connection: AsyncConnection | None = None, **values
81+
) -> bool:
82+
async with get_or_create_connection(self.engine, connection) as conn:
83+
result = await conn.execute(
84+
self.table.update().where(self.table.c.id == record_id).values(**values)
85+
)
86+
await conn.commit()
87+
return result.rowcount > 0
88+
89+
async def delete(
90+
self, record_id: int, connection: AsyncConnection | None = None
91+
) -> bool:
92+
async with get_or_create_connection(self.engine, connection) as conn:
93+
result = await conn.execute(
94+
self.table.delete().where(self.table.c.id == record_id)
95+
)
96+
await conn.commit()
97+
return result.rowcount > 0
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# pylint: disable=redefined-outer-name
2+
# pylint: disable=unused-argument
3+
# pylint: disable=unused-variable
4+
# pylint: disable=too-many-arguments
5+
6+
from collections.abc import AsyncIterator, Callable
7+
8+
import sqlalchemy as sa
9+
from simcore_postgres_database.base_repo import MinimalRepo, transaction_context
10+
from sqlalchemy.ext.asyncio import AsyncEngine
11+
12+
13+
async def asyncio_engine(
14+
make_asyncio_engine: Callable[[bool], AsyncEngine]
15+
) -> AsyncIterator[AsyncEngine]:
16+
engine = make_asyncio_engine(echo=True)
17+
try:
18+
yield engine
19+
except Exception:
20+
# for AsyncEngine created in function scope, close and
21+
# clean-up pooled connections
22+
await engine.dispose()
23+
24+
25+
async def test_it(asyncio_engine: AsyncEngine):
26+
27+
meta = sa.MetaData()
28+
t1 = sa.Table("t1", meta, sa.Column("name", sa.String(50), primary_key=True))
29+
30+
t1_repo = MinimalRepo(engine=asyncio_engine, table=t1)
31+
32+
async with transaction_context(asyncio_engine) as conn:
33+
await conn.run_sync(meta.drop_all)
34+
await conn.run_sync(meta.create_all)
35+
36+
await t1_repo.create(conn, name="some name 1")
37+
await t1_repo.create(conn, name="some name 2")
38+
39+
async with transaction_context(asyncio_engine) as conn:
40+
41+
page = await t1_repo.get_page(limit=50, offset=0, connection=conn)
42+
43+
assert page["total_count"] == 2
44+
45+
# select a Result, which will be delivered with buffered
46+
# results
47+
result = await conn.execute(sa.select(t1).where(t1.c.name == "some name 1"))
48+
print(result.fetchall())

0 commit comments

Comments
 (0)