Skip to content

Commit 71e38b5

Browse files
committed
🐛 Fix: refactor payment transaction state updates to use asyncpg engine and improve connection handling in tests and payment completion logic
1 parent 0b3f782 commit 71e38b5

File tree

3 files changed

+100
-85
lines changed

3 files changed

+100
-85
lines changed

packages/postgres-database/src/simcore_postgres_database/utils_payments.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ async def update_payment_transaction_state(
125125
.where(payments_transactions.c.payment_id == payment_id)
126126
.returning(payments_transactions)
127127
)
128-
row = result.one()
129-
return row
128+
return result.one()
130129

131130

132131
async def get_user_payments_transactions(

packages/postgres-database/tests/test_models_payments_transactions.py

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,25 @@
2424
insert_init_payment_transaction,
2525
update_payment_transaction_state,
2626
)
27-
from sqlalchemy.ext.asyncio import AsyncConnection
27+
from simcore_postgres_database.utils_repos import transaction_context
28+
from sqlalchemy.ext.asyncio import AsyncEngine
2829

2930

30-
async def test_numerics_precission_and_scale(connection: AsyncConnection):
31+
async def test_numerics_precission_and_scale(asyncpg_engine: AsyncEngine):
3132
# https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Numeric
3233
# precision: This parameter specifies the total number of digits that can be stored, both before and after the decimal point.
3334
# scale: This parameter specifies the number of digits that can be stored to the right of the decimal point.
3435

35-
for order_of_magnitude in range(8):
36-
expected = 10**order_of_magnitude + 0.123
37-
got = await connection.scalar(
38-
payments_transactions.insert()
39-
.values(**random_payment_transaction(price_dollars=expected))
40-
.returning(payments_transactions.c.price_dollars)
41-
)
42-
assert isinstance(got, decimal.Decimal)
43-
assert float(got) == expected
36+
async with asyncpg_engine.begin() as connection:
37+
for order_of_magnitude in range(8):
38+
expected = 10**order_of_magnitude + 0.123
39+
got = await connection.scalar(
40+
payments_transactions.insert()
41+
.values(**random_payment_transaction(price_dollars=expected))
42+
.returning(payments_transactions.c.price_dollars)
43+
)
44+
assert isinstance(got, decimal.Decimal)
45+
assert float(got) == expected
4446

4547

4648
def _remove_not_required(data: dict[str, Any]) -> dict[str, Any]:
@@ -57,7 +59,7 @@ def _remove_not_required(data: dict[str, Any]) -> dict[str, Any]:
5759

5860

5961
@pytest.fixture
60-
def init_transaction(connection: AsyncConnection):
62+
def init_transaction(asyncpg_engine: AsyncEngine):
6163
async def _init(payment_id: str):
6264
# get payment_id from payment-gateway
6365
values = _remove_not_required(random_payment_transaction(payment_id=payment_id))
@@ -66,7 +68,8 @@ async def _init(payment_id: str):
6668
values["initiated_at"] = utcnow()
6769

6870
# insert
69-
ok = await insert_init_payment_transaction(connection, **values)
71+
async with asyncpg_engine.begin() as connection:
72+
ok = await insert_init_payment_transaction(connection, **values)
7073
assert ok
7174

7275
return values
@@ -80,19 +83,20 @@ def payment_id() -> str:
8083

8184

8285
async def test_init_transaction_sets_it_as_pending(
83-
connection: AsyncConnection, init_transaction: Callable, payment_id: str
86+
asyncpg_engine: AsyncEngine, init_transaction: Callable, payment_id: str
8487
):
8588
values = await init_transaction(payment_id)
8689
assert values["payment_id"] == payment_id
8790

8891
# check init-ed but not completed!
89-
result = await connection.execute(
90-
sa.select(
91-
payments_transactions.c.completed_at,
92-
payments_transactions.c.state,
93-
payments_transactions.c.state_message,
94-
).where(payments_transactions.c.payment_id == payment_id)
95-
)
92+
async with asyncpg_engine.connect() as connection:
93+
result = await connection.execute(
94+
sa.select(
95+
payments_transactions.c.completed_at,
96+
payments_transactions.c.state,
97+
payments_transactions.c.state_message,
98+
).where(payments_transactions.c.payment_id == payment_id)
99+
)
96100
row = result.one_or_none()
97101
assert row is not None
98102

