Skip to content

Commit ed2a0d5

Browse files
committed
first round
1 parent bf07906 commit ed2a0d5

File tree

5 files changed

+113
-81
lines changed

5 files changed

+113
-81
lines changed

services/web/server/src/simcore_service_webserver/garbage_collector/_tasks_users.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
from collections.abc import AsyncIterator, Callable
99

1010
from aiohttp import web
11-
from aiopg.sa.engine import Engine
1211
from models_library.users import UserID
1312
from servicelib.logging_utils import get_log_record_extra, log_context
1413
from tenacity import retry
1514
from tenacity.before_sleep import before_sleep_log
1615
from tenacity.wait import wait_exponential
1716

18-
from ..db.plugin import get_database_engine
1917
from ..login.utils import notify_user_logout
2018
from ..security.api import clean_auth_policy_cache
2119
from ..users.api import update_expired_users
@@ -60,10 +58,8 @@ async def _update_expired_users(app: web.Application):
6058
"""
6159
It is resilient, i.e. if update goes wrong, it waits a bit and retries
6260
"""
63-
engine: Engine = get_database_engine(app)
64-
assert engine # nosec
6561

66-
if updated := await update_expired_users(engine):
62+
if updated := await update_expired_users(app):
6763
# expired users might be cached in the auth. If so, any request
6864
# with this user-id will get thru producing unexpected side-effects
6965
await clean_auth_policy_cache(app)

services/web/server/src/simcore_service_webserver/users/_api.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from simcore_postgres_database.models.users import UserStatus
1111

1212
from ..db.plugin import get_database_engine
13-
from . import _db, _schemas
14-
from ._db import get_user_or_raise
15-
from ._db import list_user_permissions as db_list_of_permissions
16-
from ._db import update_user_status
13+
from . import _schemas, _users_repository
14+
from ._users_repository import get_user_or_raise
15+
from ._users_repository import list_user_permissions as db_list_of_permissions
16+
from ._users_repository import update_user_status
1717
from .exceptions import AlreadyPreRegisteredError
1818
from .schemas import Permission
1919

