Skip to content

Commit e5c7458

Browse files
committed
refactor ruther
1 parent 4b037df commit e5c7458

File tree

5 files changed

+73
-68
lines changed

5 files changed

+73
-68
lines changed

packages/postgres-database/src/simcore_postgres_database/utils_users.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from typing import Any, Final
1010

1111
import sqlalchemy as sa
12-
from common_library.async_tools import maybe_await
1312
from sqlalchemy import Column
1413
from sqlalchemy.exc import IntegrityError
1514
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
1615

17-
from ._protocols import DBConnection
1816
from .models.users import UserRole, UserStatus, users
1917
from .models.users_details import users_pre_registration_details
2018
from .utils_repos import pass_or_acquire_connection, transaction_context
@@ -183,64 +181,76 @@ def get_billing_details_query(user_id: int):
183181
.where(users.c.id == user_id)
184182
)
185183

186-
@staticmethod
187-
async def get_billing_details(conn: DBConnection, user_id: int) -> Any | None:
188-
result = await conn.execute(
189-
UsersRepo.get_billing_details_query(user_id=user_id)
190-
)
191-
return await maybe_await(result.fetchone())
184+
async def get_billing_details(
185+
self, connection: AsyncConnection | None = None, *, user_id: int
186+
) -> Any | None:
187+
async with pass_or_acquire_connection(self._engine, connection) as conn:
188+
result = await conn.execute(self.get_billing_details_query(user_id=user_id))
189+
return result.one_or_none()
192190

193-
@staticmethod
194-
async def get_role(conn: DBConnection, user_id: int) -> UserRole:
195-
value: UserRole | None = await conn.scalar(
196-
sa.select(users.c.role).where(users.c.id == user_id)
197-
)
198-
if value:
199-
assert isinstance(value, UserRole) # nosec
200-
return UserRole(value)
191+
async def get_role(
192+
self, connection: AsyncConnection | None = None, *, user_id: int
193+
) -> UserRole:
194+
async with pass_or_acquire_connection(self._engine, connection) as conn:
201195

202-
raise UserNotFoundInRepoError
196+
value: UserRole | None = await conn.scalar(
197+
sa.select(users.c.role).where(users.c.id == user_id)
198+
)
199+
if value:
200+
assert isinstance(value, UserRole) # nosec
201+
return UserRole(value)
203202

204-
@staticmethod
205-
async def get_email(conn: DBConnection, user_id: int) -> str:
206-
value: str | None = await conn.scalar(
207-
sa.select(users.c.email).where(users.c.id == user_id)
208-
)
209-
if value:
210-
assert isinstance(value, str) # nosec
211-
return value
203+
raise UserNotFoundInRepoError
212204

213-
raise UserNotFoundInRepoError
205+
async def get_email(
206+
self, connection: AsyncConnection | None = None, *, user_id: int
207+
) -> str:
208+
async with pass_or_acquire_connection(self._engine, connection) as conn:
214209

