44# pylint: disable=too-many-arguments
55
66
7- from typing import Any , TypedDict
7+ from typing import Any , NamedTuple
88
99import pytest
1010import sqlalchemy as sa
1313 transaction_context ,
1414)
1515from simcore_postgres_database .models .tags import tags
16+ from sqlalchemy .exc import IntegrityError
1617from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
1718
1819
@@ -21,55 +22,57 @@ async def test_sa_transactions(asyncpg_engine: AsyncEngine):
2122 # SEE https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
2223 #
2324
25+ # READ query
2426 total_count_query = sa .select (sa .func .count ()).select_from (tags )
2527
2628 # WRITE queries
27- query1 = tags .insert ().values (name = "query1" , color = "blue" )
28- query11 = tags .insert ().values (name = "query11" , color = "blue" )
29- query111 = tags .insert ().values (name = "query111" , color = "blue" )
30- query1111 = tags .insert ().values (name = "query1111" , color = "blue" )
31- query112 = tags .insert ().values (name = "query112" , color = "blue" )
32- query12 = tags .insert ().values (name = "query12" , color = "blue" )
33- query2 = tags .insert ().values (name = "query2" , color = "blue" )
29+ query1 = (
30+ tags .insert ().values (id = 2 , name = "query1" , color = "blue" ).returning (tags .c .id )
31+ )
32+ query11 = (
33+ tags .insert ().values (id = 3 , name = "query11" , color = "blue" ).returning (tags .c .id )
34+ )
35+ query12 = (
36+ tags .insert ().values (id = 5 , name = "query12" , color = "blue" ).returning (tags .c .id )
37+ )
38+ query2 = (
39+ tags .insert ().values (id = 6 , name = "query2" , color = "blue" ).returning (tags .c .id )
40+ )
41+ query2 = (
42+ tags .insert ().values (id = 7 , name = "query2" , color = "blue" ).returning (tags .c .id )
43+ )
44+
45+ async with asyncpg_engine .connect () as conn , conn .begin (): # starts transaction (savepoint)
46+
47+ result = await conn .execute (query1 )
48+ assert result .scalar () == 2
3449
35- # to make it fail, just repeat query since `id` is unique
36- # TODO: query1111 = tags.insert().values(id=1, name="query1111", color="blue")
37- # TODO: await conn.commit() # explicit commit !
38-
39- # TODO: if this is true, then the order of execution is NOT preserved?
40- async with asyncpg_engine .connect () as conn :
50+ total_count = (await conn .execute (total_count_query )).scalar ()
51+ assert total_count == 1
4152
42- await conn .execute (query1 )
53+ rows = (await conn .execute (tags .select ().where (tags .c .id == 2 ))).fetchall ()
54+ assert rows
55+ assert rows [0 ].id == 2
4356
44- async with conn .begin (): # savepoint
57+ async with conn .begin_nested (): # savepoint
4558 await conn .execute (query11 )
4659
47- async with conn .begin_nested (): # savepoint
48- await conn .execute (query111 )
49-
60+ with pytest .raises (IntegrityError ):
5061 async with conn .begin_nested (): # savepoint
51- await conn .execute (query1111 )
52-
53- await conn .execute (query112 )
54-
55- total_count = (await conn .execute (total_count_query )).scalar ()
56- assert total_count == 1 # (query1111)
62+ await conn .execute (query11 )
5763
5864 await conn .execute (query12 )
5965
6066 total_count = (await conn .execute (total_count_query )).scalar ()
61- assert total_count == 3 # query111, (query1111), query112
67+ assert total_count == 3 # since query11 (second time) reverted!
6268
6369 await conn .execute (query2 )
6470
6571 total_count = (await conn .execute (total_count_query )).scalar ()
66- assert total_count == 5 # query11, (query111, (query1111), query112), query2
67-
68- total_count = (await conn .execute (total_count_query )).scalar ()
69- assert total_count == 7 # includes query1, query2
72+ assert total_count == 4
7073
7174
72- class _PageDict ( TypedDict ):
75+ class _PageTuple ( NamedTuple ):
7376 total_count : int
7477 rows : list [dict [str , Any ]]
7578
@@ -95,22 +98,22 @@ async def get_by_id(
9598 self ,
9699 connection : AsyncConnection | None = None ,
97100 * ,
98- record_id : int ,
101+ row_id : int ,
99102 ) -> dict [str , Any ] | None :
100103 async with get_or_create_connection (self .engine , connection ) as conn :
101104 result = await conn .execute (
102- sa .select (self .table ).where (self .table .c .id == record_id )
105+ sa .select (self .table ).where (self .table .c .id == row_id )
103106 )
104- record = result .fetchone ()
105- return dict (record ) if record else None
107+ row = result . mappings () .fetchone ()
108+ return dict (row ) if row else None
106109
107110 async def get_page (
108111 self ,
109112 connection : AsyncConnection | None = None ,
110113 * ,
111114 limit : int ,
112- offset : int ,
113- ) -> _PageDict :
115+ offset : int = 0 ,
116+ ) -> _PageTuple :
114117 async with get_or_create_connection (self .engine , connection ) as conn :
115118 # Compute total count
116119 total_count_query = sa .select (sa .func .count ()).select_from (self .table )
@@ -120,81 +123,91 @@ async def get_page(
120123 # Fetch paginated results
121124 query = sa .select (self .table ).limit (limit ).offset (offset )
122125 result = await conn .execute (query )
123- records = [dict (** row ) for row in result .fetchall ()]
126+ rows = [dict (row ) for row in result . mappings () .fetchall ()]
124127
125- return _PageDict (total_count = total_count or 0 , rows = records )
128+ return _PageTuple (total_count = total_count or 0 , rows = rows )
126129
127130 async def update (
128131 self ,
129132 connection : AsyncConnection | None = None ,
130133 * ,
131- record_id : int ,
134+ row_id : int ,
132135 ** values ,
133136 ) -> bool :
134137 async with transaction_context (self .engine , connection ) as conn :
135138 result = await conn .execute (
136- self .table .update ().where (self .table .c .id == record_id ).values (** values )
139+ self .table .update ().where (self .table .c .id == row_id ).values (** values )
137140 )
138141 return result .rowcount > 0
139142
140143 async def delete (
141144 self ,
142145 connection : AsyncConnection | None = None ,
143146 * ,
144- record_id : int ,
147+ row_id : int ,
145148 ) -> bool :
146149 async with transaction_context (self .engine , connection ) as conn :
147150 result = await conn .execute (
148- self .table .delete ().where (self .table .c .id == record_id )
151+ self .table .delete ().where (self .table .c .id == row_id )
149152 )
150153 return result .rowcount > 0
151154
152155
153- async def test_transaction_context (asyncpg_engine : AsyncEngine ):
156+ async def test_oneresourcerepodemo_prototype (asyncpg_engine : AsyncEngine ):
154157
155158 tags_repo = OneResourceRepoDemo (engine = asyncpg_engine , table = tags )
156159
160+ # create
161+ tag_id = await tags_repo .create (name = "cyan tag" , color = "cyan" )
162+ assert tag_id > 0
163+
164+ # get, list
165+ tag = await tags_repo .get_by_id (row_id = tag_id )
166+ assert tag
167+
168+ page = await tags_repo .get_page (limit = 10 )
169+ assert page .total_count == 1
170+ assert page .rows == [tag ]
171+
172+ # update
173+ ok = await tags_repo .update (row_id = tag_id , name = "changed name" )
174+ assert ok
175+
176+ updated_tag = await tags_repo .get_by_id (row_id = tag_id )
177+ assert updated_tag
178+ assert updated_tag ["name" ] != tag ["name" ]
179+
180+ # delete
181+ ok = await tags_repo .delete (row_id = tag_id )
182+ assert ok
183+
184+ assert not await tags_repo .get_by_id (row_id = tag_id )
185+
186+
187+ async def test_transaction_context (asyncpg_engine : AsyncEngine ):
157188 # (1) Using transaction_context and fails
158189 fake_error_msg = "some error"
159190
160191 def _something_raises_here ():
161192 raise RuntimeError (fake_error_msg )
162193
163- async def _create_blue_like_tags (connection ):
164- # NOTE: embedded transaction here!!!
165- async with transaction_context (asyncpg_engine , connection ) as conn :
166- await tags_repo .create (conn , name = "cyan tag" , color = "cyan" )
167- _something_raises_here ()
168- await tags_repo .create (conn , name = "violet tag" , color = "violet" )
169-
170- async def _create_four_tags (connection ):
171- await tags_repo .create (connection , name = "red tag" , color = "red" )
172- await _create_blue_like_tags (connection )
173- await tags_repo .create (connection , name = "green tag" , color = "green" )
174-
175- with pytest .raises (RuntimeError , match = fake_error_msg ):
176- async with transaction_context (asyncpg_engine ) as conn :
177- await tags_repo .create (conn , name = "red tag" , color = "red" )
178- _something_raises_here ()
179- await tags_repo .create (conn , name = "green tag" , color = "green" )
180-
181- print (asyncpg_engine .pool .status ())
182- assert conn .closed
194+ tags_repo = OneResourceRepoDemo (engine = asyncpg_engine , table = tags )
183195
184- page = await tags_repo .get_page (limit = 50 , offset = 0 )
185- assert page ["total_count" ] == 0 , "Transaction did not happen"
196+ # using external transaction_context: commits upon __aexit__
197+ async with transaction_context (asyncpg_engine ) as conn :
198+ await tags_repo .create (conn , name = "cyan tag" , color = "cyan" )
199+ await tags_repo .create (conn , name = "red tag" , color = "red" )
200+ assert (await tags_repo .get_page (conn , limit = 10 , offset = 0 )).total_count == 2
186201
187- # (2) using internal connections
188- await tags_repo .create (name = "blue tag" , color = "blue" )
202+ # using internal: auto-commit
189203 await tags_repo .create (name = "red tag" , color = "red" )
190- page = await tags_repo .get_page (limit = 50 , offset = 0 )
191- assert page ["total_count" ] == 2
204+ assert (await tags_repo .get_page (limit = 10 , offset = 0 )).total_count == 3
192205
193- # (3) using external embedded
194- async with transaction_context (asyncpg_engine ) as conn :
195- page = await tags_repo .get_page (conn , limit = 50 , offset = 0 )
196- assert page ["total_count" ] == 2
206+ # auto-rollback
207+ with pytest .raises (RuntimeError , match = fake_error_msg ): # noqa: PT012
208+ async with transaction_context (asyncpg_engine ) as conn :
209+ await tags_repo .create (conn , name = "violet tag" , color = "violet" )
210+ assert (await tags_repo .get_page (conn , limit = 10 , offset = 0 )).total_count == 4
211+ _something_raises_here ()
197212
198- # select a Result, which will be delivered with buffered results
199- result = await conn .execute (sa .select (tags ).where (tags .c .name == "blue tag" ))
200- assert result .fetchall ()
213+ assert (await tags_repo .get_page (limit = 10 , offset = 0 )).total_count == 3
0 commit comments