|
| 1 | +import uuid |
1 | 2 | from typing import Any |
2 | 3 |
|
3 | | -from sqlmodel import Session |
| 4 | +from sqlalchemy import func |
| 5 | +from sqlmodel import Session, select |
4 | 6 |
|
5 | | -from app.core.security import verify_password |
6 | | -from app.model.user_model import UserModel |
7 | | -from app.models import User, UserCreate, UserUpdate, UserUpdateMe |
| 7 | +from app.core.security import get_password_hash |
| 8 | +from app.models import User, UserCreate, UserUpdate |
8 | 9 |
|
9 | 10 |
|
10 | | -class UserService: |
| 11 | +class UserModel: |
11 | 12 | def __init__(self, session: Session): |
12 | 13 | self.session = session |
13 | 14 |
|
14 | | - def create_user(cls, user_create: UserCreate) -> User: |
15 | | - return UserModel(cls.session).create(user_create) |
16 | | - |
17 | | - def update_user(cls, db_user: User, user_in: UserUpdate) -> Any: |
18 | | - return UserModel(cls.session).update(db_user, user_in) |
19 | | - |
20 | | - def get_user_by_email(cls, email: str) -> User | None: |
21 | | - return UserModel(cls.session).get_by_email(email) |
22 | | - |
23 | | - def get_user_by_id(cls, user_id: str) -> User | None: |
24 | | - return UserModel(cls.session).get_by_id(user_id) |
25 | | - |
26 | | - def get_users(cls, skip: int = 0, limit: int = 100) -> dict[str, Any]: |
27 | | - return UserModel(cls.session).get_users(skip, limit) |
28 | | - |
29 | | - def update_user_me(cls, current_user: User, user_in: UserUpdateMe) -> User: |
30 | | - if user_in.email: |
31 | | - existing_user = cls.get_user_by_email(email=user_in.email) |
32 | | - if existing_user and existing_user.id != current_user.id: |
33 | | - raise ValueError("User with this email already exists") |
34 | | - |
35 | | - # Convert UserUpdateMe to UserUpdate since model expects UserUpdate |
36 | | - update_data = UserUpdate(email=user_in.email, full_name=user_in.full_name) |
37 | | - return UserModel(cls.session).update(current_user, update_data) |
38 | | - |
39 | | - def update_password( |
40 | | - cls, current_user: User, current_password: str, new_password: str |
41 | | - ) -> None: |
42 | | - if not verify_password(current_password, current_user.hashed_password): |
43 | | - raise ValueError("Incorrect password") |
44 | | - if current_password == new_password: |
45 | | - raise ValueError("New password cannot be the same as the current one") |
46 | | - |
47 | | - # Create UserUpdate with new password |
48 | | - update_data = UserUpdate(password=new_password) |
49 | | - UserModel(cls.session).update(current_user, update_data) |
50 | | - |
51 | | - def delete_user( |
52 | | - cls, user_id: str, current_user_id: str, is_superuser: bool |
53 | | - ) -> None: |
54 | | - user = cls.get_user_by_id(user_id) |
55 | | - if not user: |
56 | | - raise ValueError("User not found") |
57 | | - if user.id == current_user_id and is_superuser: |
58 | | - raise ValueError("Super users are not allowed to delete themselves") |
59 | | - |
60 | | - UserModel(cls.session).delete_user(user_id) |
61 | | - |
62 | | - def authenticate(cls, email: str, password: str) -> User | None: |
63 | | - db_user = cls.get_user_by_email(email) |
64 | | - if not db_user: |
65 | | - return None |
66 | | - if not verify_password(password, db_user.hashed_password): |
67 | | - return None |
| 15 | + def create(cls, user_create: UserCreate) -> "User": |
| 16 | + db_obj = User.model_validate( |
| 17 | + user_create, |
| 18 | + update={"hashed_password": get_password_hash(user_create.password)}, |
| 19 | + ) |
| 20 | + cls.session.add(db_obj) |
| 21 | + cls.session.commit() |
| 22 | + cls.session.refresh(db_obj) |
| 23 | + return db_obj |
| 24 | + |
| 25 | + def update(cls, db_user: "User", user_in: UserUpdate) -> Any: |
| 26 | + user_data = user_in.model_dump(exclude_unset=True) |
| 27 | + extra_data = {} |
| 28 | + if "password" in user_data: |
| 29 | + password = user_data["password"] |
| 30 | + hashed_password = get_password_hash(password) |
| 31 | + extra_data["hashed_password"] = hashed_password |
| 32 | + db_user.sqlmodel_update(user_data, update=extra_data) |
| 33 | + cls.session.add(db_user) |
| 34 | + cls.session.commit() |
| 35 | + cls.session.refresh(db_user) |
68 | 36 | return db_user |
| 37 | + |
| 38 | + def get_by_email(cls, email: str) -> "User | None": |
| 39 | + statement = select(User).where(User.email == email) |
| 40 | + return cls.session.exec(statement).first() |
| 41 | + |
| 42 | + def get_by_id(cls, user_id: str) -> "User | None": |
| 43 | + statement = select(User).where(User.id == uuid.UUID(user_id)) |
| 44 | + return cls.session.exec(statement).first() |
| 45 | + |
| 46 | + def get_users(cls, skip: int = 0, limit: int = 100) -> dict: |
| 47 | + count_statement = select(func.count()).select_from(User) |
| 48 | + count = cls.session.exec(count_statement).one() |
| 49 | + statement = select(User).offset(skip).limit(limit) |
| 50 | + users = cls.session.exec(statement).all() |
| 51 | + return {"data": users, "count": count} |
| 52 | + |
| 53 | + def delete_user(cls, user_id: str) -> None: |
| 54 | + user = cls.get_by_id(user_id) |
| 55 | + if user: |
| 56 | + cls.session.delete(user) |
| 57 | + cls.session.commit() |
0 commit comments