1212 get_or_create_connection ,
1313 transaction_context ,
1414)
15+ from simcore_postgres_database .models .tags import tags
1516from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
1617
1718
@@ -24,9 +25,13 @@ class OneResourceRepoDemo:
2425 # This is a PROTOTYPE of how one could implement a generic
2526 # repo that provides CRUD operations providing a given table
2627 def __init__ (self , engine : AsyncEngine , table : sa .Table ):
27- self .engine = engine
28+ if "id" not in table .columns :
29+ msg = "id column expected"
30+ raise ValueError (msg )
2831 self .table = table
2932
33+ self .engine = engine
34+
3035 async def create (self , connection : AsyncConnection | None = None , ** kwargs ) -> int :
3136 async with transaction_context (self .engine , connection ) as conn :
3237 result = await conn .execute (self .table .insert ().values (** kwargs ))
@@ -35,7 +40,10 @@ async def create(self, connection: AsyncConnection | None = None, **kwargs) -> i
3540 return result .inserted_primary_key [0 ]
3641
3742 async def get_by_id (
38- self , record_id : int , connection : AsyncConnection | None = None
43+ self ,
44+ connection : AsyncConnection | None = None ,
45+ * ,
46+ record_id : int ,
3947 ) -> dict [str , Any ] | None :
4048 async with get_or_create_connection (self .engine , connection ) as conn :
4149 result = await conn .execute (
@@ -45,7 +53,11 @@ async def get_by_id(
4553 return dict (record ) if record else None
4654
4755 async def get_page (
48- self , limit : int , offset : int , connection : AsyncConnection | None = None
56+ self ,
57+ connection : AsyncConnection | None = None ,
58+ * ,
59+ limit : int ,
60+ offset : int ,
4961 ) -> _PageDict :
5062 async with get_or_create_connection (self .engine , connection ) as conn :
5163 # Compute total count
@@ -61,7 +73,11 @@ async def get_page(
6173 return _PageDict (total_count = total_count or 0 , rows = records )
6274
6375 async def update (
64- self , record_id : int , connection : AsyncConnection | None = None , ** values
76+ self ,
77+ connection : AsyncConnection | None = None ,
78+ * ,
79+ record_id : int ,
80+ ** values ,
6581 ) -> bool :
6682 async with transaction_context (self .engine , connection ) as conn :
6783 result = await conn .execute (
@@ -71,7 +87,10 @@ async def update(
7187 return result .rowcount > 0
7288
7389 async def delete (
74- self , record_id : int , connection : AsyncConnection | None = None
90+ self ,
91+ connection : AsyncConnection | None = None ,
92+ * ,
93+ record_id : int ,
7594 ) -> bool :
7695 async with transaction_context (self .engine , connection ) as conn :
7796 result = await conn .execute (
@@ -81,30 +100,48 @@ async def delete(
81100 return result .rowcount > 0
82101
83102
84- @pytest .mark .skip ()
85- async def test_sqlachemy_asyncio_example (asyncpg_engine : AsyncEngine ):
103+ async def test_transaction_context (asyncpg_engine : AsyncEngine ):
86104 #
87- # Same example as in https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
88- # but using `t1_repo`
89- #
90- meta = sa .MetaData ()
91- t1 = sa .Table ("t1" , meta , sa .Column ("name" , sa .String (50 ), primary_key = True ))
105+ # Similar to example in https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#synopsis-core
106+ # using tags
92107
93- t1_repo = OneResourceRepoDemo (engine = asyncpg_engine , table = t1 )
108+ tags_repo = OneResourceRepoDemo (engine = asyncpg_engine , table = tags )
94109
95- async with transaction_context (asyncpg_engine ) as conn :
110+ # (1) Using transaction_context and fails
111+ fake_error_msg = "some error"
96112
97- await conn . run_sync ( meta . drop_all )
98- await conn . run_sync ( meta . create_all )
113+ def _something_raises_here ():
114+ raise RuntimeError ( fake_error_msg )
99115
100- await t1_repo .create (conn , name = "some name 1" )
101- await t1_repo .create (conn , name = "some name 2" )
116+ async def _create_blue_like_tags (conn ):
117+ async with conn .begin (): # NOTE: embedded transaction here!
118+ await tags_repo .create (conn , name = "cyan tag" , color = "cyan" )
119+ _something_raises_here ()
120+ await tags_repo .create (conn , name = "violet tag" , color = "violet" )
102121
103- async with transaction_context (asyncpg_engine ) as conn :
104- page = await t1_repo .get_page (limit = 50 , offset = 0 , connection = conn )
122+ async def _create_four_tags (conn ):
123+ await tags_repo .create (conn , name = "red tag" , color = "red" )
124+ await _create_blue_like_tags (conn )
125+ await tags_repo .create (conn , name = "green tag" , color = "green" )
126+
127+ with pytest .raises (RuntimeError , match = fake_error_msg ):
128+ async with transaction_context (asyncpg_engine ) as conn :
129+ await _create_four_tags (conn )
130+
131+ page = await tags_repo .get_page (limit = 50 , offset = 0 )
132+ assert page ["total_count" ] == 0 , "Transaction did not happen"
105133
134+ # (2) using internal connections
135+ await tags_repo .create (name = "blue tag" , color = "blue" )
136+ await tags_repo .create (conn , name = "red tag" , color = "red" )
137+ page = await tags_repo .get_page (limit = 50 , offset = 0 )
138+ assert page ["total_count" ] == 2
139+
140+ # (3) using external embedded
141+ async with transaction_context (asyncpg_engine ) as conn :
142+ page = await tags_repo .get_page (conn , limit = 50 , offset = 0 )
106143 assert page ["total_count" ] == 2
107144
108145 # select a Result, which will be delivered with buffered results
109- result = await conn .execute (sa .select (t1 ).where (t1 .c .name == "some name 1" ))
146+ result = await conn .execute (sa .select (tags ).where (tags .c .name == "some name 1" ))
110147 assert result .fetchall ()
0 commit comments