Skip to content

Commit 0b3f782

Browse files
committed
🐛 Fix: rename PaymentsTransactionsDB to PaymentsTransactionsGetDB for consistency and update database connection handling to use asyncpg engine
1 parent 1dc3725 commit 0b3f782

File tree

4 files changed

+58
-51
lines changed

4 files changed

+58
-51
lines changed

packages/postgres-database/src/simcore_postgres_database/utils_payments.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
from typing import Final, TypeAlias
66

77
import sqlalchemy as sa
8-
from aiopg.sa.connection import SAConnection
9-
from aiopg.sa.result import ResultProxy, RowProxy
8+
import sqlalchemy.exc
9+
from sqlalchemy.ext.asyncio import AsyncConnection
1010

11-
from . import aiopg_errors
1211
from .models.payments_transactions import PaymentTransactionState, payments_transactions
1312

1413
_logger = logging.getLogger(__name__)
1514

1615

1716
PaymentID: TypeAlias = str
18-
PaymentTransactionRow: TypeAlias = RowProxy
17+
PaymentTransactionRow: TypeAlias = sa.Row
1918

2019

2120
UNSET: Final[str] = "__UNSET__"
@@ -39,7 +38,7 @@ class PaymentAlreadyAcked(PaymentFailure): ...
3938

4039

4140
async def insert_init_payment_transaction(
42-
connection: SAConnection,
41+
connection: AsyncConnection,
4342
*,
4443
payment_id: str,
4544
price_dollars: Decimal,
@@ -66,14 +65,14 @@ async def insert_init_payment_transaction(
6665
initiated_at=initiated_at,
6766
)
6867
)
69-
except aiopg_errors.UniqueViolation:
68+
except sqlalchemy.exc.IntegrityError:
7069
return PaymentAlreadyExists(payment_id)
7170

7271
return payment_id
7372

7473

