Skip to content

Commit 51615a1

Browse files
committed
refactor pg utils
1 parent c0042de commit 51615a1

File tree

3 files changed

+38
-32
lines changed

3 files changed

+38
-32
lines changed

packages/postgres-database/src/simcore_postgres_database/utils_products_prices.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import NamedTuple, TypeAlias
33

44
import sqlalchemy as sa
5-
from aiopg.sa.connection import SAConnection
5+
from sqlalchemy.ext.asyncio import AsyncConnection
66

77
from .constants import QUANTIZE_EXP_ARG
88
from .models.products_prices import products_prices
@@ -17,9 +17,9 @@ class ProductPriceInfo(NamedTuple):
1717

1818

1919
async def get_product_latest_price_info_or_none(
20-
conn: SAConnection, product_name: str
20+
conn: AsyncConnection, product_name: str
2121
) -> ProductPriceInfo | None:
22-
"""None menans the product is not billable"""
22+
"""If the product is not billable, it returns None"""
2323
# newest price of a product
2424
result = await conn.execute(
2525
sa.select(
@@ -30,7 +30,7 @@ async def get_product_latest_price_info_or_none(
3030
.order_by(sa.desc(products_prices.c.valid_from))
3131
.limit(1)
3232
)
33-
row = await result.first()
33+
row = result.one_or_none()
3434

3535
if row and row.usd_per_credit is not None:
3636
assert row.min_payment_amount_usd is not None # nosec
@@ -44,26 +44,26 @@ async def get_product_latest_price_info_or_none(
4444

4545

4646
async def get_product_latest_stripe_info(
47-
conn: SAConnection, product_name: str
47+
conn: AsyncConnection, product_name: str
4848
) -> tuple[StripePriceID, StripeTaxRateID]:
4949
# Stripe info of a product for latest price
50-
row = await (
51-
await conn.execute(
52-
sa.select(
53-
products_prices.c.stripe_price_id,
54-
products_prices.c.stripe_tax_rate_id,
55-
)
56-
.where(products_prices.c.product_name == product_name)
57-
.order_by(sa.desc(products_prices.c.valid_from))
58-
.limit(1)
50+
result = await conn.execute(
51+
sa.select(
52+
products_prices.c.stripe_price_id,
53+
products_prices.c.stripe_tax_rate_id,
5954
)
60-
).fetchone()
55+
.where(products_prices.c.product_name == product_name)
56+
.order_by(sa.desc(products_prices.c.valid_from))
57+
.limit(1)
58+
)
59+
60+
row = result.one_or_none()
6161
if row is None:
6262
msg = f"Required Stripe information missing from product {product_name=}"
6363
raise ValueError(msg)
6464
return (row.stripe_price_id, row.stripe_tax_rate_id)
6565

6666

67-
async def is_payment_enabled(conn: SAConnection, product_name: str) -> bool:
67+
async def is_payment_enabled(conn: AsyncConnection, product_name: str) -> bool:
6868
p = await get_product_latest_price_info_or_none(conn, product_name=product_name)
6969
return bool(p) # zero or None is disabled

packages/postgres-database/tests/test_models_products_prices.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
# pylint: disable=too-many-arguments
55

66

7+
from collections.abc import AsyncIterator
8+
79
import pytest
810
import sqlalchemy as sa
9-
from aiopg.sa.connection import SAConnection
10-
from aiopg.sa.result import RowProxy
1111
from faker import Faker
1212
from pytest_simcore.helpers.faker_factories import random_product
1313
from simcore_postgres_database.errors import CheckViolation, ForeignKeyViolation
@@ -18,23 +18,30 @@
1818
get_product_latest_stripe_info,
1919
is_payment_enabled,
2020
)
21+
from sqlalchemy.engine.row import Row
22+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
2123

2224

2325
@pytest.fixture
24-
async def fake_product(connection: SAConnection) -> RowProxy:
26+
async def connection(asyncpg_engine: AsyncEngine) -> AsyncIterator[AsyncConnection]:
27+
async with asyncpg_engine.begin() as conn:
28+
yield conn
29+
30+
31+
@pytest.fixture
32+
async def fake_product(connection: AsyncConnection) -> Row:
2533
result = await connection.execute(
2634
products.insert()
2735
.values(random_product(group_id=None))
2836
.returning(sa.literal_column("*"))
2937
)
30-
product = await result.first()
31-
assert product is not None
32-
return product
38+
return result.one()
3339

3440

3541
async def test_creating_product_prices(
36-
connection: SAConnection, fake_product: RowProxy, faker: Faker
42+
connection: AsyncConnection, fake_product: Row, faker: Faker
3743
):
44+
3845
# a price per product
3946
result = await connection.execute(
4047
products_prices.insert()
@@ -47,12 +54,12 @@ async def test_creating_product_prices(
4754
)
4855
.returning(sa.literal_column("*"))
4956
)
50-
product_prices = await result.first()
57+
product_prices = result.one()
5158
assert product_prices
5259

5360

5461
async def test_non_negative_price_not_allowed(
55-
connection: SAConnection, fake_product: RowProxy, faker: Faker
62+
connection: AsyncConnection, fake_product: Row, faker: Faker
5663
):
5764
# negative price not allowed
5865
with pytest.raises(CheckViolation) as exc_info:
@@ -81,7 +88,7 @@ async def test_non_negative_price_not_allowed(
8188

8289

8390
async def test_delete_price_constraints(
84-
connection: SAConnection, fake_product: RowProxy, faker: Faker
91+
connection: AsyncConnection, fake_product: Row, faker: Faker
8592
):
8693
# products_prices
8794
await connection.execute(
@@ -106,7 +113,7 @@ async def test_delete_price_constraints(
106113

107114

108115
async def test_get_product_latest_price_or_none(
109-
connection: SAConnection, fake_product: RowProxy, faker: Faker
116+
connection: AsyncConnection, fake_product: Row, faker: Faker
110117
):
111118
# undefined product
112119
assert (
@@ -130,7 +137,7 @@ async def test_get_product_latest_price_or_none(
130137

131138

132139
async def test_price_history_of_a_product(
133-
connection: SAConnection, fake_product: RowProxy, faker: Faker
140+
connection: AsyncConnection, fake_product: Row, faker: Faker
134141
):
135142
# initial price
136143
await connection.execute(
@@ -163,7 +170,7 @@ async def test_price_history_of_a_product(
163170

164171

165172
async def test_get_product_latest_stripe_info(
166-
connection: SAConnection, fake_product: RowProxy, faker: Faker
173+
connection: AsyncConnection, fake_product: Row, faker: Faker
167174
):
168175
stripe_price_id_value = faker.word()
169176
stripe_tax_rate_id_value = faker.word()
@@ -187,5 +194,5 @@ async def test_get_product_latest_stripe_info(
187194
assert product_stripe_info[1] == stripe_tax_rate_id_value
188195

189196
# undefined product
190-
with pytest.raises(ValueError) as exc_info:
197+
with pytest.raises(ValueError, match="undefined"):
191198
await get_product_latest_stripe_info(connection, product_name="undefined")

services/web/server/src/simcore_service_webserver/products/_repository.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any
44

55
import sqlalchemy as sa
6-
from aiopg.sa.connection import SAConnection
76
from models_library.groups import GroupID
87
from models_library.products import ProductName, ProductStripeInfoGet
98
from simcore_postgres_database.constants import QUANTIZE_EXP_ARG
@@ -57,7 +56,7 @@
5756

5857

5958
async def get_product_payment_fields(
60-
conn: SAConnection, product_name: ProductName
59+
conn: AsyncConnection, product_name: ProductName
6160
) -> PaymentFieldsTuple:
6261
price_info = await get_product_latest_price_info_or_none(
6362
conn, product_name=product_name

0 commit comments

Comments
 (0)