|
1 | | -import datetime |
2 | | -import math |
3 | | - |
4 | 1 | from sqlalchemy import select, update, func, or_ |
5 | 2 | from sqlalchemy.ext.asyncio import AsyncSession |
6 | 3 | from sqlalchemy.orm import Session |
7 | | - |
8 | 4 | import config |
9 | 5 | from callbacks import StatisticsTimeDelta |
10 | 6 | from db import session_execute, session_flush |
|
14 | 10 |
|
15 | 11 |
|
16 | 12 | class UserRepository: |
| 13 | + INT32_MAX = 2_147_483_647 |
| 14 | + INT32_MIN = -2_147_483_648 |
| 15 | + |
17 | 16 | @staticmethod |
18 | 17 | async def get_by_tgid(telegram_id: int, session: AsyncSession | Session) -> UserDTO | None: |
19 | 18 | stmt = select(User).where(User.telegram_id == telegram_id) |
@@ -53,25 +52,35 @@ async def get_all_count(session: Session | AsyncSession) -> int: |
53 | 52 | return users_count.scalar_one() |
54 | 53 |
|
55 | 54 | @staticmethod |
56 | | - async def get_user_entity(user_entity: int | str, session: Session | AsyncSession) -> UserDTO | None: |
| 55 | + async def get_user_entity( |
| 56 | + user_entity: int | str, |
| 57 | + session: Session | AsyncSession |
| 58 | + ) -> UserDTO | None: |
| 59 | + |
| 60 | + entity_int: int | None = None |
57 | 61 | try: |
58 | | - entity_like_int = int(user_entity) |
59 | | - except ValueError: |
60 | | - entity_like_int = None |
61 | | - |
62 | | - stmt = select(User).where( |
63 | | - or_( |
64 | | - User.telegram_id == entity_like_int if entity_like_int is not None else False, |
65 | | - User.telegram_username == user_entity if entity_like_int is None else False, |
66 | | - User.id == entity_like_int if entity_like_int is not None else False |
67 | | - ) |
68 | | - ) |
69 | | - user = await session_execute(stmt, session) |
70 | | - user = user.scalar() |
| 62 | + entity_int = int(user_entity) |
| 63 | + except (ValueError, TypeError): |
| 64 | + pass |
| 65 | + |
| 66 | + conditions = [] |
| 67 | + if entity_int is not None: |
| 68 | + conditions.append(User.telegram_id == entity_int) |
| 69 | + if UserRepository.INT32_MIN <= entity_int <= UserRepository.INT32_MAX: |
| 70 | + conditions.append(User.id == entity_int) |
| 71 | + if isinstance(user_entity, str): |
| 72 | + conditions.append(User.telegram_username == user_entity) |
| 73 | + |
| 74 | + if not conditions: |
| 75 | + return None |
| 76 | + stmt = select(User).where(or_(*conditions)) |
| 77 | + result = await session_execute(stmt, session) |
| 78 | + user = result.scalar_one_or_none() |
| 79 | + |
71 | 80 | if user is None: |
72 | | - return user |
73 | | - else: |
74 | | - return UserDTO.model_validate(user, from_attributes=True) |
| 81 | + return None |
| 82 | + |
| 83 | + return UserDTO.model_validate(user, from_attributes=True) |
75 | 84 |
|
76 | 85 | @staticmethod |
77 | 86 | async def get_by_timedelta(timedelta: StatisticsTimeDelta, session: Session | AsyncSession) -> list[UserDTO]: |
|
0 commit comments