@@ -73,13 +73,13 @@ async def search_users(
7373
app: web.Application, email_glob: str, *, include_products: bool = False
7474
) -> list[_schemas.UserProfile]:
7575
# NOTE: this search is deploy-wide i.e. independent of the product!
76-
rows = await _db.search_users_and_get_profile(
76+
rows = await _users_repository.search_users_and_get_profile(
7777
get_database_engine(app), email_like=_glob_to_sql_like(email_glob)
7878
)
7979

8080
async def _list_products_or_none(user_id):
8181
if user_id is not None and include_products:
82-
products = await _db.get_user_products(
82+
products = await _users_repository.get_user_products(
8383
get_database_engine(app), user_id=user_id
8484
)
8585
return [_.product_name for _ in products]
@@ -136,7 +136,7 @@ async def pre_register_user(
136136
if key in details:
137137
details[f"pre_{key}"] = details.pop(key)
138138

139-
await _db.new_user_details(
139+
await _users_repository.new_user_details(
140140
get_database_engine(app),
141141
email=profile.email,
142142
created_by=creator_user_id,
@@ -152,8 +152,10 @@ async def pre_register_user(
152152
async def get_user_invoice_address(
153153
app: web.Application, user_id: UserID
154154
) -> UserInvoiceAddress:
155-
user_billing_details: UserBillingDetails = await _db.get_user_billing_details(
156-
get_database_engine(app), user_id=user_id
155+
user_billing_details: UserBillingDetails = (
156+
await _users_repository.get_user_billing_details(
157+
get_database_engine(app), user_id=user_id
158+
)
157159
)
158160
_user_billing_country = pycountry.countries.lookup(user_billing_details.country)
159161
_user_billing_country_alpha_2_format = _user_billing_country.alpha_2

services/web/server/src/simcore_service_webserver/users/_db.py renamed to services/web/server/src/simcore_service_webserver/users/_users_repository.py

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
import sqlalchemy as sa
44
from aiohttp import web
5-
from aiopg.sa.connection import SAConnection
6-
from aiopg.sa.engine import Engine
7-
from aiopg.sa.result import ResultProxy, RowProxy
85
from models_library.users import GroupID, UserBillingDetails, UserID
96
from simcore_postgres_database.models.groups import groups, user_to_groups
107
from simcore_postgres_database.models.products import products
@@ -16,59 +13,78 @@
1613
GroupExtraPropertiesNotFoundError,
1714
GroupExtraPropertiesRepo,
1815
)
16+
from simcore_postgres_database.utils_repos import (
17+
pass_or_acquire_connection,
18+
transaction_context,
19+
)
1920
from simcore_postgres_database.utils_users import UsersRepo
2021
from simcore_service_webserver.users.exceptions import UserNotFoundError
22+
from sqlalchemy.engine.row import Row
23+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
2124

2225
from ..db.models import user_to_groups
23-
from ..db.plugin import get_database_engine
26+
from ..db.plugin import get_asyncpg_engine
2427
from .exceptions import BillingDetailsNotFoundError
2528
from .schemas import Permission
2629

2730
_ALL = None
2831

2932

3033
async def get_user_or_raise(
31-
engine: Engine, *, user_id: UserID, return_column_names: list[str] | None = _ALL
32-
) -> RowProxy:
34+
engine: AsyncEngine,
35+
connection: AsyncConnection | None = None,
36+
*,
37+
user_id: UserID,
38+
return_column_names: list[str] | None = _ALL,
39+
) -> Row:
3340
if return_column_names == _ALL:
3441
return_column_names = list(users.columns.keys())
3542

3643
assert return_column_names is not None # nosec
3744
assert set(return_column_names).issubset(users.columns.keys()) # nosec
3845

39-
async with engine.acquire() as conn:
40-
row: RowProxy | None = await (
41-
await conn.execute(
42-
sa.select(*(users.columns[name] for name in return_column_names)).where(
43-
users.c.id == user_id
44-
)
46+
async with pass_or_acquire_connection(engine, connection) as conn:
47+
result = await conn.stream(
48+
sa.select(*(users.columns[name] for name in return_column_names)).where(
49+
users.c.id == user_id
4550
)
46-
).first()
51+
)
52+
row = await result.first()
4753
if row is None:
4854
raise UserNotFoundError(uid=user_id)
4955
return row
5056

5157

52-
async def get_users_ids_in_group(conn: SAConnection, gid: GroupID) -> set[UserID]:
53-
result: set[UserID] = set()
54-
query_result = await conn.execute(
55-
sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == gid)
56-
)
57-
async for entry in query_result:
58-
result.add(entry[0])
59-
return result
58+
async def get_users_ids_in_group(
59+
engine: AsyncEngine,
60+
connection: AsyncConnection | None = None,
61+
*,
62+
group_id: GroupID,
63+
) -> set[UserID]:
64+
async with pass_or_acquire_connection(engine, connection) as conn:
65+
result = await conn.stream(
66+
sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == group_id)
67+
)
68+
return {row.uid async for row in result}
6069

6170

6271
async def list_user_permissions(
63-
app: web.Application, *, user_id: UserID, product_name: str
72+
app: web.Application,
73+
connection: AsyncConnection | None = None,
74+
*,
75+
user_id: UserID,
76+
product_name: str,
6477
) -> list[Permission]:
6578
override_services_specifications = Permission(
6679
name="override_services_specifications",
6780
allowed=False,
6881
)
6982
with contextlib.suppress(GroupExtraPropertiesNotFoundError):
70-
async with get_database_engine(app).acquire() as conn:
83+
async with pass_or_acquire_connection(
84+
get_asyncpg_engine(app), connection
85+
) as conn:
7186
user_group_extra_properties = (
87+
# TODO: adapt to asyncpg
7288
await GroupExtraPropertiesRepo.get_aggregated_properties_for_user(
7389
conn, user_id=user_id, product_name=product_name
7490
)
@@ -80,34 +96,43 @@ async def list_user_permissions(
8096
return [override_services_specifications]
8197

8298

83-
async def do_update_expired_users(conn: SAConnection) -> list[UserID]:
84-
result: ResultProxy = await conn.execute(
85-
users.update()
86-
.values(status=UserStatus.EXPIRED)
87-
.where(
88-
(users.c.expires_at.is_not(None))
89-
& (users.c.status == UserStatus.ACTIVE)
90-
& (users.c.expires_at < sa.sql.func.now())
99+
async def do_update_expired_users(
100+
engine: AsyncEngine,
101+
connection: AsyncConnection | None = None,
102+
) -> list[UserID]:
103+
async with transaction_context(engine, connection) as conn:
104+
result = await conn.stream(
105+
users.update()
106+
.values(status=UserStatus.EXPIRED)
107+
.where(
108+
(users.c.expires_at.is_not(None))
109+
& (users.c.status == UserStatus.ACTIVE)
110+
& (users.c.expires_at < sa.sql.func.now())
111+
)
112+
.returning(users.c.id)
91113
)
92-
.returning(users.c.id)
93-
)
94-
if rows := await result.fetchall():
95-
return [r.id for r in rows]
96-
return []
114+
return [row.id async for row in result]
97115

98116

99117
async def update_user_status(
100-
engine: Engine, *, user_id: UserID, new_status: UserStatus
118+
engine: AsyncEngine,
119+
connection: AsyncConnection | None = None,
120+
*,
121+
user_id: UserID,
122+
new_status: UserStatus,
101123
):
102-
async with engine.acquire() as conn:
124+
async with transaction_context(engine, connection) as conn:
103125
await conn.execute(
104126
users.update().values(status=new_status).where(users.c.id == user_id)
105127
)
106128

107129

108130
async def search_users_and_get_profile(
109-
engine: Engine, *, email_like: str
110-
) -> list[RowProxy]:
131+
engine: AsyncEngine,
132+
connection: AsyncConnection | None = None,
133+
*,
134+
email_like: str,
135+
) -> list[Row]:
111136

112137
users_alias = sa.alias(users, name="users_alias")
113138

@@ -117,7 +142,7 @@ async def search_users_and_get_profile(
117142
.label("invited_by")
118143
)
119144

120-
async with engine.acquire() as conn:
145+
async with pass_or_acquire_connection(engine, connection) as conn:
121146
columns = (
122147
users.c.first_name,
123148
users.c.last_name,
@@ -159,12 +184,17 @@ async def search_users_and_get_profile(
159184
.where(users.c.email.like(email_like))
160185
)
161186

162-
result = await conn.execute(sa.union(left_outer_join, right_outer_join))
163-
return await result.fetchall() or []
187+
result = await conn.stream(sa.union(left_outer_join, right_outer_join))
188+
return [row async for row in result]
164189

165190

166-
async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
167-
async with engine.acquire() as conn:
191+
async def get_user_products(
192+
engine: AsyncEngine,
193+
connection: AsyncConnection | None = None,
194+
*,
195+
user_id: UserID,
196+
) -> list[Row]:
197+
async with pass_or_acquire_connection(engine, connection) as conn:
168198
product_name_subq = (
169199
sa.select(products.c.name)
170200
.where(products.c.group_id == groups.c.gid)
@@ -186,14 +216,19 @@ async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
186216
.where(users.c.id == user_id)
187217
.order_by(groups.c.gid)
188218
)
189-
result = await conn.execute(query)
190-
return await result.fetchall() or []
219+
result = await conn.stream(query)
220+
return [row async for row in result]
191221

192222

193223
async def new_user_details(
194-
engine: Engine, email: str, created_by: UserID, **other_values
224+
engine: AsyncEngine,
225+
connection: AsyncConnection | None = None,
226+
*,
227+
email: str,
228+
created_by: UserID,
229+
**other_values,
195230
) -> None:
196-
async with engine.acquire() as conn:
231+
async with transaction_context(engine, connection) as conn:
197232
await conn.execute(
198233
sa.insert(users_pre_registration_details).values(
199234
created_by=created_by, pre_email=email, **other_values
@@ -202,13 +237,13 @@ async def new_user_details(
202237

203238

204239
async def get_user_billing_details(
205-
engine: Engine, user_id: UserID
240+
engine: AsyncEngine, connection: AsyncConnection | None = None, *, user_id: UserID
206241
) -> UserBillingDetails:
207242
"""
208243
Raises:
209244
BillingDetailsNotFoundError
210245
"""
211-
async with engine.acquire() as conn:
246+
async with pass_or_acquire_connection(engine, connection) as conn:
212247
user_billing_details = await UsersRepo.get_billing_details(conn, user_id)
213248
if not user_billing_details:
214249
raise BillingDetailsNotFoundError(user_id=user_id)

services/web/server/src/simcore_service_webserver/users/api.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import simcore_postgres_database.errors as db_errors
1313
import sqlalchemy as sa
1414
from aiohttp import web
15-
from aiopg.sa.engine import Engine
1615
from aiopg.sa.result import RowProxy
1716
from models_library.api_schemas_webserver.users import (
1817
ProfileGet,
@@ -29,11 +28,11 @@
2928
GroupExtraPropertiesNotFoundError,
3029
)
3130

32-
from ..db.plugin import get_database_engine
31+
from ..db.plugin import get_asyncpg_engine, get_database_engine
3332
from ..groups.models import convert_groups_db_to_schema
3433
from ..login.storage import AsyncpgStorage, get_plugin_storage
3534
from ..security.api import clean_auth_policy_cache
36-
from . import _db
35+
from . import _users_repository
3736
from ._api import get_user_credentials, get_user_invoice_address, set_user_as_deleted
3837
from ._models import ToUserUpdateDB
3938
from ._preferences_api import get_frontend_user_preferences_aggregation
@@ -216,8 +215,8 @@ async def get_user_name_and_email(
216215
Returns:
217216
(user, email)
218217
"""
219-
row = await _db.get_user_or_raise(
220-
get_database_engine(app),
218+
row = await _users_repository.get_user_or_raise(
219+
get_asyncpg_engine(app),
221220
user_id=_parse_as_user(user_id),
222221
return_column_names=["name", "email"],
223222
)
@@ -242,8 +241,8 @@ async def get_user_display_and_id_names(
242241
Raises:
243242
UserNotFoundError
244243
"""
245-
row = await _db.get_user_or_raise(
246-
get_database_engine(app),
244+
row = await _users_repository.get_user_or_raise(
245+
get_asyncpg_engine(app),
247246
user_id=_parse_as_user(user_id),
248247
return_column_names=["name", "email", "first_name", "last_name"],
249248
)
@@ -318,7 +317,9 @@ async def get_user(app: web.Application, user_id: UserID) -> dict[str, Any]:
318317
"""
319318
:raises UserNotFoundError:
320319
"""
321-
row = await _db.get_user_or_raise(engine=get_database_engine(app), user_id=user_id)
320+
row = await _users_repository.get_user_or_raise(
321+
engine=get_asyncpg_engine(app), user_id=user_id
322+
)
322323
return dict(row)
323324

324325

@@ -332,14 +333,13 @@ async def get_user_id_from_gid(app: web.Application, primary_gid: int) -> UserID
332333

333334

334335
async def get_users_in_group(app: web.Application, gid: GroupID) -> set[UserID]:
335-
engine = get_database_engine(app)
336-
async with engine.acquire() as conn:
337-
return await _db.get_users_ids_in_group(conn, gid)
336+
return await _users_repository.get_users_ids_in_group(
337+
get_asyncpg_engine(app), group_id=gid
338+
)
338339

339340

340-
async def update_expired_users(engine: Engine) -> list[UserID]:
341-
async with engine.acquire() as conn:
342-
return await _db.do_update_expired_users(conn)
341+
async def update_expired_users(app: web.Application) -> list[UserID]:
342+
return await _users_repository.do_update_expired_users(get_asyncpg_engine(app))
343343

344344

345345
assert set_user_as_deleted # nosec

0 commit comments

Comments
 (0)