Skip to content

Commit 796ceea

Browse files
committed
refactor ruther
1 parent e5c7458 commit 796ceea

File tree

1 file changed

+33
-31
lines changed
  • packages/postgres-database/src/simcore_postgres_database

1 file changed

+33
-31
lines changed

packages/postgres-database/src/simcore_postgres_database/utils_users.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlalchemy import Column
1313
from sqlalchemy.exc import IntegrityError
1414
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
15+
from sqlalchemy.sql import Select
1516

1617
from .models.users import UserRole, UserStatus, users
1718
from .models.users_details import users_pre_registration_details
@@ -56,6 +57,18 @@ class UsersRepo:
5657
def __init__(self, engine: AsyncEngine):
5758
self._engine = engine
5859

60+
async def _get_scalar_or_raise(
61+
self,
62+
query: Select,
63+
connection: AsyncConnection | None = None,
64+
) -> Any:
65+
"""Execute a scalar query and raise UserNotFoundInRepoError if no value found."""
66+
async with pass_or_acquire_connection(self._engine, connection) as conn:
67+
value = await conn.scalar(query)
68+
if value is not None:
69+
return value
70+
raise UserNotFoundInRepoError
71+
5972
async def new_user(
6073
self,
6174
connection: AsyncConnection | None = None,
@@ -191,45 +204,34 @@ async def get_billing_details(
191204
async def get_role(
192205
self, connection: AsyncConnection | None = None, *, user_id: int
193206
) -> UserRole:
194-
async with pass_or_acquire_connection(self._engine, connection) as conn:
195-
196-
value: UserRole | None = await conn.scalar(
197-
sa.select(users.c.role).where(users.c.id == user_id)
198-
)
199-
if value:
200-
assert isinstance(value, UserRole) # nosec
201-
return UserRole(value)
202-
203-
raise UserNotFoundInRepoError
207+
value = await self._get_scalar_or_raise(
208+
sa.select(users.c.role).where(users.c.id == user_id),
209+
connection=connection,
210+
)
211+
assert isinstance(value, UserRole) # nosec
212+
return UserRole(value)
204213

205214
async def get_email(
206215
self, connection: AsyncConnection | None = None, *, user_id: int
207216
) -> str:
208-
async with pass_or_acquire_connection(self._engine, connection) as conn:
209-
210-
value: str | None = await conn.scalar(
211-
sa.select(users.c.email).where(users.c.id == user_id)
212-
)
213-
if value:
214-
assert isinstance(value, str) # nosec
215-
return value
216-
217-
raise UserNotFoundInRepoError
217+
value = await self._get_scalar_or_raise(
218+
sa.select(users.c.email).where(users.c.id == user_id),
219+
connection=connection,
220+
)
221+
assert isinstance(value, str) # nosec
222+
return value
218223

219224
async def get_active_user_email(
220225
self, connection: AsyncConnection | None = None, *, user_id: int
221226
) -> str:
222-
async with pass_or_acquire_connection(self._engine, connection) as conn:
223-
value: str | None = await conn.scalar(
224-
sa.select(users.c.email).where(
225-
(users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id)
226-
)
227-
)
228-
if value is not None:
229-
assert isinstance(value, str) # nosec
230-
return value
231-
232-
raise UserNotFoundInRepoError
227+
value = await self._get_scalar_or_raise(
228+
sa.select(users.c.email).where(
229+
(users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id)
230+
),
231+
connection=connection,
232+
)
233+
assert isinstance(value, str) # nosec
234+
return value
233235

234236
async def is_email_used(
235237
self, connection: AsyncConnection | None = None, *, email: str

0 commit comments

Comments
 (0)