Skip to content

Commit 497dc84

Browse files
committed
first round
1 parent e6fc642 commit 497dc84

File tree

5 files changed

+114
-82
lines changed

5 files changed

+114
-82
lines changed

services/web/server/src/simcore_service_webserver/garbage_collector/_tasks_users.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
from collections.abc import AsyncIterator, Callable
99

1010
from aiohttp import web
11-
from aiopg.sa.engine import Engine
1211
from models_library.users import UserID
1312
from servicelib.logging_utils import get_log_record_extra, log_context
1413
from tenacity import retry
1514
from tenacity.before_sleep import before_sleep_log
1615
from tenacity.wait import wait_exponential
1716

18-
from ..db.plugin import get_database_engine
1917
from ..login.utils import notify_user_logout
2018
from ..security.api import clean_auth_policy_cache
2119
from ..users.api import update_expired_users
@@ -60,10 +58,8 @@ async def _update_expired_users(app: web.Application):
6058
"""
6159
It is resilient, i.e. if update goes wrong, it waits a bit and retries
6260
"""
63-
engine: Engine = get_database_engine(app)
64-
assert engine # nosec
6561

66-
if updated := await update_expired_users(engine):
62+
if updated := await update_expired_users(app):
6763
# expired users might be cached in the auth. If so, any request
6864
# with this user-id will get thru producing unexpected side-effects
6965
await clean_auth_policy_cache(app)

services/web/server/src/simcore_service_webserver/users/_api.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from simcore_postgres_database.models.users import UserStatus
1111

1212
from ..db.plugin import get_database_engine
13-
from . import _db, _schemas
14-
from ._db import get_user_or_raise
15-
from ._db import list_user_permissions as db_list_of_permissions
16-
from ._db import update_user_status
13+
from . import _schemas, _users_repository
14+
from ._users_repository import get_user_or_raise
15+
from ._users_repository import list_user_permissions as db_list_of_permissions
16+
from ._users_repository import update_user_status
1717
from .exceptions import AlreadyPreRegisteredError
1818
from .schemas import Permission
1919