7574
async def update_payment_transaction_state(
76-
connection: SAConnection,
75+
connection: AsyncConnection,
7776
*,
7877
payment_id: str,
7978
completion_state: PaymentTransactionState,
@@ -101,16 +100,15 @@ async def update_payment_transaction_state(
101100
optional["invoice_url"] = invoice_url
102101

103102
async with connection.begin():
104-
row = await (
105-
await connection.execute(
106-
sa.select(
107-
payments_transactions.c.initiated_at,
108-
payments_transactions.c.completed_at,
109-
)
110-
.where(payments_transactions.c.payment_id == payment_id)
111-
.with_for_update()
103+
result = await connection.execute(
104+
sa.select(
105+
payments_transactions.c.initiated_at,
106+
payments_transactions.c.completed_at,
112107
)
113-
).fetchone()
108+
.where(payments_transactions.c.payment_id == payment_id)
109+
.with_for_update()
110+
)
111+
row = result.one_or_none()
114112

115113
if row is None:
116114
return PaymentNotFound(payment_id=payment_id)
@@ -125,16 +123,14 @@ async def update_payment_transaction_state(
125123
payments_transactions.update()
126124
.values(completed_at=sa.func.now(), state=completion_state, **optional)
127125
.where(payments_transactions.c.payment_id == payment_id)
128-
.returning(sa.literal_column("*"))
126+
.returning(payments_transactions)
129127
)
130-
row = await result.first()
131-
assert row, "execute above should have caught this" # nosec
132-
assert isinstance(row, RowProxy) # nosec
128+
row = result.one()
133129
return row
134130

135131

136132
async def get_user_payments_transactions(
137-
connection: SAConnection,
133+
connection: AsyncConnection,
138134
*,
139135
user_id: int,
140136
offset: int | None = None,
@@ -149,7 +145,7 @@ async def get_user_payments_transactions(
149145

150146
# NOTE: what if between these two calls there are new rows? can we get this in an atomic call?å
151147
stmt = (
152-
payments_transactions.select()
148+
sa.select(payments_transactions)
153149
.where(payments_transactions.c.user_id == user_id)
154150
.order_by(payments_transactions.c.created.desc())
155151
) # newest first
@@ -162,6 +158,6 @@ async def get_user_payments_transactions(
162158
# InvalidRowCountInLimitClause: LIMIT must not be negative
163159
stmt = stmt.limit(limit)
164160

165-
result: ResultProxy = await connection.execute(stmt)
166-
rows = await result.fetchall() or []
161+
result = await connection.execute(stmt)
162+
rows = result.fetchall()
167163
return total_number_of_items, rows

packages/postgres-database/tests/test_models_payments_transactions.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import pytest
1212
import sqlalchemy as sa
13-
from aiopg.sa.connection import SAConnection
14-
from aiopg.sa.result import RowProxy
1513
from faker import Faker
1614
from pytest_simcore.helpers.faker_factories import random_payment_transaction, utcnow
1715
from simcore_postgres_database.models.payments_transactions import (
@@ -26,9 +24,10 @@
2624
insert_init_payment_transaction,
2725
update_payment_transaction_state,
2826
)
27+
from sqlalchemy.ext.asyncio import AsyncConnection
2928

3029

31-
async def test_numerics_precission_and_scale(connection: SAConnection):
30+
async def test_numerics_precission_and_scale(connection: AsyncConnection):
3231
# https://docs.sqlalchemy.org/en/20/core/type_basics.html#sqlalchemy.types.Numeric
3332
# precision: This parameter specifies the total number of digits that can be stored, both before and after the decimal point.
3433
# scale: This parameter specifies the number of digits that can be stored to the right of the decimal point.
@@ -58,7 +57,7 @@ def _remove_not_required(data: dict[str, Any]) -> dict[str, Any]:
5857

5958

6059
@pytest.fixture
61-
def init_transaction(connection: SAConnection):
60+
def init_transaction(connection: AsyncConnection):
6261
async def _init(payment_id: str):
6362
# get payment_id from payment-gateway
6463
values = _remove_not_required(random_payment_transaction(payment_id=payment_id))
@@ -81,7 +80,7 @@ def payment_id() -> str:
8180

8281

8382
async def test_init_transaction_sets_it_as_pending(
84-
connection: SAConnection, init_transaction: Callable, payment_id: str
83+
connection: AsyncConnection, init_transaction: Callable, payment_id: str
8584
):
8685
values = await init_transaction(payment_id)
8786
assert values["payment_id"] == payment_id
@@ -94,11 +93,11 @@ async def test_init_transaction_sets_it_as_pending(
9493
payments_transactions.c.state_message,
9594
).where(payments_transactions.c.payment_id == payment_id)
9695
)
97-
row: RowProxy | None = await result.fetchone()
96+
row = result.one_or_none()
9897
assert row is not None
9998

10099
# tests that defaults are right?
101-
assert dict(row.items()) == {
100+
assert dict(row._mapping.items()) == {
102101
"completed_at": None,
103102
"state": PaymentTransactionState.PENDING,
104103
"state_message": None,
@@ -127,7 +126,7 @@ def invoice_url(faker: Faker, expected_state: PaymentTransactionState) -> str |
127126
],
128127
)
129128
async def test_complete_transaction(
130-
connection: SAConnection,
129+
connection: AsyncConnection,
131130
init_transaction: Callable,
132131
payment_id: str,
133132
expected_state: PaymentTransactionState,
@@ -152,7 +151,7 @@ async def test_complete_transaction(
152151

153152

154153
async def test_update_transaction_failures_and_exceptions(
155-
connection: SAConnection,
154+
connection: AsyncConnection,
156155
init_transaction: Callable,
157156
payment_id: str,
158157
):
@@ -188,7 +187,9 @@ def user_id() -> int:
188187

189188

190189
@pytest.fixture
191-
def create_fake_user_transactions(connection: SAConnection, user_id: int) -> Callable:
190+
def create_fake_user_transactions(
191+
connection: AsyncConnection, user_id: int
192+
) -> Callable:
192193
async def _go(expected_total=5):
193194
payment_ids = []
194195
for _ in range(expected_total):
@@ -204,7 +205,7 @@ async def _go(expected_total=5):
204205

205206

206207
async def test_get_user_payments_transactions(
207-
connection: SAConnection, create_fake_user_transactions: Callable, user_id: int
208+
connection: AsyncConnection, create_fake_user_transactions: Callable, user_id: int
208209
):
209210
expected_payments_ids = await create_fake_user_transactions()
210211
expected_total = len(expected_payments_ids)
@@ -216,7 +217,7 @@ async def test_get_user_payments_transactions(
216217

217218

218219
async def test_get_user_payments_transactions_with_pagination_options(
219-
connection: SAConnection, create_fake_user_transactions: Callable, user_id: int
220+
connection: AsyncConnection, create_fake_user_transactions: Callable, user_id: int
220221
):
221222
expected_payments_ids = await create_fake_user_transactions()
222223
expected_total = len(expected_payments_ids)
@@ -244,3 +245,4 @@ async def test_get_user_payments_transactions_with_pagination_options(
244245
connection, user_id=user_id, limit=0
245246
)
246247
assert not rows
248+
assert not rows

services/web/server/src/simcore_service_webserver/payments/_onetime_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from simcore_postgres_database.utils_payments import insert_init_payment_transaction
2525
from yarl import URL
2626

27-
from ..db.plugin import get_database_engine_legacy
27+
from ..db.plugin import get_asyncpg_engine
2828
from ..products import products_service
2929
from ..resource_usage.service import add_credits_to_wallet
3030
from ..users import users_service
@@ -46,7 +46,7 @@
4646

4747

4848
def _to_api_model(
49-
transaction: _onetime_db.PaymentsTransactionsDB,
49+
transaction: _onetime_db.PaymentsTransactionsGetDB,
5050
) -> PaymentTransaction:
5151
data: dict[str, Any] = {
5252
"payment_id": transaction.payment_id,
@@ -90,7 +90,7 @@ async def _fake_init_payment(
9090
.with_query(id=payment_id)
9191
)
9292
# (2) Annotate INIT transaction
93-
async with get_database_engine_legacy(app).acquire() as conn:
93+
async with get_asyncpg_engine(app).begin() as conn:
9494
await insert_init_payment_transaction(
9595
conn,
9696
payment_id=payment_id,

services/web/server/src/simcore_service_webserver/payments/_onetime_db.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@
2020
get_user_payments_transactions,
2121
update_payment_transaction_state,
2222
)
23+
from simcore_postgres_database.utils_repos import (
24+
pass_or_acquire_connection,
25+
transaction_context,
26+
)
27+
from sqlalchemy.ext.asyncio import AsyncConnection
2328

24-
from ..db.plugin import get_database_engine_legacy
29+
from ..db.plugin import get_asyncpg_engine
2530
from .errors import PaymentCompletedError, PaymentNotFoundError
2631

2732
_logger = logging.getLogger(__name__)
@@ -30,7 +35,7 @@
3035
#
3136
# NOTE: this will be moved to the payments service
3237
# NOTE: with https://sqlmodel.tiangolo.com/ we would only define this once!
33-
class PaymentsTransactionsDB(BaseModel):
38+
class PaymentsTransactionsGetDB(BaseModel):
3439
payment_id: PaymentID
3540
price_dollars: Decimal # accepts negatives
3641
osparc_credits: Decimal # accepts negatives
@@ -48,43 +53,47 @@ class PaymentsTransactionsDB(BaseModel):
4853

4954

5055
async def list_user_payment_transactions(
51-
app,
56+
app: web.Application,
57+
connection: AsyncConnection | None = None,
5258
*,
5359
user_id: UserID,
5460
offset: PositiveInt,
5561
limit: PositiveInt,
56-
) -> tuple[int, list[PaymentsTransactionsDB]]:
62+
) -> tuple[int, list[PaymentsTransactionsGetDB]]:
5763
"""List payments done by a give user (any wallet)
5864
5965
Sorted by newest-first
6066
"""
61-
async with get_database_engine_legacy(app).acquire() as conn:
67+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
6268
total_number_of_items, rows = await get_user_payments_transactions(
6369
conn, user_id=user_id, offset=offset, limit=limit
6470
)
65-
page = TypeAdapter(list[PaymentsTransactionsDB]).validate_python(rows)
71+
page = TypeAdapter(list[PaymentsTransactionsGetDB]).validate_python(rows)
6672
return total_number_of_items, page
6773

6874

69-
async def get_pending_payment_transactions_ids(app: web.Application) -> list[PaymentID]:
70-
async with get_database_engine_legacy(app).acquire() as conn:
75+
async def get_pending_payment_transactions_ids(
76+
app: web.Application, connection: AsyncConnection | None = None
77+
) -> list[PaymentID]:
78+
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
7179
result = await conn.execute(
7280
sa.select(payments_transactions.c.payment_id)
7381
.where(payments_transactions.c.completed_at == None) # noqa: E711
7482
.order_by(payments_transactions.c.initiated_at.asc()) # oldest first
7583
)
76-
rows = await result.fetchall() or []
84+
rows = result.fetchall()
7785
return [TypeAdapter(PaymentID).validate_python(row.payment_id) for row in rows]
7886

7987

8088
async def complete_payment_transaction(
8189
app: web.Application,
90+
connection: AsyncConnection | None = None,
8291
*,
8392
payment_id: PaymentID,
8493
completion_state: PaymentTransactionState,
8594
state_message: str | None,
8695
invoice_url: HttpUrl | None = None,
87-
) -> PaymentsTransactionsDB:
96+
) -> PaymentsTransactionsGetDB:
8897
"""
8998
9099
Raises:
@@ -95,7 +104,7 @@ async def complete_payment_transaction(
95104
if invoice_url:
96105
optional_kwargs["invoice_url"] = invoice_url
97106

98-
async with get_database_engine_legacy(app).acquire() as conn:
107+
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
99108
row = await update_payment_transaction_state(
100109
conn,
101110
payment_id=payment_id,
@@ -111,4 +120,4 @@ async def complete_payment_transaction(
111120
raise PaymentCompletedError(payment_id=row.payment_id)
112121

113122
assert row # nosec
114-
return PaymentsTransactionsDB.model_validate(row)
123+
return PaymentsTransactionsGetDB.model_validate(row)

0 commit comments

Comments
 (0)