@@ -126,59 +130,64 @@ def invoice_url(faker: Faker, expected_state: PaymentTransactionState) -> str |
126130
],
127131
)
128132
async def test_complete_transaction(
129-
connection: AsyncConnection,
133+
asyncpg_engine: AsyncEngine,
130134
init_transaction: Callable,
131135
payment_id: str,
132136
expected_state: PaymentTransactionState,
133137
expected_message: str | None,
134138
invoice_url: str | None,
135139
):
140+
# init
136141
await init_transaction(payment_id)
137142

138-
payment_row = await update_payment_transaction_state(
139-
connection,
140-
payment_id=payment_id,
141-
completion_state=expected_state,
142-
state_message=expected_message,
143-
invoice_url=invoice_url,
144-
)
143+
async with asyncpg_engine.connect() as connection:
144+
# NOTE: internal function uses transaction
145+
payment_row = await update_payment_transaction_state(
146+
connection,
147+
payment_id=payment_id,
148+
completion_state=expected_state,
149+
state_message=expected_message,
150+
invoice_url=invoice_url,
151+
)
145152

146-
assert isinstance(payment_row, PaymentTransactionRow)
147-
assert payment_row.state_message == expected_message
148-
assert payment_row.state == expected_state
149-
assert payment_row.initiated_at < payment_row.completed_at
150-
assert PaymentTransactionState(payment_row.state).is_completed()
153+
assert isinstance(payment_row, PaymentTransactionRow)
154+
assert payment_row.state_message == expected_message
155+
assert payment_row.state == expected_state
156+
assert payment_row.initiated_at < payment_row.completed_at
157+
assert PaymentTransactionState(payment_row.state).is_completed()
151158

152159

153160
async def test_update_transaction_failures_and_exceptions(
154-
connection: AsyncConnection,
161+
asyncpg_engine: AsyncEngine,
155162
init_transaction: Callable,
156163
payment_id: str,
157164
):
158-
kwargs = {
159-
"connection": connection,
160-
"payment_id": payment_id,
161-
"completion_state": PaymentTransactionState.SUCCESS,
162-
}
163165

164-
ok = await update_payment_transaction_state(**kwargs)
165-
assert isinstance(ok, PaymentNotFound)
166-
assert not ok
167-
168-
# init & complete
169-
await init_transaction(payment_id)
170-
ok = await update_payment_transaction_state(**kwargs)
171-
assert isinstance(ok, PaymentTransactionRow)
172-
assert ok
166+
async with asyncpg_engine.connect() as connection:
167+
kwargs = {
168+
"connection": connection,
169+
"payment_id": payment_id,
170+
"completion_state": PaymentTransactionState.SUCCESS,
171+
}
172+
173+
ok = await update_payment_transaction_state(**kwargs)
174+
assert isinstance(ok, PaymentNotFound)
175+
assert not ok
176+
177+
# init & complete
178+
await init_transaction(payment_id)
179+
ok = await update_payment_transaction_state(**kwargs)
180+
assert isinstance(ok, PaymentTransactionRow)
181+
assert ok
173182

174-
# repeat -> fails
175-
ok = await update_payment_transaction_state(**kwargs)
176-
assert isinstance(ok, PaymentAlreadyAcked)
177-
assert not ok
183+
# repeat -> fails
184+
ok = await update_payment_transaction_state(**kwargs)
185+
assert isinstance(ok, PaymentAlreadyAcked)
186+
assert not ok
178187

179-
with pytest.raises(ValueError):
180-
kwargs.update(completion_state=PaymentTransactionState.PENDING)
181-
await update_payment_transaction_state(**kwargs)
188+
with pytest.raises(ValueError, match="cannot update state with"): # noqa: PT012
189+
kwargs.update(completion_state=PaymentTransactionState.PENDING)
190+
await update_payment_transaction_state(**kwargs)
182191

183192

184193
@pytest.fixture
@@ -188,14 +197,18 @@ def user_id() -> int:
188197