215-
@staticmethod
216-
async def get_active_user_email(conn: DBConnection, user_id: int) -> str:
217-
value: str | None = await conn.scalar(
218-
sa.select(users.c.email).where(
219-
(users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id)
210+
value: str | None = await conn.scalar(
211+
sa.select(users.c.email).where(users.c.id == user_id)
220212
)
221-
)
222-
if value is not None:
223-
assert isinstance(value, str) # nosec
224-
return value
213+
if value:
214+
assert isinstance(value, str) # nosec
215+
return value
225216

226-
raise UserNotFoundInRepoError
217+
raise UserNotFoundInRepoError
227218

228-
@staticmethod
229-
async def is_email_used(conn: DBConnection, email: str) -> bool:
230-
email = email.lower()
219+
async def get_active_user_email(
220+
self, connection: AsyncConnection | None = None, *, user_id: int
221+
) -> str:
222+
async with pass_or_acquire_connection(self._engine, connection) as conn:
223+
value: str | None = await conn.scalar(
224+
sa.select(users.c.email).where(
225+
(users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id)
226+
)
227+
)
228+
if value is not None:
229+
assert isinstance(value, str) # nosec
230+
return value
231231

232-
registered = await conn.scalar(
233-
sa.select(users.c.id).where(users.c.email == email)
234-
)
235-
if registered:
236-
return True
232+
raise UserNotFoundInRepoError
233+
234+
async def is_email_used(
235+
self, connection: AsyncConnection | None = None, *, email: str
236+
) -> bool:
237+
238+
async with pass_or_acquire_connection(self._engine, connection) as conn:
239+
240+
email = email.lower()
237241

238-
pre_registered = await conn.scalar(
239-
sa.select(users_pre_registration_details.c.user_id).where(
240-
users_pre_registration_details.c.pre_email == email
242+
registered = await conn.scalar(
243+
sa.select(users.c.id).where(users.c.email == email)
241244
)
242-
)
243-
return bool(pre_registered)
245+
if registered:
246+
return True
247+
248+
pre_registered = await conn.scalar(
249+
sa.select(users_pre_registration_details.c.user_id).where(
250+
users_pre_registration_details.c.pre_email == email
251+
)
252+
)
253+
return bool(pre_registered)
244254

245255

246256
#

packages/postgres-database/tests/test_users.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,12 @@ async def test_new_user(
159159
assert other_user.name != new_user.name
160160

161161
async with pass_or_acquire_connection(asyncpg_engine) as connection:
162-
assert await UsersRepo.get_email(connection, other_user.id) == other_user.email
163-
assert await UsersRepo.get_role(connection, other_user.id) == other_user.role
164162
assert (
165-
await UsersRepo.get_active_user_email(connection, other_user.id)
163+
await repo.get_email(connection, user_id=other_user.id) == other_user.email
164+
)
165+
assert await repo.get_role(connection, user_id=other_user.id) == other_user.role
166+
assert (
167+
await repo.get_active_user_email(connection, user_id=other_user.id)
166168
== other_user.email
167169
)
168170

packages/postgres-database/tests/test_users_details.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,8 @@ async def test_get_billing_details_from_pre_registration(
309309
)
310310

311311
# Get billing details
312-
async with pass_or_acquire_connection(asyncpg_engine) as connection:
313-
invoice_data = await UsersRepo.get_billing_details(
314-
connection, user_id=new_user.id
315-
)
316-
assert invoice_data is not None
312+
invoice_data = await repo.get_billing_details(user_id=new_user.id)
313+
assert invoice_data is not None
317314

318315
# Test UserAddress model conversion
319316
user_address = UserAddress.create_from_db(invoice_data)

services/web/server/src/simcore_service_webserver/login/_controller/rest/change.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from servicelib.aiohttp.requests_validation import parse_request_body_as
66
from servicelib.logging_errors import create_troubleshootting_log_kwargs
77
from servicelib.request_keys import RQT_USERID_KEY
8-
from simcore_postgres_database.utils_repos import pass_or_acquire_connection
98
from simcore_postgres_database.utils_users import UsersRepo
109

1110
from ...._meta import API_VTAG
@@ -228,9 +227,9 @@ async def initiate_change_email(request: web.Request):
228227
if user["email"] == request_body.email:
229228
return flash_response("Email changed")
230229

231-
async with pass_or_acquire_connection(get_asyncpg_engine(request.app)) as conn:
232-
if await UsersRepo.is_email_used(conn, email=request_body.email):
233-
raise web.HTTPUnprocessableEntity(text="This email cannot be used")
230+
repo = UsersRepo(get_asyncpg_engine(request.app))
231+
if await repo.is_email_used(email=request_body.email):
232+
raise web.HTTPUnprocessableEntity(text="This email cannot be used")
234233

235234
# Reset if previously requested
236235
confirmation = await db.get_confirmation({"user": user, "action": CHANGE_EMAIL})

services/web/server/src/simcore_service_webserver/users/_users_repository.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,10 @@ async def get_user_billing_details(
372372
Raises:
373373
BillingDetailsNotFoundError
374374
"""
375-
async with pass_or_acquire_connection(engine, connection) as conn:
376-
query = UsersRepo.get_billing_details_query(user_id=user_id)
377-
result = await conn.execute(query)
378-
row = result.first()
379-
if not row:
380-
raise BillingDetailsNotFoundError(user_id=user_id)
381-
return UserBillingDetails.model_validate(row)
375+
row = await UsersRepo(engine).get_billing_details(connection, user_id=user_id)
376+
if not row:
377+
raise BillingDetailsNotFoundError(user_id=user_id)
378+
return UserBillingDetails.model_validate(row)
382379

383380

384381
async def delete_user_by_id(

0 commit comments

Comments
 (0)