11import datetime
22import logging
33
4- import simcore_postgres_database .aiopg_errors as db_errors
54import sqlalchemy as sa
5+ import sqlalchemy .exc
66from aiohttp import web
7- from aiopg .sa .result import ResultProxy
87from models_library .api_schemas_webserver .wallets import PaymentMethodID
98from models_library .users import UserID
109from models_library .wallets import WalletID
1312 InitPromptAckFlowState ,
1413 payments_methods ,
1514)
16- from sqlalchemy import literal_column
17- from sqlalchemy .sql import func
15+ from simcore_postgres_database .utils_repos import (
16+ pass_or_acquire_connection ,
17+ transaction_context ,
18+ )
19+ from sqlalchemy .ext .asyncio import AsyncConnection
1820
19- from ..db .plugin import get_database_engine_legacy
21+ from ..db .plugin import get_asyncpg_engine
2022from .errors import (
2123 PaymentMethodAlreadyAckedError ,
2224 PaymentMethodNotFoundError ,
2628_logger = logging .getLogger (__name__ )
2729
2830
29- class PaymentsMethodsDB (BaseModel ):
31+ class PaymentsMethodsGetDB (BaseModel ):
3032 payment_method_id : PaymentMethodID
3133 user_id : UserID
3234 wallet_id : WalletID
@@ -40,13 +42,14 @@ class PaymentsMethodsDB(BaseModel):
4042
4143async def insert_init_payment_method (
4244 app : web .Application ,
45+ connection : AsyncConnection | None = None ,
4346 * ,
4447 payment_method_id : str ,
4548 user_id : UserID ,
4649 wallet_id : WalletID ,
4750 initiated_at : datetime .datetime ,
4851) -> None :
49- async with get_database_engine_legacy ( app ). acquire ( ) as conn :
52+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
5053 try :
5154 await conn .execute (
5255 payments_methods .insert ().values (
@@ -56,65 +59,68 @@ async def insert_init_payment_method(
5659 initiated_at = initiated_at ,
5760 )
5861 )
59- except db_errors . UniqueViolation as err :
62+ except sqlalchemy . exc . IntegrityError as err :
6063 raise PaymentMethodUniqueViolationError (
6164 payment_method_id = payment_method_id
6265 ) from err
6366
6467
6568async def list_successful_payment_methods (
66- app ,
69+ app : web .Application ,
70+ connection : AsyncConnection | None = None ,
6771 * ,
6872 user_id : UserID ,
6973 wallet_id : WalletID ,
70- ) -> list [PaymentsMethodsDB ]:
71- async with get_database_engine_legacy ( app ). acquire ( ) as conn :
72- result : ResultProxy = await conn .execute (
73- payments_methods .select ()
74+ ) -> list [PaymentsMethodsGetDB ]:
75+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
76+ result = await conn .execute (
77+ sa .select (payments_methods )
7478 .where (
7579 (payments_methods .c .user_id == user_id )
7680 & (payments_methods .c .wallet_id == wallet_id )
7781 & (payments_methods .c .state == InitPromptAckFlowState .SUCCESS )
7882 )
7983 .order_by (payments_methods .c .created .desc ())
8084 ) # newest first
81- rows = await result .fetchall () or []
82- return TypeAdapter (list [PaymentsMethodsDB ]).validate_python (rows )
85+ rows = result .fetchall ()
86+ return TypeAdapter (list [PaymentsMethodsGetDB ]).validate_python (rows )
8387
8488
8589async def get_successful_payment_method (
86- app ,
90+ app : web .Application ,
91+ connection : AsyncConnection | None = None ,
8792 * ,
8893 user_id : UserID ,
8994 wallet_id : WalletID ,
9095 payment_method_id : PaymentMethodID ,
91- ) -> PaymentsMethodsDB :
92- async with get_database_engine_legacy ( app ). acquire ( ) as conn :
93- result : ResultProxy = await conn .execute (
94- payments_methods .select ().where (
96+ ) -> PaymentsMethodsGetDB :
97+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
98+ result = await conn .execute (
99+ sa .select (payments_methods ).where (
95100 (payments_methods .c .user_id == user_id )
96101 & (payments_methods .c .wallet_id == wallet_id )
97102 & (payments_methods .c .payment_method_id == payment_method_id )
98103 & (payments_methods .c .state == InitPromptAckFlowState .SUCCESS )
99104 )
100105 )
101- row = await result .first ()
106+ row = result .one_or_none ()
102107 if row is None :
103108 raise PaymentMethodNotFoundError (payment_method_id = payment_method_id )
104109
105- return PaymentsMethodsDB .model_validate (row )
110+ return PaymentsMethodsGetDB .model_validate (row )
106111
107112
108113async def get_pending_payment_methods_ids (
109114 app : web .Application ,
115+ connection : AsyncConnection | None = None ,
110116) -> list [PaymentMethodID ]:
111- async with get_database_engine_legacy ( app ). acquire ( ) as conn :
117+ async with pass_or_acquire_connection ( get_asyncpg_engine ( app ), connection ) as conn :
112118 result = await conn .execute (
113119 sa .select (payments_methods .c .payment_method_id )
114120 .where (payments_methods .c .completed_at .is_ (None ))
115121 .order_by (payments_methods .c .initiated_at .asc ()) # oldest first
116122 )
117- rows = await result .fetchall () or []
123+ rows = result .fetchall ()
118124 return [
119125 TypeAdapter (PaymentMethodID ).validate_python (row .payment_method_id )
120126 for row in rows
@@ -124,10 +130,11 @@ async def get_pending_payment_methods_ids(
124130async def udpate_payment_method (
125131 app : web .Application ,
126132 payment_method_id : PaymentMethodID ,
133+ connection : AsyncConnection | None = None ,
127134 * ,
128135 state : InitPromptAckFlowState ,
129136 state_message : str | None ,
130- ) -> PaymentsMethodsDB :
137+ ) -> PaymentsMethodsGetDB :
131138 """
132139
133140 Raises:
@@ -142,17 +149,16 @@ async def udpate_payment_method(
142149 if state_message :
143150 optional ["state_message" ] = state_message
144151
145- async with get_database_engine_legacy (app ).acquire () as conn , conn .begin ():
146- row = await (
147- await conn .execute (
148- sa .select (
149- payments_methods .c .initiated_at ,
150- payments_methods .c .completed_at ,
151- )
152- .where (payments_methods .c .payment_method_id == payment_method_id )
153- .with_for_update ()
152+ async with transaction_context (get_asyncpg_engine (app ), connection ) as conn :
153+ result = await conn .execute (
154+ sa .select (
155+ payments_methods .c .initiated_at ,
156+ payments_methods .c .completed_at ,
154157 )
155- ).fetchone ()
158+ .where (payments_methods .c .payment_method_id == payment_method_id )
159+ .with_for_update ()
160+ )
161+ row = result .one_or_none ()
156162
157163 if row is None :
158164 raise PaymentMethodNotFoundError (payment_method_id = payment_method_id )
@@ -162,24 +168,24 @@ async def udpate_payment_method(
162168
163169 result = await conn .execute (
164170 payments_methods .update ()
165- .values (completed_at = func .now (), state = state , ** optional )
171+ .values (completed_at = sa . func .now (), state = state , ** optional )
166172 .where (payments_methods .c .payment_method_id == payment_method_id )
167- .returning (literal_column ( "*" ) )
173+ .returning (payments_methods )
168174 )
169- row = await result .first ()
170- assert row , "execute above should have caught this" # nosec
175+ row = result .one ()
171176
172- return PaymentsMethodsDB .model_validate (row )
177+ return PaymentsMethodsGetDB .model_validate (row )
173178
174179
175180async def delete_payment_method (
176181 app : web .Application ,
182+ connection : AsyncConnection | None = None ,
177183 * ,
178184 user_id : UserID ,
179185 wallet_id : WalletID ,
180186 payment_method_id : PaymentMethodID ,
181- ):
182- async with get_database_engine_legacy ( app ). acquire ( ) as conn :
187+ ) -> None :
188+ async with transaction_context ( get_asyncpg_engine ( app ), connection ) as conn :
183189 await conn .execute (
184190 payments_methods .delete ().where (
185191 (payments_methods .c .user_id == user_id )
0 commit comments