189198
@pytest.fixture
190199
def create_fake_user_transactions(
191-
connection: AsyncConnection, user_id: int
200+
asyncpg_engine: AsyncEngine, user_id: int
192201
) -> Callable:
202+
203+
assert asyncpg_engine
204+
193205
async def _go(expected_total=5):
194206
payment_ids = []
195207
for _ in range(expected_total):
196208
values = _remove_not_required(random_payment_transaction(user_id=user_id))
197209

198-
payment_id = await insert_init_payment_transaction(connection, **values)
210+
async with transaction_context(asyncpg_engine) as connection:
211+
payment_id = await insert_init_payment_transaction(connection, **values)
199212
assert payment_id
200213
payment_ids.append(payment_id)
201214

@@ -205,19 +218,21 @@ async def _go(expected_total=5):
205218

206219

207220
async def test_get_user_payments_transactions(
208-
connection: AsyncConnection, create_fake_user_transactions: Callable, user_id: int
221+
asyncpg_engine: AsyncEngine, create_fake_user_transactions: Callable, user_id: int
209222
):
210223
expected_payments_ids = await create_fake_user_transactions()
211224
expected_total = len(expected_payments_ids)
212225

213226
# test offset and limit defaults
214-
total, rows = await get_user_payments_transactions(connection, user_id=user_id)
227+
async with asyncpg_engine.connect() as connection:
228+
total, rows = await get_user_payments_transactions(connection, user_id=user_id)
229+
215230
assert total == expected_total
216231
assert [r.payment_id for r in rows] == expected_payments_ids[::-1], "newest first"
217232

218233

219234
async def test_get_user_payments_transactions_with_pagination_options(
220-
connection: AsyncConnection, create_fake_user_transactions: Callable, user_id: int
235+
asyncpg_engine: AsyncEngine, create_fake_user_transactions: Callable, user_id: int
221236
):
222237
expected_payments_ids = await create_fake_user_transactions()
223238
expected_total = len(expected_payments_ids)
@@ -226,23 +241,24 @@ async def test_get_user_payments_transactions_with_pagination_options(
226241
offset = int(expected_total / 4)
227242
limit = int(expected_total / 2)
228243

229-
total, rows = await get_user_payments_transactions(
230-
connection, user_id=user_id, limit=limit, offset=offset
231-
)
232-
assert total == expected_total
233-
assert [r.payment_id for r in rows] == expected_payments_ids[::-1][
234-
offset : (offset + limit)
235-
], "newest first"
236-
237-
# test offset>=expected_total?
238-
total, rows = await get_user_payments_transactions(
239-
connection, user_id=user_id, offset=expected_total
240-
)
241-
assert not rows
242-
243-
# test limit==0?
244-
total, rows = await get_user_payments_transactions(
245-
connection, user_id=user_id, limit=0
246-
)
247-
assert not rows
248-
assert not rows
244+
async with asyncpg_engine.connect() as connection:
245+
total, rows = await get_user_payments_transactions(
246+
connection, user_id=user_id, limit=limit, offset=offset
247+
)
248+
assert total == expected_total
249+
assert [r.payment_id for r in rows] == expected_payments_ids[::-1][
250+
offset : (offset + limit)
251+
], "newest first"
252+
253+
# test offset>=expected_total?
254+
total, rows = await get_user_payments_transactions(
255+
connection, user_id=user_id, offset=expected_total
256+
)
257+
assert not rows
258+
259+
# test limit==0?
260+
total, rows = await get_user_payments_transactions(
261+
connection, user_id=user_id, limit=0
262+
)
263+
assert not rows
264+
assert not rows

services/web/server/src/simcore_service_webserver/payments/_onetime_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
)
2323
from simcore_postgres_database.utils_repos import (
2424
pass_or_acquire_connection,
25-
transaction_context,
2625
)
2726
from sqlalchemy.ext.asyncio import AsyncConnection
2827

@@ -104,7 +103,8 @@ async def complete_payment_transaction(
104103
if invoice_url:
105104
optional_kwargs["invoice_url"] = invoice_url
106105

107-
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
106+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
107+
# NOTE: update_payment_transaction_state() uses a transaction internally, therefore we use pass_or_acquire_connection(...)
108108
row = await update_payment_transaction_state(
109109
conn,
110110
payment_id=payment_id,

0 commit comments

Comments
 (0)