@@ -73,13 +73,13 @@ async def search_users(
7373
app: web.Application, email_glob: str, *, include_products: bool = False
7474
) -> list[_schemas.UserProfile]:
7575
# NOTE: this search is deploy-wide i.e. independent of the product!
76-
rows = await _db.search_users_and_get_profile(
76+
rows = await _users_repository.search_users_and_get_profile(
7777
get_database_engine(app), email_like=_glob_to_sql_like(email_glob)
7878
)
7979

8080
async def _list_products_or_none(user_id):
8181
if user_id is not None and include_products:
82-
products = await _db.get_user_products(
82+
products = await _users_repository.get_user_products(
8383
get_database_engine(app), user_id=user_id
8484
)
8585
return [_.product_name for _ in products]
@@ -136,7 +136,7 @@ async def pre_register_user(
136136
if key in details:
137137
details[f"pre_{key}"] = details.pop(key)
138138

139-
await _db.new_user_details(
139+
await _users_repository.new_user_details(
140140
get_database_engine(app),
141141
email=profile.email,
142142
created_by=creator_user_id,
@@ -152,8 +152,10 @@ async def pre_register_user(
152152
async def get_user_invoice_address(
153153
app: web.Application, user_id: UserID
154154
) -> UserInvoiceAddress:
155-
user_billing_details: UserBillingDetails = await _db.get_user_billing_details(
156-
get_database_engine(app), user_id=user_id
155+
user_billing_details: UserBillingDetails = (
156+
await _users_repository.get_user_billing_details(
157+
get_database_engine(app), user_id=user_id
158+
)
157159
)
158160
_user_billing_country = pycountry.countries.lookup(user_billing_details.country)
159161
_user_billing_country_alpha_2_format = _user_billing_country.alpha_2

services/web/server/src/simcore_service_webserver/users/_db.py renamed to services/web/server/src/simcore_service_webserver/users/_users_repository.py

Lines changed: 87 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22

33
import sqlalchemy as sa
44
from aiohttp import web
5-
from aiopg.sa.connection import SAConnection
6-
from aiopg.sa.engine import Engine
7-
from aiopg.sa.result import ResultProxy, RowProxy
8-
from models_library.groups import GroupID
9-
from models_library.users import UserBillingDetails, UserID
5+
from models_library.users import GroupID, UserBillingDetails, UserID
106
from simcore_postgres_database.models.groups import groups, user_to_groups
117
from simcore_postgres_database.models.products import products
128
from simcore_postgres_database.models.users import UserStatus, users
@@ -17,59 +13,78 @@
1713
GroupExtraPropertiesNotFoundError,
1814
GroupExtraPropertiesRepo,
1915
)
16+
from simcore_postgres_database.utils_repos import (
17+
pass_or_acquire_connection,
18+
transaction_context,
19+
)
2020
from simcore_postgres_database.utils_users import UsersRepo
2121
from simcore_service_webserver.users.exceptions import UserNotFoundError
22+
from sqlalchemy.engine.row import Row
23+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
2224

2325
from ..db.models import user_to_groups
24-
from ..db.plugin import get_database_engine
26+
from ..db.plugin import get_asyncpg_engine
2527
from .exceptions import BillingDetailsNotFoundError
2628
from .schemas import Permission
2729

2830
_ALL = None
2931

3032

3133
async def get_user_or_raise(
32-
engine: Engine, *, user_id: UserID, return_column_names: list[str] | None = _ALL
33-
) -> RowProxy:
34+
engine: AsyncEngine,
35+
connection: AsyncConnection | None = None,
36+
*,
37+
user_id: UserID,
38+
return_column_names: list[str] | None = _ALL,
39+
) -> Row:
3440
if return_column_names == _ALL:
3541
return_column_names = list(users.columns.keys())
3642

3743
assert return_column_names is not None # nosec
3844
assert set(return_column_names).issubset(users.columns.keys()) # nosec
3945

40-
async with engine.acquire() as conn:
41-
row: RowProxy | None = await (
42-
await conn.execute(
43-
sa.select(*(users.columns[name] for name in return_column_names)).where(
44-
users.c.id == user_id
45-
)
46+
async with pass_or_acquire_connection(engine, connection) as conn:
47+
result = await conn.stream(
48+
sa.select(*(users.columns[name] for name in return_column_names)).where(
49+
users.c.id == user_id
4650
)
47-
).first()
51+
)
52+
row = await result.first()
4853
if row is None:
4954
raise UserNotFoundError(uid=user_id)
5055
return row
5156

5257

53-
async def get_users_ids_in_group(conn: SAConnection, gid: GroupID) -> set[UserID]:
54-
result: set[UserID] = set()
55-
query_result = await conn.execute(
56-
sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == gid)
57-
)
58-
async for entry in query_result:
59-
result.add(entry[0])
60-
return result
58+
async def get_users_ids_in_group(
59+
engine: AsyncEngine,
60+
connection: AsyncConnection | None = None,
61+
*,
62+
group_id: GroupID,
63+
) -> set[UserID]:
64+
async with pass_or_acquire_connection(engine, connection) as conn:
65+
result = await conn.stream(
66+
sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == group_id)
67+
)
68+
return {row.uid async for row in result}
6169

6270

6371
async def list_user_permissions(
64-
app: web.Application, *, user_id: UserID, product_name: str
72+
app: web.Application,
73+
connection: AsyncConnection | None = None,
74+
*,
75+
user_id: UserID,
76+
product_name: str,
6577
) -> list[Permission]:
6678
override_services_specifications = Permission(
6779
name="override_services_specifications",
6880
allowed=False,
6981
)
7082
with contextlib.suppress(GroupExtraPropertiesNotFoundError):
71-
async with get_database_engine(app).acquire() as conn:
83+
async with pass_or_acquire_connection(
84+
get_asyncpg_engine(app), connection
85+
) as conn:
7286
user_group_extra_properties = (
87+
# TODO: adapt to asyncpg
7388
await GroupExtraPropertiesRepo.get_aggregated_properties_for_user(
7489
conn, user_id=user_id, product_name=product_name
7590
)
@@ -81,34 +96,43 @@ async def list_user_permissions(
8196
return [override_services_specifications]
8297

8398

84-
async def do_update_expired_users(conn: SAConnection) -> list[UserID]:
85-
result: ResultProxy = await conn.execute(
86-
users.update()
87-
.values(status=UserStatus.EXPIRED)
88-
.where(
89-
(users.c.expires_at.is_not(None))
90-
& (users.c.status == UserStatus.ACTIVE)
91-
& (users.c.expires_at < sa.sql.func.now())
99+
async def do_update_expired_users(
100+
engine: AsyncEngine,
101+
connection: AsyncConnection | None = None,
102+
) -> list[UserID]:
103+
async with transaction_context(engine, connection) as conn:
104+
result = await conn.stream(
105+
users.update()
106+
.values(status=UserStatus.EXPIRED)
107+
.where(
108+
(users.c.expires_at.is_not(None))
109+
& (users.c.status == UserStatus.ACTIVE)
110+
& (users.c.expires_at < sa.sql.func.now())
111+
)
112+
.returning(users.c.id)
92113
)
93-
.returning(users.c.id)
94-
)
95-
if rows := await result.fetchall():
96-
return [r.id for r in rows]
97-
return []
114+
return [row.id async for row in result]
98115

99116

100117
async def update_user_status(
101-
engine: Engine, *, user_id: UserID, new_status: UserStatus
118+
engine: AsyncEngine,
119+
connection: AsyncConnection | None = None,
120+
*,
121+
user_id: UserID,
122+
new_status: UserStatus,
102123
):
103-
async with engine.acquire() as conn:
124+
async with transaction_context(engine, connection) as conn:
104125
await conn.execute(
105126
users.update().values(status=new_status).where(users.c.id == user_id)
106127
)
107128

108129

109130
async def search_users_and_get_profile(
110-
engine: Engine, *, email_like: str
111-
) -> list[RowProxy]:
131+
engine: AsyncEngine,
132+
connection: AsyncConnection | None = None,
133+
*,
134+
email_like: str,
135+
) -> list[Row]:
112136

113137
users_alias = sa.alias(users, name="users_alias")
114138

@@ -118,7 +142,7 @@ async def search_users_and_get_profile(
118142
.label("invited_by")
119143
)
120144

121-
async with engine.acquire() as conn:
145+
async with pass_or_acquire_connection(engine, connection) as conn:
122146
columns = (
123147
users.c.first_name,
124148
users.c.last_name,
@@ -160,12 +184,17 @@ async def search_users_and_get_profile(
160184
.where(users.c.email.like(email_like))
161185
)
162186

163-
result = await conn.execute(sa.union(left_outer_join, right_outer_join))
164-
return await result.fetchall() or []
187+
result = await conn.stream(sa.union(left_outer_join, right_outer_join))
188+
return [row async for row in result]
165189

166190

167-
async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
168-
async with engine.acquire() as conn:
191+
async def get_user_products(
192+
engine: AsyncEngine,
193+
connection: AsyncConnection | None = None,
194+
*,
195+
user_id: UserID,
196+
) -> list[Row]:
197+
async with pass_or_acquire_connection(engine, connection) as conn:
169198
product_name_subq = (
170199
sa.select(products.c.name)
171200
.where(products.c.group_id == groups.c.gid)
@@ -187,14 +216,19 @@ async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
187216
.where(users.c.id == user_id)
188217
.order_by(groups.c.gid)
189218
)
190-
result = await conn.execute(query)
191-
return await result.fetchall() or []
219+
result = await conn.stream(query)
220+
return [row async for row in result]
192221

193222

194223
async def new_user_details(
195-
engine: Engine, email: str, created_by: UserID, **other_values
224+
engine: AsyncEngine,
225+
connection: AsyncConnection | None = None,
226+
*,
227+
email: str,
228+
created_by: UserID,
229+
**other_values,
196230
) -> None:
197-
async with engine.acquire() as conn:
231+
async with transaction_context(engine, connection) as conn:
198232
await conn.execute(
199233
sa.insert(users_pre_registration_details).values(
200234
created_by=created_by, pre_email=email, **other_values
@@ -203,13 +237,13 @@ async def new_user_details(
203237

204238

205239
async def get_user_billing_details(
206-
engine: Engine, user_id: UserID
240+
engine: AsyncEngine, connection: AsyncConnection | None = None, *, user_id: UserID
207241
) -> UserBillingDetails:
208242
"""
209243
Raises:
210244
BillingDetailsNotFoundError
211245
"""
212-
async with engine.acquire() as conn:
246+
async with pass_or_acquire_connection(engine, connection) as conn:
213247
user_billing_details = await UsersRepo.get_billing_details(conn, user_id)
214248
if not user_billing_details:
215249
raise BillingDetailsNotFoundError(user_id=user_id)

services/web/server/src/simcore_service_webserver/users/api.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import simcore_postgres_database.errors as db_errors
1313
import sqlalchemy as sa
1414
from aiohttp import web
15-
from aiopg.sa.engine import Engine
1615
from aiopg.sa.result import RowProxy
1716
from models_library.api_schemas_webserver.users import (
1817
MyProfileGet,
@@ -32,9 +31,10 @@
3231
from simcore_postgres_database.utils_users import generate_alternative_username
3332

3433
from ..db.plugin import get_database_engine
34+
from ..db.plugin import get_asyncpg_engine, get_database_engine
3535
from ..login.storage import AsyncpgStorage, get_plugin_storage
3636
from ..security.api import clean_auth_policy_cache
37-
from . import _db
37+
from . import _users_repository
3838
from ._api import get_user_credentials, get_user_invoice_address, set_user_as_deleted
3939
from ._models import ToUserUpdateDB
4040
from ._preferences_api import get_frontend_user_preferences_aggregation
@@ -245,8 +245,8 @@ async def get_user_name_and_email(
245245
Returns:
246246
(user, email)
247247
"""
248-
row = await _db.get_user_or_raise(
249-
get_database_engine(app),
248+
row = await _users_repository.get_user_or_raise(
249+
get_asyncpg_engine(app),
250250
user_id=_parse_as_user(user_id),
251251
return_column_names=["name", "email"],
252252
)
@@ -271,8 +271,8 @@ async def get_user_display_and_id_names(
271271
Raises:
272272
UserNotFoundError
273273
"""
274-
row = await _db.get_user_or_raise(
275-
get_database_engine(app),
274+
row = await _users_repository.get_user_or_raise(
275+
get_asyncpg_engine(app),
276276
user_id=_parse_as_user(user_id),
277277
return_column_names=["name", "email", "first_name", "last_name"],
278278
)
@@ -347,7 +347,9 @@ async def get_user(app: web.Application, user_id: UserID) -> dict[str, Any]:
347347
"""
348348
:raises UserNotFoundError:
349349
"""
350-
row = await _db.get_user_or_raise(engine=get_database_engine(app), user_id=user_id)
350+
row = await _users_repository.get_user_or_raise(
351+
engine=get_asyncpg_engine(app), user_id=user_id
352+
)
351353
return dict(row)
352354

353355

@@ -361,14 +363,13 @@ async def get_user_id_from_gid(app: web.Application, primary_gid: int) -> UserID
361363

362364

363365
async def get_users_in_group(app: web.Application, gid: GroupID) -> set[UserID]:
364-
engine = get_database_engine(app)
365-
async with engine.acquire() as conn:
366-
return await _db.get_users_ids_in_group(conn, gid)
366+
return await _users_repository.get_users_ids_in_group(
367+
get_asyncpg_engine(app), group_id=gid
368+
)
367369

368370

369-
async def update_expired_users(engine: Engine) -> list[UserID]:
370-
async with engine.acquire() as conn:
371-
return await _db.do_update_expired_users(conn)
371+
async def update_expired_users(app: web.Application) -> list[UserID]:
372+
return await _users_repository.do_update_expired_users(get_asyncpg_engine(app))
372373

373374

374375
assert set_user_as_deleted # nosec

0 commit comments

Comments
 (0)