|
9 | 9 | from typing import Any, Final |
10 | 10 |
|
11 | 11 | import sqlalchemy as sa |
12 | | -from common_library.async_tools import maybe_await |
13 | 12 | from sqlalchemy import Column |
14 | 13 | from sqlalchemy.exc import IntegrityError |
15 | 14 | from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine |
16 | 15 |
|
17 | | -from ._protocols import DBConnection |
18 | 16 | from .models.users import UserRole, UserStatus, users |
19 | 17 | from .models.users_details import users_pre_registration_details |
20 | 18 | from .utils_repos import pass_or_acquire_connection, transaction_context |
@@ -183,64 +181,76 @@ def get_billing_details_query(user_id: int): |
183 | 181 | .where(users.c.id == user_id) |
184 | 182 | ) |
185 | 183 |
|
186 | | - @staticmethod |
187 | | - async def get_billing_details(conn: DBConnection, user_id: int) -> Any | None: |
188 | | - result = await conn.execute( |
189 | | - UsersRepo.get_billing_details_query(user_id=user_id) |
190 | | - ) |
191 | | - return await maybe_await(result.fetchone()) |
| 184 | + async def get_billing_details( |
| 185 | + self, connection: AsyncConnection | None = None, *, user_id: int |
| 186 | + ) -> Any | None: |
| 187 | + async with pass_or_acquire_connection(self._engine, connection) as conn: |
| 188 | + result = await conn.execute(self.get_billing_details_query(user_id=user_id)) |
| 189 | + return result.one_or_none() |
192 | 190 |
|
193 | | - @staticmethod |
194 | | - async def get_role(conn: DBConnection, user_id: int) -> UserRole: |
195 | | - value: UserRole | None = await conn.scalar( |
196 | | - sa.select(users.c.role).where(users.c.id == user_id) |
197 | | - ) |
198 | | - if value: |
199 | | - assert isinstance(value, UserRole) # nosec |
200 | | - return UserRole(value) |
| 191 | + async def get_role( |
| 192 | + self, connection: AsyncConnection | None = None, *, user_id: int |
| 193 | + ) -> UserRole: |
| 194 | + async with pass_or_acquire_connection(self._engine, connection) as conn: |
201 | 195 |
|
202 | | - raise UserNotFoundInRepoError |
| 196 | + value: UserRole | None = await conn.scalar( |
| 197 | + sa.select(users.c.role).where(users.c.id == user_id) |
| 198 | + ) |
| 199 | + if value: |
| 200 | + assert isinstance(value, UserRole) # nosec |
| 201 | + return UserRole(value) |
203 | 202 |
|
204 | | - @staticmethod |
205 | | - async def get_email(conn: DBConnection, user_id: int) -> str: |
206 | | - value: str | None = await conn.scalar( |
207 | | - sa.select(users.c.email).where(users.c.id == user_id) |
208 | | - ) |
209 | | - if value: |
210 | | - assert isinstance(value, str) # nosec |
211 | | - return value |
| 203 | + raise UserNotFoundInRepoError |
212 | 204 |
|
213 | | - raise UserNotFoundInRepoError |
| 205 | + async def get_email( |
| 206 | + self, connection: AsyncConnection | None = None, *, user_id: int |
| 207 | + ) -> str: |
| 208 | + async with pass_or_acquire_connection(self._engine, connection) as conn: |
214 | 209 |
|
215 | | - @staticmethod |
216 | | - async def get_active_user_email(conn: DBConnection, user_id: int) -> str: |
217 | | - value: str | None = await conn.scalar( |
218 | | - sa.select(users.c.email).where( |
219 | | - (users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id) |
| 210 | + value: str | None = await conn.scalar( |
| 211 | + sa.select(users.c.email).where(users.c.id == user_id) |
220 | 212 | ) |
221 | | - ) |
222 | | - if value is not None: |
223 | | - assert isinstance(value, str) # nosec |
224 | | - return value |
| 213 | + if value: |
| 214 | + assert isinstance(value, str) # nosec |
| 215 | + return value |
225 | 216 |
|
226 | | - raise UserNotFoundInRepoError |
| 217 | + raise UserNotFoundInRepoError |
227 | 218 |
|
228 | | - @staticmethod |
229 | | - async def is_email_used(conn: DBConnection, email: str) -> bool: |
230 | | - email = email.lower() |
| 219 | + async def get_active_user_email( |
| 220 | + self, connection: AsyncConnection | None = None, *, user_id: int |
| 221 | + ) -> str: |
| 222 | + async with pass_or_acquire_connection(self._engine, connection) as conn: |
| 223 | + value: str | None = await conn.scalar( |
| 224 | + sa.select(users.c.email).where( |
| 225 | + (users.c.status == UserStatus.ACTIVE) & (users.c.id == user_id) |
| 226 | + ) |
| 227 | + ) |
| 228 | + if value is not None: |
| 229 | + assert isinstance(value, str) # nosec |
| 230 | + return value |
231 | 231 |
|
232 | | - registered = await conn.scalar( |
233 | | - sa.select(users.c.id).where(users.c.email == email) |
234 | | - ) |
235 | | - if registered: |
236 | | - return True |
| 232 | + raise UserNotFoundInRepoError |
| 233 | + |
| 234 | + async def is_email_used( |
| 235 | + self, connection: AsyncConnection | None = None, *, email: str |
| 236 | + ) -> bool: |
| 237 | + |
| 238 | + async with pass_or_acquire_connection(self._engine, connection) as conn: |
| 239 | + |
| 240 | + email = email.lower() |
237 | 241 |
|
238 | | - pre_registered = await conn.scalar( |
239 | | - sa.select(users_pre_registration_details.c.user_id).where( |
240 | | - users_pre_registration_details.c.pre_email == email |
| 242 | + registered = await conn.scalar( |
| 243 | + sa.select(users.c.id).where(users.c.email == email) |
241 | 244 | ) |
242 | | - ) |
243 | | - return bool(pre_registered) |
| 245 | + if registered: |
| 246 | + return True |
| 247 | + |
| 248 | + pre_registered = await conn.scalar( |
| 249 | + sa.select(users_pre_registration_details.c.user_id).where( |
| 250 | + users_pre_registration_details.c.pre_email == email |
| 251 | + ) |
| 252 | + ) |
| 253 | + return bool(pre_registered) |
244 | 254 |
|
245 | 255 |
|
246 | 256 | # |
|
0 commit comments