|
12 | 12 | from sqlalchemy import Column |
13 | 13 | from sqlalchemy.exc import IntegrityError |
14 | 14 | from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine |
| 15 | +from sqlalchemy.sql import Select |
15 | 16 |
|
16 | 17 | from .models.users import UserRole, UserStatus, users |
17 | 18 | from .models.users_details import users_pre_registration_details |
@@ -56,6 +57,18 @@ class UsersRepo: |
56 | 57 | def __init__(self, engine: AsyncEngine): |
57 | 58 | self._engine = engine |
58 | 59 |
|
| 60 | + async def _get_scalar_or_raise( |
| 61 | + self, |
| 62 | + query: Select, |
| 63 | + connection: AsyncConnection | None = None, |
| 64 | + ) -> Any: |
| 65 | + """Execute a scalar query and raise UserNotFoundInRepoError if no value found.""" |
| 66 | + async with pass_or_acquire_connection(self._engine, connection) as conn: |
| 67 | + value = await conn.scalar(query) |
| 68 | + if value is not None: |
| 69 | + return value |
| 70 | + raise UserNotFoundInRepoError |
| 71 | + |
59 | 72 | async def new_user( |
60 | 73 | self, |
61 | 74 | connection: AsyncConnection | None = None, |
@@ -191,45 +204,34 @@ async def get_billing_details( |
191 | 204 | async def get_role( |
192 | 205 | self, connection: AsyncConnection | None = None, *, user_id: int |
193 | 206 | ) -> UserRole: |
194 | | - async with pass_or_acquire_connection(self._engine, connection) as conn: |
195 | | - |
196 | | - value: UserRole | None = await conn.scalar( |
197 | | - sa.select(users.c.role).where(users.c.id == user_id) |
198 | | - ) |
199 | | - if value: |
200 | | - assert isinstance(value, UserRole) # nosec |
201 | | - return UserRole(value) |
202 | | - |
203 | | - raise UserNotFoundInRepoError |
| 207 | + value = await self._get_scalar_or_raise( |
| 208 | + sa.select(users.c.role).where(users.c.id == user_id), |
| 209 | + connection=connection, |
| 210 | + ) |
| 211 | + assert isinstance(value, UserRole) # nosec |
| 212 | + return UserRole(value) |
204 | 213 |
|
205 | 214 | async def get_email( |
206 | 215 | self, connection: AsyncConnection | None = None, *, user_id: int |
207 | 216 | ) -> str: |
208 | | - async with pass_or_acquire_connection(self._engine, connection) as conn: |
209 | | - |
210 | | - value: str | None = await conn.scalar( |
211 | | - sa.select(users.c.email).where(users.c.id == user_id) |
212 | | - ) |
213 | | - if value: |
214 | | - assert isinstance(value, str) # nosec |
215 | | - return value |
216 | | - |
217 | | - raise UserNotFoundInRepoError |
| 217 | + value = await self._get_scalar_or_raise( |
| 218 | + sa.select(users.c.email).where(users.c.id == user_id), |
| 219 | + connection=connection, |
| 220 | + ) |
| 221 | + assert isinstance(value, str) # nosec |
| 222 | + return value |
218 | 223 |
|
219 | 224 | async def get_active_user_email( |
220 | 225 | self, connection: AsyncConnection | None = None, *, user_id: int |
221 | 226 | ) -> str: |
222 | | - async with pass_or_acquire_connection(self._engine, connection) as conn: |
223 | | - value: str | None = await conn.scalar( |
224 | | - sa.select(users.c.email).where( |
225 | | - (users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id) |
226 | | - ) |
227 | | - ) |
228 | | - if value is not None: |
229 | | - assert isinstance(value, str) # nosec |
230 | | - return value |
231 | | - |
232 | | - raise UserNotFoundInRepoError |
| 227 | + value = await self._get_scalar_or_raise( |
| 228 | + sa.select(users.c.email).where( |
| 229 | + (users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id) |
| 230 | + ), |
| 231 | + connection=connection, |
| 232 | + ) |
| 233 | + assert isinstance(value, str) # nosec |
| 234 | + return value |
233 | 235 |
|
234 | 236 | async def is_email_used( |
235 | 237 | self, connection: AsyncConnection | None = None, *, email: str |
|
0 commit comments