diff --git a/.env b/.env index 1d44286e25..951bb7a53f 100644 --- a/.env +++ b/.env @@ -18,7 +18,7 @@ STACK_NAME=full-stack-fastapi-project # Backend BACKEND_CORS_ORIGINS="http://localhost,http://localhost:5173,https://localhost,https://localhost:5173,http://localhost.tiangolo.com" -SECRET_KEY=changethis +SECRET_KEY=changethis_changethis_changethis FIRST_SUPERUSER=admin@example.com FIRST_SUPERUSER_PASSWORD=changethis diff --git a/backend/.coveragerc b/backend/.coveragerc new file mode 100644 index 0000000000..f0086b67a8 --- /dev/null +++ b/backend/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit = + */models/* + */test_* + */tests/* + */__init__.py diff --git a/backend/Dockerfile b/backend/Dockerfile index 9d6e699f30..c403b45de1 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,7 +1,5 @@ FROM python:3.10 -ENV PYTHONUNBUFFERED=1 - WORKDIR /app/ # Install uv @@ -31,7 +29,7 @@ ENV PYTHONPATH=/app COPY ./scripts /app/scripts -COPY ./pyproject.toml ./uv.lock ./alembic.ini /app/ +COPY ./pyproject.toml ./uv.lock ./alembic.ini ./pytest.ini ./.coveragerc /app/ COPY ./app /app/app diff --git a/backend/README.md b/backend/README.md index 17210a2f2c..728b42f8cd 100644 --- a/backend/README.md +++ b/backend/README.md @@ -27,7 +27,7 @@ $ source .venv/bin/activate Make sure your editor is using the correct Python virtual environment, with the interpreter at `backend/.venv/bin/python`. -Modify or add SQLModel models for data and SQL tables in `./backend/app/models.py`, API endpoints in `./backend/app/api/`, CRUD (Create, Read, Update, Delete) utils in `./backend/app/crud.py`. +Modify or add SQLModel models for data and SQL tables in `./backend/app/models/*`, API endpoints in `./backend/app/api/`, CRUD (Create, Read, Update, Delete) utils in `./backend/app/crud/*`. ## VS Code diff --git a/backend/app/alembic/env.py b/backend/app/alembic/env.py index 7f29c04680..3329fb97a3 100755 --- a/backend/app/alembic/env.py +++ b/backend/app/alembic/env.py @@ -1,4 +1,3 @@ -import os from logging.config import fileConfig from alembic import context @@ -18,7 +17,8 @@ # target_metadata = mymodel.Base.metadata # target_metadata = None -from app.models import SQLModel # noqa +from sqlmodel import SQLModel # noqa +from app.models import * # noqa from app.core.config import settings # noqa target_metadata = SQLModel.metadata diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index c2b83c841d..1ef370fae4 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -13,44 +13,114 @@ from app.core.db import engine from app.models import TokenPayload, User +# Define OAuth2 scheme for token authentication reusable_oauth2 = OAuth2PasswordBearer( tokenUrl=f"{settings.API_V1_STR}/login/access-token" ) def get_db() -> Generator[Session, None, None]: + """ + Get a database session. + + This function creates a new session for the database and yields it. + The session is automatically closed after the function is executed. + It's not protected and can be used by any part of the application. + + Args: + None + + Returns: + Generator[Session, None, None]: A generator that yields the database session. + + Raises: + None + + Notes: + This function uses a context manager to ensure proper closure of the session. + """ with Session(engine) as session: yield session +# Define dependencies for database session and token SessionDep = Annotated[Session, Depends(get_db)] TokenDep = Annotated[str, Depends(reusable_oauth2)] +# Function to get the current user from the token def get_current_user(session: SessionDep, token: TokenDep) -> User: + """ + Get the current user from the token. + + This function decodes the JWT token and retrieves the corresponding user from the database. + It's protected and used as a dependency in routes that require authentication. + + Args: + session (SessionDep): The database session. + token (TokenDep): The JWT token. + + Returns: + User: The current authenticated user. + + Raises: + HTTPException: + - 403: If the token is invalid or expired. + - 404: If the user is not found in the database. + - 400: If the user account is inactive. + + Notes: + This function performs token validation and user authentication. + """ try: + # Decode the JWT token payload = jwt.decode( token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) token_data = TokenPayload(**payload) except (InvalidTokenError, ValidationError): + # Raise exception if token is invalid raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Could not validate credentials", ) + # Get the user from the database user = session.get(User, token_data.sub) if not user: + # Raise exception if user not found raise HTTPException(status_code=404, detail="User not found") if not user.is_active: + # Raise exception if user is inactive raise HTTPException(status_code=400, detail="Inactive user") return user +# Define dependency for getting the current user CurrentUser = Annotated[User, Depends(get_current_user)] def get_current_active_superuser(current_user: CurrentUser) -> User: + """ + Get the current active superuser. + + This function checks if the current user is a superuser. + It's protected and used as a dependency in routes that require superuser privileges. + + Args: + current_user (CurrentUser): The current authenticated user. + + Returns: + User: The current active superuser. + + Raises: + HTTPException: + - 403: If the user is not a superuser. + + Notes: + This function is typically used to restrict access to administrative endpoints. + """ if not current_user.is_superuser: + # Raise exception if user is not a superuser raise HTTPException( status_code=403, detail="The user doesn't have enough privileges" ) diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 09e0663fc3..6f645b4194 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -2,8 +2,16 @@ from app.api.routes import items, login, users, utils +# Create a main APIRouter instance api_router = APIRouter() + +# Include the login router without a prefix, tagged as "login" api_router.include_router(login.router, tags=["login"]) + +# Include the users router with "/users" prefix, tagged as "users" api_router.include_router(users.router, prefix="/users", tags=["users"]) + +# Include the utils router with "/utils" prefix, tagged as "utils" api_router.include_router(utils.router, prefix="/utils", tags=["utils"]) + api_router.include_router(items.router, prefix="/items", tags=["items"]) diff --git a/backend/app/api/routes/items.py b/backend/app/api/routes/items.py index 67196c2366..8a515fa381 100644 --- a/backend/app/api/routes/items.py +++ b/backend/app/api/routes/items.py @@ -2,10 +2,10 @@ from typing import Any from fastapi import APIRouter, HTTPException -from sqlmodel import func, select +from app import crud from app.api.deps import CurrentUser, SessionDep -from app.models import Item, ItemCreate, ItemPublic, ItemsPublic, ItemUpdate, Message +from app.models import ItemCreate, ItemPublic, ItemsPublic, ItemUpdate, Message router = APIRouter() @@ -19,24 +19,13 @@ def read_items( """ if current_user.is_superuser: - count_statement = select(func.count()).select_from(Item) - count = session.exec(count_statement).one() - statement = select(Item).offset(skip).limit(limit) - items = session.exec(statement).all() + count = crud.get_item_count(session=session) + items = crud.get_items(session=session, skip=skip, limit=limit) else: - count_statement = ( - select(func.count()) - .select_from(Item) - .where(Item.owner_id == current_user.id) + count = crud.get_item_count_by_owner(session=session, owner_id=current_user.id) + items = crud.get_items_by_owner( + session=session, owner_id=current_user.id, skip=skip, limit=limit ) - count = session.exec(count_statement).one() - statement = ( - select(Item) - .where(Item.owner_id == current_user.id) - .offset(skip) - .limit(limit) - ) - items = session.exec(statement).all() return ItemsPublic(data=items, count=count) @@ -46,7 +35,7 @@ def read_item(session: SessionDep, current_user: CurrentUser, id: uuid.UUID) -> """ Get item by ID. """ - item = session.get(Item, id) + item = crud.get_item(session=session, item_id=id) if not item: raise HTTPException(status_code=404, detail="Item not found") if not current_user.is_superuser and (item.owner_id != current_user.id): @@ -61,10 +50,7 @@ def create_item( """ Create new item. """ - item = Item.model_validate(item_in, update={"owner_id": current_user.id}) - session.add(item) - session.commit() - session.refresh(item) + item = crud.create_item(session=session, item_in=item_in, owner_id=current_user.id) return item @@ -79,16 +65,12 @@ def update_item( """ Update an item. """ - item = session.get(Item, id) + item = crud.get_item(session=session, item_id=id) if not item: raise HTTPException(status_code=404, detail="Item not found") if not current_user.is_superuser and (item.owner_id != current_user.id): raise HTTPException(status_code=400, detail="Not enough permissions") - update_dict = item_in.model_dump(exclude_unset=True) - item.sqlmodel_update(update_dict) - session.add(item) - session.commit() - session.refresh(item) + item = crud.update_item(session=session, db_item=item, item_in=item_in) return item @@ -99,11 +81,10 @@ def delete_item( """ Delete an item. """ - item = session.get(Item, id) + item = crud.get_item(session=session, item_id=id) if not item: raise HTTPException(status_code=404, detail="Item not found") if not current_user.is_superuser and (item.owner_id != current_user.id): raise HTTPException(status_code=400, detail="Not enough permissions") - session.delete(item) - session.commit() + crud.delete_item(session=session, item_id=id) return Message(message="Item deleted successfully") diff --git a/backend/app/api/routes/login.py b/backend/app/api/routes/login.py index fe7e94d5c1..87bb15228c 100644 --- a/backend/app/api/routes/login.py +++ b/backend/app/api/routes/login.py @@ -18,24 +18,51 @@ verify_password_reset_token, ) +# Create a new APIRouter instance for login-related routes router = APIRouter() @router.post("/login/access-token") def login_access_token( - session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()] + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> Token: """ OAuth2 compatible token login, get an access token for future requests + + This endpoint allows users to obtain an access token for authentication. + It's rate-limited to 5 requests per minute to prevent abuse. + + Args: + request (Request): The incoming request object (required for rate limiting). + response (Response): The outgoing response object (required for rate limiting). + session (SessionDep): The database session dependency. + form_data (OAuth2PasswordRequestForm): The form data containing username and password. + + Returns: + Token: An object containing the access token. + + Raises: + HTTPException: + - 400: If the email or password is incorrect. + - 400: If the user account is inactive. + + Notes: + This function authenticates the user, checks if the account is active, + and then generates and returns an access token with a specified expiration time. """ + # Authenticate the user using the provided credentials user = crud.authenticate( session=session, email=form_data.username, password=form_data.password ) + # If authentication fails, raise an HTTPException if not user: raise HTTPException(status_code=400, detail="Incorrect email or password") elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") + # Set the access token expiration time access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + # Create and return a new access token return Token( access_token=security.create_access_token( user.id, expires_delta=access_token_expires @@ -47,42 +74,116 @@ def login_access_token( def test_token(current_user: CurrentUser) -> Any: """ Test access token + + This endpoint allows testing the validity of an access token. + It's protected and can only be accessed with a valid token. + + Args: + current_user (CurrentUser): The current authenticated user, injected by dependency. + + Returns: + Any: The current user's public information. + + Raises: + HTTPException: If the token is invalid or expired (handled by dependency). + + Notes: + This function is useful for verifying that a token is working correctly. + It simply returns the current user's information, which implicitly confirms + that the token is valid and the user is authenticated. """ + # Return the current user to verify the token is working return current_user @router.post("/password-recovery/{email}") -def recover_password(email: str, session: SessionDep) -> Message: +def recover_password( + email: str, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] +) -> Message: """ Password Recovery + + This endpoint initiates the password recovery process for a user. + It's rate-limited to 3 requests per minute to prevent abuse. + + Args: + email (str): The email address of the user requesting password recovery. + session (SessionDep): The database session dependency. + request (Request): The incoming request object (required for rate limiting). + response (Response): The outgoing response object (required for rate limiting). + + Returns: + Message: A message indicating that the password recovery email was sent. + + Raises: + HTTPException: + - 404: If no user is found with the provided email address. + + Notes: + This function checks for the existence of the user, generates a password reset token, + creates a password reset email, and sends it to the user's email address. + It does not confirm or deny the existence of an account to prevent email enumeration. """ + # Get the user by email user = crud.get_user_by_email(session=session, email=email) + # If user not found, still act like we sent the email if not user: - raise HTTPException( - status_code=404, - detail="The user with this email does not exist in the system.", - ) + return Message(message="Password recovery email sent if the account exists.") + + # Generate a password reset token password_reset_token = generate_password_reset_token(email=email) + # Generate the reset password email content email_data = generate_reset_password_email( email_to=user.email, email=email, token=password_reset_token ) + # Send the password reset email send_email( email_to=user.email, subject=email_data.subject, html_content=email_data.html_content, ) - return Message(message="Password recovery email sent") + # Return a success message + return Message(message="Password recovery email sent if the account exists.") @router.post("/reset-password/") -def reset_password(session: SessionDep, body: NewPassword) -> Message: +def reset_password( + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + body: NewPassword, +) -> Message: """ Reset password + + This endpoint allows users to reset their password using a valid reset token. + It's rate-limited to 3 requests per minute to prevent abuse. + + Args: + session (SessionDep): The database session dependency. + body (NewPassword): The new password and reset token. + request (Request): The incoming request object (required for rate limiting). + response (Response): The outgoing response object (required for rate limiting). + + Returns: + Message: A message indicating that the password was successfully updated. + + Raises: + HTTPException: + - 400: If the reset token is invalid. + - 404: If no user is found with the email associated with the token. + - 400: If the user account is inactive. + + Notes: + This function verifies the reset token, retrieves the associated user, + checks if the user is active, hashes the new password, and updates it in the database. + It's the final step in the password recovery process. """ + # Verify the password reset token email = verify_password_reset_token(token=body.token) if not email: raise HTTPException(status_code=400, detail="Invalid token") + # Get the user by email user = crud.get_user_by_email(session=session, email=email) if not user: raise HTTPException( @@ -91,10 +192,13 @@ def reset_password(session: SessionDep, body: NewPassword) -> Message: ) elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") + # Hash the new password hashed_password = get_password_hash(password=body.new_password) + # Update the user's password user.hashed_password = hashed_password session.add(user) session.commit() + # Return a success message return Message(message="Password updated successfully") @@ -103,22 +207,50 @@ def reset_password(session: SessionDep, body: NewPassword) -> Message: dependencies=[Depends(get_current_active_superuser)], response_class=HTMLResponse, ) -def recover_password_html_content(email: str, session: SessionDep) -> Any: +def recover_password_html_content( + email: str, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] +) -> Any: """ HTML Content for Password Recovery + + This endpoint generates and returns the HTML content for a password recovery email. + It's protected and can only be accessed by active superusers. + + Args: + email (str): The email address of the user for whom to generate the recovery email. + session (SessionDep): The database session dependency. + + Returns: + HTMLResponse: The HTML content of the password reset email, with the subject in the headers. + + Raises: + HTTPException: + - 404: If no user is found with the provided email address. + + Notes: + This function is primarily for testing and debugging purposes. It allows superusers + to view the content of password reset emails without actually sending them. + It generates a password reset token and creates the email content just like + the actual password recovery process. """ + # Get the user by email user = crud.get_user_by_email(session=session, email=email) + # If user not found, raise an HTTPException if not user: raise HTTPException( status_code=404, detail="The user with this username does not exist in the system.", ) + # Generate a password reset token password_reset_token = generate_password_reset_token(email=email) + # Generate the reset password email content email_data = generate_reset_password_email( email_to=user.email, email=email, token=password_reset_token ) + # Return the HTML content of the password reset email return HTMLResponse( content=email_data.html_content, headers={"subject:": email_data.subject} ) diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index c636b094ee..1ceed426a7 100644 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -2,7 +2,6 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlmodel import col, delete, func, select from app import crud from app.api.deps import ( @@ -11,12 +10,10 @@ get_current_active_superuser, ) from app.core.config import settings -from app.core.security import get_password_hash, verify_password +from app.core.security import verify_password from app.models import ( - Item, Message, UpdatePassword, - User, UserCreate, UserPublic, UserRegister, @@ -26,6 +23,7 @@ ) from app.utils import generate_new_account_email, send_email +# Create a new APIRouter instance for user-related routes router = APIRouter() @@ -34,27 +32,65 @@ dependencies=[Depends(get_current_active_superuser)], response_model=UsersPublic, ) -def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> Any: +def read_users( + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] [reportInvalidTypeForm] + skip: int = 0, + limit: int = 100, +) -> Any: """ Retrieve users. - """ - count_statement = select(func.count()).select_from(User) - count = session.exec(count_statement).one() + This endpoint allows retrieving a list of users with pagination. + It's protected and can only be accessed by superusers. - statement = select(User).offset(skip).limit(limit) - users = session.exec(statement).all() + Args: + session (SessionDep): The database session dependency. + skip (int): The number of users to skip (for pagination). + limit (int): The maximum number of users to return. + Returns: + UsersPublic: An object containing the list of users and the total count. + + Notes: + This endpoint is useful for administrative purposes to view all users in the system. + """ + # Get a list of users with pagination + users = crud.get_users(session=session, skip=skip, limit=limit) + # Get the total count of users + count = crud.get_user_count(session=session) + # Return the users and count in the UsersPublic model return UsersPublic(data=users, count=count) @router.post( "/", dependencies=[Depends(get_current_active_superuser)], response_model=UserPublic ) -def create_user(*, session: SessionDep, user_in: UserCreate) -> Any: +def create_user( + *, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + user_in: UserCreate, +) -> Any: """ Create new user. + + This endpoint allows creating a new user in the system. + It's protected and can only be accessed by superusers. + + Args: + session (SessionDep): The database session dependency. + user_in (UserCreate): The user data to be created. + + Returns: + UserPublic: The created user's public information. + + Raises: + HTTPException: + - 400: If a user with the given email already exists. + + Notes: + If email sending is enabled, a new account email will be sent to the user. """ + # Check if a user with the given email already exists user = crud.get_user_by_email(session=session, email=user_in.email) if user: raise HTTPException( @@ -62,7 +98,9 @@ def create_user(*, session: SessionDep, user_in: UserCreate) -> Any: detail="The user with this email already exists in the system.", ) + # Create the new user user = crud.create_user(session=session, user_create=user_in) + # If email sending is enabled, send a new account email if settings.emails_enabled and user_in.email: email_data = generate_new_account_email( email_to=user_in.email, username=user_in.email, password=user_in.password @@ -77,43 +115,90 @@ def create_user(*, session: SessionDep, user_in: UserCreate) -> Any: @router.patch("/me", response_model=UserPublic) def update_user_me( - *, session: SessionDep, user_in: UserUpdateMe, current_user: CurrentUser + *, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + user_in: UserUpdateMe, + current_user: CurrentUser, ) -> Any: """ Update own user. - """ + This endpoint allows users to update their own information. + It's protected and can be accessed by authenticated users. + + Args: + session (SessionDep): The database session dependency. + user_in (UserUpdateMe): The user data to be updated. + current_user (CurrentUser): The current authenticated user. + + Returns: + UserPublic: The updated user's public information. + + Raises: + HTTPException: + - 409: If the new email is already in use by another user. + + Notes: + Users can update their own information, but not their role or superuser status. + """ + # If the email is being updated, check if it's already in use if user_in.email: existing_user = crud.get_user_by_email(session=session, email=user_in.email) if existing_user and existing_user.id != current_user.id: raise HTTPException( status_code=409, detail="User with this email already exists" ) - user_data = user_in.model_dump(exclude_unset=True) - current_user.sqlmodel_update(user_data) - session.add(current_user) - session.commit() - session.refresh(current_user) - return current_user + # Update the user's attributes + updated_user = crud.update_user_attributes( + session=session, + db_user=current_user, + user_in=UserUpdate(**user_in.model_dump()), + ) + return updated_user @router.patch("/me/password", response_model=Message) def update_password_me( - *, session: SessionDep, body: UpdatePassword, current_user: CurrentUser + *, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + body: UpdatePassword, + current_user: CurrentUser, ) -> Any: """ Update own password. + + This endpoint allows users to update their own password. + It's protected and can be accessed by authenticated users. + + Args: + session (SessionDep): The database session dependency. + body (UpdatePassword): The current and new password data. + current_user (CurrentUser): The current authenticated user. + + Returns: + Message: A message confirming the password update. + + Raises: + HTTPException: + - 400: If the current password is incorrect or if the new password is the same as the current one. + + Notes: + Users must provide their current password for security reasons. """ + # Verify the current password if not verify_password(body.current_password, current_user.hashed_password): raise HTTPException(status_code=400, detail="Incorrect password") + # Ensure the new password is different from the current one if body.current_password == body.new_password: raise HTTPException( status_code=400, detail="New password cannot be the same as the current one" ) - hashed_password = get_password_hash(body.new_password) - current_user.hashed_password = hashed_password - session.add(current_user) - session.commit() + # Update the user's password + crud.update_user_password( + session=session, + db_user=current_user, + new_password=body.new_password, + ) return Message(message="Password updated successfully") @@ -121,52 +206,131 @@ def update_password_me( def read_user_me(current_user: CurrentUser) -> Any: """ Get current user. + + This endpoint allows users to retrieve their own information. + It's protected and can be accessed by authenticated users. + + Args: + current_user (CurrentUser): The current authenticated user. + + Returns: + UserPublic: The current user's public information. + + Notes: + This endpoint is useful for clients to get the latest user information after login. """ + # Return the current user's information return current_user @router.delete("/me", response_model=Message) -def delete_user_me(session: SessionDep, current_user: CurrentUser) -> Any: +def delete_user_me( + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + current_user: CurrentUser, +) -> Any: """ Delete own user. + + This endpoint allows users to delete their own account. + It's protected and can be accessed by authenticated users. + + Args: + session (SessionDep): The database session dependency. + current_user (CurrentUser): The current authenticated user. + + Returns: + Message: A message confirming the user deletion. + + Raises: + HTTPException: + - 403: If the user is a superuser trying to delete their own account. + + Notes: + Superusers are not allowed to delete their own accounts for security reasons. """ + # Prevent superusers from deleting themselves if current_user.is_superuser: raise HTTPException( status_code=403, detail="Super users are not allowed to delete themselves" ) - statement = delete(Item).where(col(Item.owner_id) == current_user.id) - session.exec(statement) # type: ignore - session.delete(current_user) - session.commit() + # Delete the user + crud.delete_user(session=session, user_id=current_user.id) return Message(message="User deleted successfully") @router.post("/signup", response_model=UserPublic) -def register_user(session: SessionDep, user_in: UserRegister) -> Any: +def register_user( + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + user_in: UserRegister, +) -> Any: """ Create new user without the need to be logged in. + + This endpoint allows new users to register in the system. + It's public and can be accessed without authentication. + + Args: + session (SessionDep): The database session dependency. + user_in (UserRegister): The user registration data. + + Returns: + UserPublic: The created user's public information. + + Raises: + HTTPException: + - 400: If a user with the given email already exists. + + Notes: + This endpoint is typically used for user sign-up functionality. """ + # Check if a user with the given email already exists user = crud.get_user_by_email(session=session, email=user_in.email) if user: raise HTTPException( status_code=400, detail="The user with this email already exists in the system", ) + # Create a UserCreate model from the UserRegister input user_create = UserCreate.model_validate(user_in) + # Create the new user user = crud.create_user(session=session, user_create=user_create) return user @router.get("/{user_id}", response_model=UserPublic) def read_user_by_id( - user_id: uuid.UUID, session: SessionDep, current_user: CurrentUser + user_id: uuid.UUID, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + current_user: CurrentUser, ) -> Any: """ Get a specific user by id. + + This endpoint allows retrieving a specific user's information by their ID. + It's protected and can be accessed by the user themselves or superusers. + + Args: + user_id (uuid.UUID): The ID of the user to retrieve. + session (SessionDep): The database session dependency. + current_user (CurrentUser): The current authenticated user. + + Returns: + UserPublic: The requested user's public information. + + Raises: + HTTPException: + - 403: If the current user doesn't have enough privileges to access the information. + - 404: If the user with the given ID is not found. + + Notes: + Regular users can only access their own information, while superusers can access any user's information. """ - user = session.get(User, user_id) + # Retrieve the user by ID + user = crud.get_user(session=session, user_id=user_id) + # Allow users to access their own information if user == current_user: return user + # Only superusers can access other users' information if not current_user.is_superuser: raise HTTPException( status_code=403, @@ -182,20 +346,40 @@ def read_user_by_id( ) def update_user( *, - session: SessionDep, + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] user_id: uuid.UUID, user_in: UserUpdate, ) -> Any: """ Update a user. - """ - db_user = session.get(User, user_id) + This endpoint allows updating a specific user's information. + It's protected and can only be accessed by superusers. + + Args: + session (SessionDep): The database session dependency. + user_id (uuid.UUID): The ID of the user to update. + user_in (UserUpdate): The user data to be updated. + + Returns: + UserPublic: The updated user's public information. + + Raises: + HTTPException: + - 404: If the user with the given ID is not found. + - 409: If the new email is already in use by another user. + + Notes: + This endpoint is typically used for administrative purposes to update any user's information. + """ + # Retrieve the user by ID + db_user = crud.get_user(session=session, user_id=user_id) if not db_user: raise HTTPException( status_code=404, detail="The user with this id does not exist in the system", ) + # If the email is being updated, check if it's already in use if user_in.email: existing_user = crud.get_user_by_email(session=session, email=user_in.email) if existing_user and existing_user.id != user_id: @@ -203,26 +387,50 @@ def update_user( status_code=409, detail="User with this email already exists" ) - db_user = crud.update_user(session=session, db_user=db_user, user_in=user_in) + # Update the user's attributes + db_user = crud.update_user_attributes( + session=session, db_user=db_user, user_in=user_in + ) return db_user @router.delete("/{user_id}", dependencies=[Depends(get_current_active_superuser)]) def delete_user( - session: SessionDep, current_user: CurrentUser, user_id: uuid.UUID + session: SessionDep, # pyright: ignore [reportInvalidTypeForm] + current_user: CurrentUser, + user_id: uuid.UUID, ) -> Message: """ Delete a user. + + This endpoint allows deleting a specific user from the system. + It's protected and can only be accessed by superusers. + + Args: + session (SessionDep): The database session dependency. + current_user (CurrentUser): The current authenticated superuser. + user_id (uuid.UUID): The ID of the user to delete. + + Returns: + Message: A message confirming the user deletion. + + Raises: + HTTPException: + - 403: If a superuser tries to delete their own account. + - 404: If the user with the given ID is not found. + + Notes: + Superusers are not allowed to delete their own accounts for security reasons. """ - user = session.get(User, user_id) + # Retrieve the user by ID + user = crud.get_user(session=session, user_id=user_id) if not user: raise HTTPException(status_code=404, detail="User not found") + # Prevent superusers from deleting themselves if user == current_user: raise HTTPException( status_code=403, detail="Super users are not allowed to delete themselves" ) - statement = delete(Item).where(col(Item.owner_id) == user_id) - session.exec(statement) # type: ignore - session.delete(user) - session.commit() + # Delete the user + crud.delete_user(session=session, user_id=user_id) return Message(message="User deleted successfully") diff --git a/backend/app/api/routes/utils.py b/backend/app/api/routes/utils.py index a73b80d761..73f156dbd9 100644 --- a/backend/app/api/routes/utils.py +++ b/backend/app/api/routes/utils.py @@ -5,6 +5,7 @@ from app.models import Message from app.utils import generate_test_email, send_email +# Create a new APIRouter instance router = APIRouter() @@ -16,16 +17,56 @@ def test_email(email_to: EmailStr) -> Message: """ Test emails. + + This endpoint allows sending a test email to a specified address. + It's protected and can only be accessed by active superusers. + + Args: + email_to (EmailStr): The email address to send the test email to. + + Returns: + Message: A message indicating that the test email was sent successfully. + + Raises: + HTTPException: If the user is not an active superuser. + + Notes: + This function is useful for verifying email functionality in the system. + It generates a test email and sends it to the specified address. """ + # Generate test email content email_data = generate_test_email(email_to=email_to) + + # Send the test email send_email( email_to=email_to, subject=email_data.subject, html_content=email_data.html_content, ) + + # Return a success message return Message(message="Test email sent") @router.get("/health-check/") async def health_check() -> bool: + """ + Perform a health check. + + This endpoint returns True, indicating that the API is up and running. + It can be used for monitoring and load balancer checks. + + Args: + None + + Returns: + bool: Always returns True if the API is functioning. + + Raises: + None + + Notes: + This is an asynchronous function that doesn't require any authentication. + It's typically used by monitoring systems to verify the API's availability. + """ return True diff --git a/backend/app/backend_pre_start.py b/backend/app/backend_pre_start.py deleted file mode 100644 index c2f8e29ae1..0000000000 --- a/backend/app/backend_pre_start.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging - -from sqlalchemy import Engine -from sqlmodel import Session, select -from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed - -from app.core.db import engine - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -max_tries = 60 * 5 # 5 minutes -wait_seconds = 1 - - -@retry( - stop=stop_after_attempt(max_tries), - wait=wait_fixed(wait_seconds), - before=before_log(logger, logging.INFO), - after=after_log(logger, logging.WARN), -) -def init(db_engine: Engine) -> None: - try: - with Session(db_engine) as session: - # Try to create session to check if DB is awake - session.exec(select(1)) - except Exception as e: - logger.error(e) - raise e - - -def main() -> None: - logger.info("Initializing service") - init(engine) - logger.info("Service finished initializing") - - -if __name__ == "__main__": - main() diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 2370469d7a..4328f8f004 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -16,6 +16,24 @@ def parse_cors(v: Any) -> list[str] | str: + """ + Parse CORS settings. + + This function parses the CORS settings from various input formats. + It's not protected and can be used by any part of the application. + + Args: + v (Any): The input value to parse. + + Returns: + list[str] | str: Parsed CORS settings. + + Raises: + ValueError: If the input cannot be parsed. + + Notes: + Accepts comma-separated string or list of strings. + """ if isinstance(v, str) and not v.startswith("["): return [i.strip() for i in v.split(",")] elif isinstance(v, list | str): @@ -24,19 +42,26 @@ def parse_cors(v: Any) -> list[str] | str: class Settings(BaseSettings): + # Configuration for the settings model model_config = SettingsConfigDict( # Use top level .env file (one level above ./backend/) env_file="../.env", env_ignore_empty=True, extra="ignore", ) + + # API version string API_V1_STR: str = "/api/v1" + # Secret key for security (default is a random URL-safe string) SECRET_KEY: str = secrets.token_urlsafe(32) - # 60 minutes * 24 hours * 8 days = 8 days + # Token expiration time (8 days) ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 + # Frontend host URL FRONTEND_HOST: str = "http://localhost:5173" + # Current environment (local, staging, production) ENVIRONMENT: Literal["local", "staging", "production"] = "local" + # CORS origins configuration BACKEND_CORS_ORIGINS: Annotated[ list[AnyUrl] | str, BeforeValidator(parse_cors) ] = [] @@ -44,12 +69,34 @@ class Settings(BaseSettings): @computed_field # type: ignore[prop-decorator] @property def all_cors_origins(self) -> list[str]: + """ + Get all CORS origins. + + This property computes and returns all CORS origins including the frontend host. + It's not protected and can be accessed by any part of the application. + + Args: + None + + Returns: + list[str]: List of all CORS origins. + + Raises: + None + + Notes: + Combines backend CORS origins with the frontend host. + """ return [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + [ self.FRONTEND_HOST ] + # Project name PROJECT_NAME: str + # Sentry DSN for error tracking SENTRY_DSN: HttpUrl | None = None + + # Database configuration POSTGRES_SERVER: str POSTGRES_PORT: int = 5432 POSTGRES_USER: str @@ -59,6 +106,24 @@ def all_cors_origins(self) -> list[str]: @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: + """ + Get SQLAlchemy database URI. + + This property computes and returns the SQLAlchemy database URI. + It's not protected and can be accessed by any part of the application. + + Args: + None + + Returns: + PostgresDsn: The SQLAlchemy database URI. + + Raises: + None + + Notes: + Constructs the URI using the PostgreSQL configuration settings. + """ return MultiHostUrl.build( scheme="postgresql+psycopg", username=self.POSTGRES_USER, @@ -68,6 +133,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: path=self.POSTGRES_DB, ) + # Email configuration SMTP_TLS: bool = True SMTP_SSL: bool = False SMTP_PORT: int = 587 @@ -80,15 +146,52 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: @model_validator(mode="after") def _set_default_emails_from(self) -> Self: + """ + Set default email sender name. + + This method sets the default email sender name if not provided. + It's not protected and is used internally by the Settings class. + + Args: + None + + Returns: + Self: The Settings instance. + + Raises: + None + + Notes: + Uses the project name as the default sender name. + """ if not self.EMAILS_FROM_NAME: self.EMAILS_FROM_NAME = self.PROJECT_NAME return self + # Password reset token expiration time EMAIL_RESET_TOKEN_EXPIRE_HOURS: int = 48 @computed_field # type: ignore[prop-decorator] @property def emails_enabled(self) -> bool: + """ + Check if emails are enabled. + + This property determines if the email functionality is enabled. + It's not protected and can be accessed by any part of the application. + + Args: + None + + Returns: + bool: True if emails are enabled, False otherwise. + + Raises: + None + + Notes: + Emails are considered enabled if SMTP host and sender email are set. + """ return bool(self.SMTP_HOST and self.EMAILS_FROM_EMAIL) # TODO: update type to EmailStr when sqlmodel supports it @@ -98,6 +201,25 @@ def emails_enabled(self) -> bool: FIRST_SUPERUSER_PASSWORD: str def _check_default_secret(self, var_name: str, value: str | None) -> None: + """ + Check for default secret values. + + This method checks if a secret value is set to its default and raises a warning or error. + It's not protected and is used internally by the Settings class. + + Args: + var_name (str): The name of the variable being checked. + value (str | None): The value of the variable. + + Returns: + None + + Raises: + ValueError: If the value is "changethis" in non-local environments. + + Notes: + Warns in local environment, raises error in other environments. + """ if value == "changethis": message = ( f'The value of {var_name} is "changethis", ' @@ -110,6 +232,24 @@ def _check_default_secret(self, var_name: str, value: str | None) -> None: @model_validator(mode="after") def _enforce_non_default_secrets(self) -> Self: + """ + Enforce non-default secrets. + + This method ensures that secret values are not left as their defaults. + It's not protected and is used internally by the Settings class. + + Args: + None + + Returns: + Self: The Settings instance. + + Raises: + ValueError: If any secret is left as its default value. + + Notes: + Checks SECRET_KEY, POSTGRES_PASSWORD, and FIRST_SUPERUSER_PASSWORD. + """ self._check_default_secret("SECRET_KEY", self.SECRET_KEY) self._check_default_secret("POSTGRES_PASSWORD", self.POSTGRES_PASSWORD) self._check_default_secret( @@ -119,4 +259,5 @@ def _enforce_non_default_secrets(self) -> Self: return self +# Create an instance of the Settings class settings = Settings() # type: ignore diff --git a/backend/app/core/db.py b/backend/app/core/db.py index ba991fb36d..f170cd7a68 100644 --- a/backend/app/core/db.py +++ b/backend/app/core/db.py @@ -1,9 +1,10 @@ -from sqlmodel import Session, create_engine, select +from sqlmodel import Session, create_engine from app import crud from app.core.config import settings -from app.models import User, UserCreate +from app.models import UserCreate +# Create a SQLAlchemy engine using the database URI from settings engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) @@ -13,21 +14,39 @@ def init_db(session: Session) -> None: - # Tables should be created with Alembic migrations - # But if you don't want to use migrations, create - # the tables un-commenting the next lines - # from sqlmodel import SQLModel + """ + Initialize the database. + + This function initializes the database by creating a superuser if one doesn't exist. + It's not protected and can be called during application startup. + + Args: + session (Session): The database session. - # This works because the models are already imported and registered from app.models + Returns: + None + + Raises: + None + + Notes: + - Tables should be created with Alembic migrations. + - If migrations are not used, uncomment the SQLModel.metadata.create_all(engine) line. This works because the models are already imported and registered from app.models + - Creates a superuser using settings if one doesn't exist. + """ + # from sqlmodel import SQLModel + # from app.core.engine import engine + # # SQLModel.metadata.create_all(engine) - user = session.exec( - select(User).where(User.email == settings.FIRST_SUPERUSER) - ).first() + # Check if a superuser already exists in the database + user = crud.get_user_by_email(session=session, email=settings.FIRST_SUPERUSER) if not user: + # If no superuser exists, create one using the settings user_in = UserCreate( email=settings.FIRST_SUPERUSER, password=settings.FIRST_SUPERUSER_PASSWORD, is_superuser=True, ) + # Add the new superuser to the database user = crud.create_user(session=session, user_create=user_in) diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 7aff7cfb32..4c9eac7912 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -6,22 +6,84 @@ from app.core.config import settings +# Create a password context for hashing and verifying passwords pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - +# Define the algorithm used for JWT encoding and decoding ALGORITHM = "HS256" def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: + """ + Create an access token for the given subject with an expiration time. + + This function creates a JWT access token for the given subject with the specified expiration time. + It's not protected and can be used by any part of the application that needs to generate access tokens. + + Args: + subject (str | Any): The subject of the token, typically a user identifier. + expires_delta (timedelta): The time delta after which the token will expire. + + Returns: + str: The encoded JWT access token. + + Raises: + None + + Notes: + Uses the SECRET_KEY from settings and the ALGORITHM constant for encoding. + """ + # Calculate the expiration time expire = datetime.now(timezone.utc) + expires_delta + # Create a payload with expiration time and subject to_encode = {"exp": expire, "sub": str(subject)} + # Encode the payload into a JWT encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt def verify_password(plain_password: str, hashed_password: str) -> bool: + """ + Verify if the plain password matches the hashed password. + + This function checks if a given plain text password matches a hashed password. + It's not protected and can be used by any part of the application that needs to verify passwords. + + Args: + plain_password (str): The plain text password to verify. + hashed_password (str): The hashed password to compare against. + + Returns: + bool: True if the passwords match, False otherwise. + + Raises: + None + + Notes: + Uses the pwd_context for password verification. + """ + # Use the password context to verify the plain password against the hashed password return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: + """ + Get the hash of the given password. + + This function generates a hash for a given plain text password. + It's not protected and can be used by any part of the application that needs to hash passwords. + + Args: + password (str): The plain text password to hash. + + Returns: + str: The hashed password. + + Raises: + None + + Notes: + Uses the pwd_context for password hashing. + """ + # Use the password context to hash the given password return pwd_context.hash(password) diff --git a/backend/app/crud.py b/backend/app/crud.py deleted file mode 100644 index 905bf48724..0000000000 --- a/backend/app/crud.py +++ /dev/null @@ -1,54 +0,0 @@ -import uuid -from typing import Any - -from sqlmodel import Session, select - -from app.core.security import get_password_hash, verify_password -from app.models import Item, ItemCreate, User, UserCreate, UserUpdate - - -def create_user(*, session: Session, user_create: UserCreate) -> User: - db_obj = User.model_validate( - user_create, update={"hashed_password": get_password_hash(user_create.password)} - ) - session.add(db_obj) - session.commit() - session.refresh(db_obj) - return db_obj - - -def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: - user_data = user_in.model_dump(exclude_unset=True) - extra_data = {} - if "password" in user_data: - password = user_data["password"] - hashed_password = get_password_hash(password) - extra_data["hashed_password"] = hashed_password - db_user.sqlmodel_update(user_data, update=extra_data) - session.add(db_user) - session.commit() - session.refresh(db_user) - return db_user - - -def get_user_by_email(*, session: Session, email: str) -> User | None: - statement = select(User).where(User.email == email) - session_user = session.exec(statement).first() - return session_user - - -def authenticate(*, session: Session, email: str, password: str) -> User | None: - db_user = get_user_by_email(session=session, email=email) - if not db_user: - return None - if not verify_password(password, db_user.hashed_password): - return None - return db_user - - -def create_item(*, session: Session, item_in: ItemCreate, owner_id: uuid.UUID) -> Item: - db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) - session.add(db_item) - session.commit() - session.refresh(db_item) - return db_item diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py new file mode 100644 index 0000000000..9d2c52d2bc --- /dev/null +++ b/backend/app/crud/__init__.py @@ -0,0 +1,41 @@ +from .items import ( + create_item, + delete_item, + get_item, + get_item_count, + get_item_count_by_owner, + get_items, + get_items_by_owner, + update_item, +) +from .users import ( + authenticate, + create_user, + delete_user, + get_user, + get_user_by_email, + get_user_count, + get_users, + update_user_attributes, + update_user_password, +) + +__all__ = [ + "create_user", + "get_user", + "get_users", + "update_user_password", + "delete_user", + "get_user_count", + "get_user_by_email", + "authenticate", + "update_user_attributes", + "create_item", + "get_item", + "get_items", + "get_items_by_owner", + "get_item_count", + "get_item_count_by_owner", + "update_item", + "delete_item", +] diff --git a/backend/app/crud/items.py b/backend/app/crud/items.py new file mode 100644 index 0000000000..39870dda94 --- /dev/null +++ b/backend/app/crud/items.py @@ -0,0 +1,189 @@ +import uuid + +from sqlmodel import Session, func, select + +from app.models import Item, ItemCreate, ItemUpdate + + +def create_item(*, session: Session, item_in: ItemCreate, owner_id: uuid.UUID) -> Item: + """ + Create a new item. + + Creates a new Item object and associates it with an owner. + + Args: + session: The database session. + item_in: The ItemCreate object containing item information. + owner_id: The UUID of the user who owns this item. + + Returns: + The newly created Item object. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) + session.add(db_item) + session.commit() + session.refresh(db_item) + return db_item + + +def get_item(*, session: Session, item_id: uuid.UUID) -> Item | None: + """ + Get an item by ID. + + Retrieves an item from the database by its UUID. + + Args: + session: The database session. + item_id: The UUID of the item to retrieve. + + Returns: + The Item object if found, None otherwise. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + return session.get(Item, item_id) + + +def get_items(*, session: Session, skip: int = 0, limit: int = 100) -> list[Item]: + """ + Get a list of items. + + Retrieves a list of items from the database with pagination. + + Args: + session: The database session. + skip: The number of items to skip (for pagination). + limit: The maximum number of items to return (for pagination). + + Returns: + A list of Item objects. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + statement = select(Item).offset(skip).limit(limit) + return list(session.exec(statement).all()) + + +def get_items_by_owner( + *, session: Session, owner_id: uuid.UUID, skip: int = 0, limit: int = 100 +) -> list[Item]: + """ + Get items by owner. + + Retrieves a list of items from the database for a specific owner with pagination. + + Args: + session: The database session. + owner_id: The UUID of the owner. + skip: The number of items to skip (for pagination). + limit: The maximum number of items to return (for pagination). + + Returns: + A list of Item objects owned by the specified user. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + statement = select(Item).where(Item.owner_id == owner_id).offset(skip).limit(limit) + return list(session.exec(statement).all()) + + +def get_item_count(*, session: Session) -> int: + """ + Get the total number of items. + + Counts the total number of items in the database. + + Args: + session: The database session. + + Returns: + The total number of items as an integer. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + return session.exec(select(func.count()).select_from(Item)).one() + + +def get_item_count_by_owner(*, session: Session, owner_id: uuid.UUID) -> int: + """ + Get the number of items by owner. + + Counts the number of items in the database for a specific owner. + + Args: + session: The database session. + owner_id: The UUID of the owner. + + Returns: + The number of items owned by the specified user as an integer. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + return session.exec( + select(func.count()).where(Item.owner_id == owner_id).select_from(Item) + ).one() + + +def update_item(*, session: Session, db_item: Item, item_in: ItemUpdate) -> Item: + """ + Update an item. + + Updates the attributes of an item in the database. + + Args: + session: The database session. + db_item: The Item object to update. + item_in: The ItemUpdate object containing the new attribute values. + + Returns: + The updated Item object. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + item_data = item_in.model_dump(exclude_unset=True) + db_item.sqlmodel_update(item_data) + session.add(db_item) + session.commit() + session.refresh(db_item) + return db_item + + +def delete_item(*, session: Session, item_id: uuid.UUID) -> Item | None: + """ + Delete an item. + + Deletes an item from the database by its UUID. + + Args: + session: The database session. + item_id: The UUID of the item to delete. + + Returns: + The deleted Item object if found and deleted, None otherwise. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + item = session.get(Item, item_id) + if item: + session.delete(item) + session.commit() + return item diff --git a/backend/app/crud/users.py b/backend/app/crud/users.py new file mode 100644 index 0000000000..b87500cab8 --- /dev/null +++ b/backend/app/crud/users.py @@ -0,0 +1,240 @@ +import uuid + +from sqlmodel import Session, func, select + +from app.core.security import get_password_hash, verify_password +from app.models import User, UserCreate, UserUpdate + + +def create_user(*, session: Session, user_create: UserCreate) -> User: + """ + Create a new user. + + Creates a new User object from UserCreate, hashes the password, and adds the user to the database. + + Args: + session: The database session. + user_create: The UserCreate object containing user information. + + Returns: + The newly created User object. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Create a new User object from UserCreate, hashing the password + db_obj = User.model_validate( + user_create, update={"hashed_password": get_password_hash(user_create.password)} + ) + # Add the new user to the session, commit changes, and refresh the object + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + +def get_user(*, session: Session, user_id: uuid.UUID) -> User | None: + """ + Get a user by ID. + + Retrieves a user from the database by their UUID. + + Args: + session: The database session. + user_id: The UUID of the user to retrieve. + + Returns: + The User object if found, None otherwise. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Retrieve a user from the database by their UUID + return session.get(User, user_id) + + +def get_user_by_email(*, session: Session, email: str) -> User | None: + """ + Get a user by email. + + Retrieves a user from the database by their email address. + + Args: + session: The database session. + email: The email address of the user to retrieve. + + Returns: + The User object if found, None otherwise. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Create a SELECT statement to find a user by email + statement = select(User).where(User.email == email) + # Execute the statement and return the first result (or None if not found) + session_user = session.exec(statement).first() + return session_user + + +def get_users(*, session: Session, skip: int = 0, limit: int = 100) -> list[User]: + """ + Get a list of users. + + Retrieves a list of users from the database with pagination. + + Args: + session: The database session. + skip: The number of users to skip (for pagination). + limit: The maximum number of users to return (for pagination). + + Returns: + A list of User objects. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Create a SELECT statement to retrieve users with pagination + statement = select(User).offset(skip).limit(limit) + # Execute the statement and return the results as a list + return list(session.exec(statement).all()) + + +def get_user_count(*, session: Session) -> int: + """ + Get the total number of users. + + Counts the total number of users in the database. + + Args: + session: The database session. + + Returns: + The total number of users as an integer. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Count the total number of users in the database + return session.exec(select(func.count()).select_from(User)).one() + + +def delete_user(*, session: Session, user_id: uuid.UUID) -> User | None: + """ + Delete a user. + + Deletes a user from the database by their UUID. + + Args: + session: The database session. + user_id: The UUID of the user to delete. + + Returns: + The deleted User object if found and deleted, None otherwise. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Retrieve the user by ID + user = session.get(User, user_id) + if user: + # If the user exists, delete them from the database + session.delete(user) + session.commit() + return user + + +def authenticate(*, session: Session, email: str, password: str) -> User | None: + """ + Authenticate a user. + + Verifies the user's email and password against the database. + + Args: + session: The database session. + email: The email address of the user. + password: The password to verify. + + Returns: + The authenticated User object if credentials are valid, None otherwise. + + Notes: + This function is not protected and can be called by any unauthenticated user. + """ + # Get the user by email + db_user = get_user_by_email(session=session, email=email) + if not db_user: + return None + # Verify the provided password against the stored hashed password + if not verify_password(password, db_user.hashed_password): + return None + return db_user + + +def update_user_attributes( + *, session: Session, db_user: User, user_in: UserUpdate +) -> User: + """ + Update a user's attributes. + + Updates the attributes of a user in the database. + + Args: + session: The database session. + db_user: The User object to update. + user_in: The UserUpdate object containing the new attribute values. + + Returns: + The updated User object. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Convert the UserUpdate object to a dictionary, excluding unset fields + user_data = user_in.model_dump(exclude_unset=True) + # Update the user object with the new data + db_user.sqlmodel_update(user_data) + # Save the changes to the database + session.add(db_user) + session.commit() + session.refresh(db_user) + return db_user + + +def update_user_password(*, session: Session, db_user: User, new_password: str) -> User: + """ + Update a user's password. + + Updates the password of a user in the database. + + Args: + session: The database session. + db_user: The User object to update. + new_password: The new password to set. + + Returns: + The updated User object. + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function is not protected and can be called by any authenticated user. + """ + # Hash the new password + hashed_password = get_password_hash(new_password) + # Update the user's hashed password + db_user.hashed_password = hashed_password + # Save the changes to the database + session.add(db_user) + session.commit() + session.refresh(db_user) + return db_user diff --git a/backend/app/initial_data.py b/backend/app/initial_data.py deleted file mode 100644 index d806c3d381..0000000000 --- a/backend/app/initial_data.py +++ /dev/null @@ -1,23 +0,0 @@ -import logging - -from sqlmodel import Session - -from app.core.db import engine, init_db - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def init() -> None: - with Session(engine) as session: - init_db(session) - - -def main() -> None: - logger.info("Creating initial data") - init() - logger.info("Initial data created") - - -if __name__ == "__main__": - main() diff --git a/backend/app/main.py b/backend/app/main.py index 9a95801e74..d2a7c65cf3 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -8,19 +8,38 @@ def custom_generate_unique_id(route: APIRoute) -> str: + """ + Generate unique IDs for API routes. + + Creates a unique identifier for each API route by combining the route's tag and name. + + Args: + route: The APIRoute object for which to generate a unique ID. + + Returns: + A string representing the unique ID for the given route. + + Notes: + This function is used as a custom generator for FastAPI's route IDs. + """ return f"{route.tags[0]}-{route.name}" +# Initialize Sentry for error tracking if DSN is provided and not in local environment if settings.SENTRY_DSN and settings.ENVIRONMENT != "local": sentry_sdk.init(dsn=str(settings.SENTRY_DSN), enable_tracing=True) +# Create FastAPI application instance with custom settings app = FastAPI( title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json", + docs_url=f"{settings.API_V1_STR}/docs", + redoc_url=f"{settings.API_V1_STR}/redoc", + version="v1", generate_unique_id_function=custom_generate_unique_id, ) -# Set all CORS enabled origins +# Set up CORS middleware if origins are specified in settings if settings.all_cors_origins: app.add_middleware( CORSMiddleware, @@ -30,4 +49,5 @@ def custom_generate_unique_id(route: APIRoute) -> str: allow_headers=["*"], ) +# Include the main API router with the specified prefix app.include_router(api_router, prefix=settings.API_V1_STR) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000000..71d684cd4c --- /dev/null +++ b/backend/app/models/__init__.py @@ -0,0 +1,35 @@ +from .item import Item, ItemBase, ItemCreate, ItemPublic, ItemsPublic, ItemUpdate +from .misc import Message, NewPassword, Token, TokenPayload +from .user import ( + UpdatePassword, + User, + UserBase, + UserCreate, + UserPublic, + UserRegister, + UsersPublic, + UserUpdate, + UserUpdateMe, +) + +__all__ = [ + "User", + "UserBase", + "UserCreate", + "UserRegister", + "UserUpdate", + "UserUpdateMe", + "UpdatePassword", + "UserPublic", + "UsersPublic", + "Message", + "Token", + "TokenPayload", + "NewPassword", + "Item", + "ItemBase", + "ItemCreate", + "ItemUpdate", + "ItemPublic", + "ItemsPublic", +] diff --git a/backend/app/models/item.py b/backend/app/models/item.py new file mode 100644 index 0000000000..32768759c8 --- /dev/null +++ b/backend/app/models/item.py @@ -0,0 +1,42 @@ +import uuid + +from sqlmodel import Field, Relationship, SQLModel + +from app.models.user import User + + +# Shared properties +class ItemBase(SQLModel): + title: str = Field(min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=255) + + +# Properties to receive on item creation +class ItemCreate(ItemBase): + pass + + +# Properties to receive on item update +class ItemUpdate(ItemBase): + title: str | None = Field(default=None, min_length=1, max_length=255) # type: ignore + + +# Database model, database table inferred from class name +class Item(ItemBase, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + title: str = Field(max_length=255) + owner_id: uuid.UUID = Field( + foreign_key="user.id", nullable=False, ondelete="CASCADE" + ) + owner: User | None = Relationship(back_populates="items") + + +# Properties to return via API, id is always required +class ItemPublic(ItemBase): + id: uuid.UUID + owner_id: uuid.UUID + + +class ItemsPublic(SQLModel): + data: list[ItemPublic] + count: int diff --git a/backend/app/models/misc.py b/backend/app/models/misc.py new file mode 100644 index 0000000000..356b09f11c --- /dev/null +++ b/backend/app/models/misc.py @@ -0,0 +1,84 @@ +from sqlmodel import Field, SQLModel + + +# Generic message +class Message(SQLModel): + """ + Generic message. + + Defines a simple structure for generic messages. + + Args: + message (str): The content of the message. + + Returns: + None + + Notes: + This class can be used for various messaging purposes throughout the application. + """ + + message: str + + +# JSON payload containing access token +class Token(SQLModel): + """ + JSON payload containing access token. + + Defines the structure for an authentication token response. + + Args: + access_token (str): The access token string. + token_type (str): The type of token, defaults to "bearer". + + Returns: + None + + Notes: + This class is used in the authentication process to return token information. + """ + + access_token: str + token_type: str = "bearer" + + +# Contents of JWT token +class TokenPayload(SQLModel): + """ + Contents of JWT token. + + Defines the structure for the payload of a JWT token. + + Args: + sub (str, optional): The subject of the token, usually the user ID. + + Returns: + None + + Notes: + This class represents the decoded contents of a JWT token. + """ + + sub: str | None = None + + +class NewPassword(SQLModel): + """ + Class for resetting password. + + Defines the structure for a password reset request. + + Args: + token (str): The token for password reset verification. + new_password (str): The new password to set, with length constraints. + + Returns: + None + + Notes: + This class is used in the password reset process. + """ + + token: str + new_password: str = Field(min_length=8, max_length=40) diff --git a/backend/app/models.py b/backend/app/models/user.py similarity index 55% rename from backend/app/models.py rename to backend/app/models/user.py index 90ef5559e3..cbf6adb481 100644 --- a/backend/app/models.py +++ b/backend/app/models/user.py @@ -1,114 +1,116 @@ import uuid +from typing import TYPE_CHECKING from pydantic import EmailStr from sqlmodel import Field, Relationship, SQLModel +if TYPE_CHECKING: + from .item import Item + # Shared properties class UserBase(SQLModel): + """ + Base class for user properties. + """ + + # User's email address, must be unique and indexed email: EmailStr = Field(unique=True, index=True, max_length=255) + # Flag to indicate if the user account is active is_active: bool = True + # Flag to indicate if the user has superuser privileges is_superuser: bool = False + # User's full name, optional full_name: str | None = Field(default=None, max_length=255) # Properties to receive via API on creation class UserCreate(UserBase): + """ + Class for creating a new user. + """ + + # Password field with length constraints password: str = Field(min_length=8, max_length=40) class UserRegister(SQLModel): + """ + Class for user registration. + """ + + # Email field for registration email: EmailStr = Field(max_length=255) + # Password field with length constraints password: str = Field(min_length=8, max_length=40) + # Optional full name field full_name: str | None = Field(default=None, max_length=255) # Properties to receive via API on update, all are optional class UserUpdate(UserBase): + """ + Class for updating user information. + """ + + # Optional email field for updates email: EmailStr | None = Field(default=None, max_length=255) # type: ignore + # Optional password field for updates password: str | None = Field(default=None, min_length=8, max_length=40) class UserUpdateMe(SQLModel): + """ + Class for updating user information. + """ + + # Optional full name field for self-updates full_name: str | None = Field(default=None, max_length=255) + # Optional email field for self-updates email: EmailStr | None = Field(default=None, max_length=255) class UpdatePassword(SQLModel): + """ + Class for updating user password. + """ + + # Current password field current_password: str = Field(min_length=8, max_length=40) + # New password field new_password: str = Field(min_length=8, max_length=40) # Database model, database table inferred from class name class User(UserBase, table=True): + """ + Database model for user. + """ + + # Unique identifier for the user id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + # Hashed password field hashed_password: str + # Relationship to items model items: list["Item"] = Relationship(back_populates="owner", cascade_delete=True) # Properties to return via API, id is always required class UserPublic(UserBase): + """ + Public properties for user. + """ + + # Public user ID id: uuid.UUID class UsersPublic(SQLModel): - data: list[UserPublic] - count: int - - -# Shared properties -class ItemBase(SQLModel): - title: str = Field(min_length=1, max_length=255) - description: str | None = Field(default=None, max_length=255) - - -# Properties to receive on item creation -class ItemCreate(ItemBase): - pass - - -# Properties to receive on item update -class ItemUpdate(ItemBase): - title: str | None = Field(default=None, min_length=1, max_length=255) # type: ignore - + """ + Public properties for users. + """ -# Database model, database table inferred from class name -class Item(ItemBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - title: str = Field(max_length=255) - owner_id: uuid.UUID = Field( - foreign_key="user.id", nullable=False, ondelete="CASCADE" - ) - owner: User | None = Relationship(back_populates="items") - - -# Properties to return via API, id is always required -class ItemPublic(ItemBase): - id: uuid.UUID - owner_id: uuid.UUID - - -class ItemsPublic(SQLModel): - data: list[ItemPublic] + # List of public user data + data: list[UserPublic] + # Total count of users count: int - - -# Generic message -class Message(SQLModel): - message: str - - -# JSON payload containing access token -class Token(SQLModel): - access_token: str - token_type: str = "bearer" - - -# Contents of JWT token -class TokenPayload(SQLModel): - sub: str | None = None - - -class NewPassword(SQLModel): - token: str - new_password: str = Field(min_length=8, max_length=40) diff --git a/backend/app/scripts/__init__.py b/backend/app/scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/app/scripts/backend_pre_start.py b/backend/app/scripts/backend_pre_start.py new file mode 100644 index 0000000000..fa3a3d09c9 --- /dev/null +++ b/backend/app/scripts/backend_pre_start.py @@ -0,0 +1,85 @@ +import logging + +from sqlalchemy import Engine +from sqlmodel import Session, select +from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed + +from app.core.db import engine + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Define retry parameters +max_tries = 60 * 5 # 5 minutes +wait_seconds = 1 + + +# Decorator for retrying the database initialization +@retry( + stop=stop_after_attempt(max_tries), # Stop after max_tries attempts + wait=wait_fixed(wait_seconds), # Wait for wait_seconds between attempts + before=before_log(logger, logging.INFO), # Log before each attempt + after=after_log(logger, logging.WARN), # Log after each failed attempt +) +def init(db_engine: Engine) -> None: + """ + Initialize the database + + Attempts to create a database session and execute a simple query to check if the database is awake. + This function is decorated with retry to handle potential connection issues. + + Args: + db_engine: The SQLAlchemy engine instance to connect to the database. + + Returns: + None + + Raises: + Exception: If there's an error during database initialization after all retry attempts. + + Notes: + The function uses a retry decorator to attempt the initialization multiple times + before giving up. It logs the initialization process and any errors that occur. + """ + try: + with Session(db_engine) as session: + # Try to create session to check if DB is awake + session.exec(select(1)) + except Exception as e: + # Log any errors that occur during initialization + logger.error(f"Database initialization error: {e}") + raise + finally: + # Ensure the session is closed and any open transactions are rolled back + session.rollback() + session.close() + + +def main() -> None: + """ + Main function to initialize the service + + Calls the init function to initialize the database and logs the process. + + Args: + None + + Returns: + None + + Raises: + None + + Notes: + This function is the entry point for the service initialization process. + It logs the start and end of the initialization process. + """ + logger.info("Initializing service") + init(engine) # Call the init function with the engine + logger.info("Service finished initializing") + + +# Entry point of the script +if __name__ == "__main__": + main() diff --git a/backend/app/scripts/initial_data.py b/backend/app/scripts/initial_data.py new file mode 100644 index 0000000000..06b8df5adc --- /dev/null +++ b/backend/app/scripts/initial_data.py @@ -0,0 +1,79 @@ +import logging + +from sqlmodel import Session + +from app.core.db import engine, init_db + +# Configure logging to display INFO level messages +logging.basicConfig(level=logging.INFO) +# Create a logger instance for this module +logger = logging.getLogger(__name__) + + +def init() -> None: + """ + Initialize the database with initial data. + + Creates a new database session and calls the init_db function to set up the initial database state. + This function is not protected and can be called directly. + + Args: + None + + Returns: + None + + Raises: + SQLAlchemyError: If there's an error during database operations. + + Notes: + This function uses a context manager to ensure proper handling of the database session. + """ + # Create a new database session + with Session(engine) as session: + # Call the init_db function to set up the initial database state + init_db(session) + + +def main() -> None: + """ + Main function to create initial data. + + Logs the start of the data creation process, calls the init function to initialize the database, + and logs the completion of the process. This function is not protected and can be called directly. + + Args: + None + + Returns: + None + + Notes: + This function uses logging to provide information about the data creation process. + """ + # Log the start of the data creation process + logger.info("Creating initial data") + # Call the init function to initialize the database + init() + # Log the completion of the data creation process + logger.info("Initial data created") + + +if __name__ == "__main__": + """ + Entry point for the script. + + Executes the main function when the script is run directly. + This block is not protected and will run if the script is executed as the main program. + + Args: + None + + Returns: + None + + Notes: + This is a standard Python idiom for scripts that can be both imported and run directly. + """ + # Execute the main function when the script is run directly + main() diff --git a/backend/app/scripts/tests_pre_start.py b/backend/app/scripts/tests_pre_start.py new file mode 100644 index 0000000000..d94b18cbfe --- /dev/null +++ b/backend/app/scripts/tests_pre_start.py @@ -0,0 +1,85 @@ +import logging + +from sqlalchemy import Engine +from sqlmodel import Session, select +from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed + +from app.core.db import engine + +# Configure logging to display INFO level messages +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Define retry parameters +max_tries = 60 * 5 # 5 minutes +wait_seconds = 1 + + +# Decorator for retrying the database initialization +@retry( + stop=stop_after_attempt(max_tries), # Stop after max_tries attempts + wait=wait_fixed(wait_seconds), # Wait for wait_seconds between attempts + before=before_log(logger, logging.INFO), # Log before each attempt + after=after_log(logger, logging.WARN), # Log after each failed attempt +) +def init(db_engine: Engine) -> None: + """ + Initialize the database. + + Attempts to create a session and execute a simple query to check if the database is awake. + This function is not protected and can be called directly. + + Args: + db_engine: The SQLAlchemy engine instance to use for database connection. + + Returns: + None + + Raises: + Exception: If there's an error during database initialization. + + Notes: + This function uses a retry decorator to attempt initialization multiple times. + """ + try: + # Try to create session to check if DB is awake + with Session(db_engine) as session: + session.exec(select(1)) + except Exception as e: + # Log any errors that occur during initialization + logger.error(e) + raise e + finally: + # Ensure the session is closed and any open transactions are rolled back + session.rollback() + session.close() + + +def main() -> None: + """ + Initialize the service. + + Logs the start of service initialization, calls the init function to initialize the database, + and logs the completion of the process. This function is not protected and can be called directly. + + Args: + None + + Returns: + None + + Notes: + This function uses logging to provide information about the initialization process. + """ + # Log the start of service initialization + logger.info("Initializing service") + # Call the init function with the engine to initialize the database + init(engine) + # Log the completion of service initialization + logger.info("Service finished initializing") + + +# Entry point of the script +if __name__ == "__main__": + # Call the main function when the script is run directly + main() diff --git a/backend/app/tests/api/routes/test_login.py b/backend/app/tests/api/routes/test_login.py index 34fe8ee560..6fcc2b2a88 100644 --- a/backend/app/tests/api/routes/test_login.py +++ b/backend/app/tests/api/routes/test_login.py @@ -1,104 +1,634 @@ from unittest.mock import patch from fastapi.testclient import TestClient -from sqlmodel import Session, select +from pytest import mark +from sqlmodel import Session +from app import crud from app.core.config import settings from app.core.security import verify_password -from app.models import User +from app.models import UserCreate +from app.tests.utils.utils import random_email from app.utils import generate_password_reset_token def test_get_access_token(client: TestClient) -> None: + """ + Test the get access token endpoint. + + This function tests the API endpoint for obtaining an access token. + + Args: + client (TestClient): The test client for making HTTP requests. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test prepares login data with superuser credentials, sends a POST request + to the login endpoint, and verifies the response contains a valid access token. + """ + # Prepare login data with superuser credentials login_data = { "username": settings.FIRST_SUPERUSER, "password": settings.FIRST_SUPERUSER_PASSWORD, } + # Send POST request to login endpoint r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) tokens = r.json() + # Assert successful response and presence of access token assert r.status_code == 200 assert "access_token" in tokens assert tokens["access_token"] +def test_login_(client: TestClient) -> None: + """ + Test the login without rate limit. + + This function tests the login endpoint without rate limiting to ensure + the limiter is disabled for testing unless explicitly enabled. + + Args: + client (TestClient): The test client for making HTTP requests. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test prepares login data with superuser credentials, sends a POST request + to the login endpoint, and verifies that the request is not rate limited. + """ + # Prepare login data with superuser credentials + login_data = { + "username": settings.FIRST_SUPERUSER, + "password": settings.FIRST_SUPERUSER_PASSWORD, + } + # Send POST request to login endpoint + r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + # Assert successful response, indicating no rate limiting + assert r.status_code == 200, "Request should not be rate limited" + + def test_get_access_token_incorrect_password(client: TestClient) -> None: + """ + Test the get access token endpoint with an incorrect password. + + This function tests the API's response when attempting to obtain + an access token with an incorrect password. + + Args: + client (TestClient): The test client for making HTTP requests. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test prepares login data with an incorrect password, + sends a POST request to the login endpoint, and verifies + that the API returns a bad request response. + """ + # Prepare login data with incorrect password login_data = { "username": settings.FIRST_SUPERUSER, "password": "incorrect", } + # Send POST request to login endpoint + r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + # Assert bad request response + assert r.status_code == 400 + + +def test_get_access_token_user_not_found(client: TestClient) -> None: + """ + Test the get access token endpoint with a non-existent user. + + This function tests the API's response when attempting to obtain + an access token for a user that does not exist in the system. + + Args: + client (TestClient): The test client for making HTTP requests. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test prepares login data with a non-existent user, + sends a POST request to the login endpoint, and verifies + that the API returns a bad request response with the correct error message. + """ + # Prepare login data with non-existent user + login_data = { + "username": "nonexistent@example.com", + "password": "password", + } + # Send POST request to login endpoint r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + # Assert bad request response and correct error message assert r.status_code == 400 + assert r.json()["detail"] == "Incorrect email or password" + + +def test_get_access_token_inactive_user(client: TestClient, db: Session) -> None: + """ + Test the get access token endpoint with an inactive user. + + This function tests the API's response when attempting to obtain + an access token for an inactive user. + + Args: + client (TestClient): The test client for making HTTP requests. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test creates an inactive user, prepares login data for this user, + sends a POST request to the login endpoint, and verifies that the API + returns a bad request response with the correct error message. + """ + # Create an inactive user + user_in = UserCreate( + email=random_email(), + password="password1234!", + is_active=False, + ) + user = crud.create_user(session=db, user_create=user_in) + # Prepare login data for inactive user + login_data = { + "username": user.email, + "password": "password1234!", + } + # Send POST request to login endpoint + r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + # Assert bad request response and correct error message + assert r.status_code == 400 + assert r.json()["detail"] == "Inactive user" def test_use_access_token( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test the use access token endpoint. + + This function tests the API endpoint for using an access token. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test sends a POST request to the test token endpoint with a superuser token, + and verifies that the response is successful and contains the expected data. + """ + # Send POST request to test token endpoint with superuser token r = client.post( f"{settings.API_V1_STR}/login/test-token", headers=superuser_token_headers, ) result = r.json() + # Assert successful response and presence of email in result assert r.status_code == 200 assert "email" in result +@mark.order("last") def test_recovery_password( client: TestClient, normal_user_token_headers: dict[str, str] ) -> None: + """ + Test the recovery password endpoint. + + This function tests the API endpoint for password recovery. + + Args: + client (TestClient): The test client for making HTTP requests. + normal_user_token_headers (dict[str, str]): Headers containing a normal user's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + """ + with ( patch("app.core.config.settings.SMTP_HOST", "smtp.example.com"), patch("app.core.config.settings.SMTP_USER", "admin@example.com"), ): - email = "test@example.com" + email = settings.FIRST_SUPERUSER + # Send POST request to password recovery endpoint r = client.post( f"{settings.API_V1_STR}/password-recovery/{email}", headers=normal_user_token_headers, ) + # Assert successful response and correct message assert r.status_code == 200 - assert r.json() == {"message": "Password recovery email sent"} + assert r.json() == { + "message": "Password recovery email sent if the account exists." + } +@mark.order("last") def test_recovery_password_user_not_exits( client: TestClient, normal_user_token_headers: dict[str, str] ) -> None: - email = "jVgQr@example.com" + """ + Test the recovery password endpoint with a non-existent user. + + This function tests the API's response when attempting to recover + the password for a user that does not exist in the system. + + Args: + client (TestClient): The test client for making HTTP requests. + normal_user_token_headers (dict[str, str]): Headers containing a normal user's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test sends a POST request to the password recovery endpoint with a non-existent email, + and verifies that the API returns a successful response with the expected message, + maintaining user privacy by not disclosing whether the account exists. + """ + # Prepare non-existent email + email = "this_email_does_not_exist@example.com" + # Send POST request to password recovery endpoint r = client.post( f"{settings.API_V1_STR}/password-recovery/{email}", headers=normal_user_token_headers, ) - assert r.status_code == 404 + # Assert successful response and correct message + assert r.status_code == 200 + assert r.json() == { + "message": "Password recovery email sent if the account exists." + } +@mark.order("last") def test_reset_password( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test the reset password endpoint. + + This function tests the API endpoint for resetting a user's password. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test generates a password reset token, sends a POST request to reset the password, + verifies the response, checks the password change in the database, and then reverts + the password back to the original for other tests to run correctly. + """ + # Generate password reset token for superuser token = generate_password_reset_token(email=settings.FIRST_SUPERUSER) + # Prepare reset password data data = {"new_password": "changethis", "token": token} + # Send POST request to reset password endpoint r = client.post( f"{settings.API_V1_STR}/reset-password/", headers=superuser_token_headers, json=data, ) + # Assert successful response and correct message assert r.status_code == 200 assert r.json() == {"message": "Password updated successfully"} - - user_query = select(User).where(User.email == settings.FIRST_SUPERUSER) - user = db.exec(user_query).first() + # Verify password change in database + user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) assert user assert verify_password(data["new_password"], user.hashed_password) + # Revert the password back to FIRST_SUPERUSER_PASSWORD so the other tests can continue to run + revert_token = generate_password_reset_token(email=settings.FIRST_SUPERUSER) + revert_data = { + "new_password": settings.FIRST_SUPERUSER_PASSWORD, + "token": revert_token, + } + revert_r = client.post( + f"{settings.API_V1_STR}/reset-password/", + headers=superuser_token_headers, + json=revert_data, + ) + # Assert successful revert + assert revert_r.status_code == 200 + assert revert_r.json() == {"message": "Password updated successfully"} +@mark.order("last") def test_reset_password_invalid_token( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test the reset password endpoint with an invalid token. + + This function tests the API's response when attempting to reset + a password with an invalid token. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test prepares reset password data with an invalid token, + sends a POST request to the reset password endpoint, and verifies + that the API returns a bad request response with the correct error message. + """ + # Prepare reset password data with invalid token data = {"new_password": "changethis", "token": "invalid"} + # Send POST request to reset password endpoint r = client.post( f"{settings.API_V1_STR}/reset-password/", headers=superuser_token_headers, json=data, ) response = r.json() - + # Assert bad request response and correct error message assert "detail" in response assert r.status_code == 400 assert response["detail"] == "Invalid token" + + +@mark.order("last") +def test_reset_password_user_not_found( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + """ + Test the reset password endpoint with a non-existent user. + + This function tests the API's response when attempting to reset + the password for a user that does not exist in the system. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test generates a token for a non-existent user, sends a POST request + to the reset password endpoint, and verifies that the API returns a not found + response with the correct error message. + """ + # Generate token for non-existent user + token = generate_password_reset_token(email="this_email_does_not_exist@example.com") + data = {"new_password": "changethis", "token": token} + # Send POST request to reset password endpoint + r = client.post( + f"{settings.API_V1_STR}/reset-password/", + headers=superuser_token_headers, + json=data, + ) + # Assert not found response and correct error message + assert r.status_code == 404 + assert ( + r.json()["detail"] == "The user with this email does not exist in the system." + ) + + +@mark.order("last") +def test_reset_password_expired_token( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + """ + Test the reset password endpoint with an expired token. + + This function tests the API's response when attempting to reset + a password with an expired token. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test mocks the verify_password_reset_token function to simulate an expired token, + sends a POST request to the reset password endpoint, and verifies that the API + returns a bad request response with the correct error message. + """ + # Mock verify_password_reset_token to return None (simulating expired token) + with patch("app.utils.verify_password_reset_token", return_value=None): + data = {"new_password": "changethis", "token": "expired_token"} + # Send POST request to reset password endpoint + r = client.post( + f"{settings.API_V1_STR}/reset-password/", + headers=superuser_token_headers, + json=data, + ) + # Assert bad request response and correct error message + assert r.status_code == 400 + assert r.json()["detail"] == "Invalid token" + + +@mark.order("last") +def test_reset_password_inactive_user( + client: TestClient, superuser_token_headers: dict[str, str], db: Session +) -> None: + """ + Test the reset password endpoint with an inactive user. + + This function tests the API's response when attempting to reset + the password for an inactive user. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test creates an inactive user, generates a token for this user, + sends a POST request to the reset password endpoint, and verifies that + the API returns a bad request response with the correct error message. + """ + # Create an inactive user + user_in = UserCreate( + email=random_email(), + password="password1234!", + is_active=False, + ) + user = crud.create_user(session=db, user_create=user_in) + # Generate token for inactive user + token = generate_password_reset_token(email=user.email) + data = {"new_password": user_in.password, "token": token} + # Send POST request to reset password endpoint + r = client.post( + f"{settings.API_V1_STR}/reset-password/", + headers=superuser_token_headers, + json=data, + ) + # Assert bad request response and correct error message + assert r.status_code == 400 + assert r.json()["detail"] == "Inactive user" + + +@mark.order("last") +def test_recover_password_html_content( + client: TestClient, superuser_token_headers: dict[str, str], db: Session +) -> None: + """ + Test the recover password html content endpoint. + + This function tests the API endpoint for retrieving the HTML content + for password recovery. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test creates a test user, sends a POST request to the password recovery + HTML content endpoint, and verifies that the response is successful and + contains the expected content type and headers. + """ + # Create a test user + user_in = UserCreate( + email=random_email(), + password="password1234!", + is_active=True, + ) + user = crud.create_user(session=db, user_create=user_in) + # Send POST request to password recovery HTML content endpoint + response = client.post( + f"{settings.API_V1_STR}/password-recovery-html-content/{user.email}", + headers=superuser_token_headers, + ) + # Assert successful response and correct content type + assert response.status_code == 200 + assert response.headers.get("content-type") == "text/html; charset=utf-8" + assert "subject:" in response.headers + # assert "Reset your password" in response.text + + +@mark.order("last") +def test_recover_password_html_content_user_not_found( + client: TestClient, superuser_token_headers: dict[str, str] +) -> None: + """ + Test the recover password html content endpoint with a non-existent user. + + This function tests the API's response when attempting to retrieve + the password recovery HTML content for a user that does not exist in the system. + + Args: + client (TestClient): The test client for making HTTP requests. + superuser_token_headers (dict[str, str]): Headers containing the superuser's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test sends a POST request to the password recovery HTML content endpoint + with a non-existent email, and verifies that the API returns a not found + response with the correct error message. + """ + # Send POST request to password recovery HTML content endpoint with non-existent email + response = client.post( + f"{settings.API_V1_STR}/password-recovery-html-content/this_email_does_not_exist@example.com", + headers=superuser_token_headers, + ) + # Assert not found response and correct error message + assert response.status_code == 404 + assert ( + response.json()["detail"] + == "The user with this username does not exist in the system." + ) + + +@mark.order("last") +def test_recover_password_html_content_not_superuser( + client: TestClient, normal_user_token_headers: dict[str, str] +) -> None: + """ + Test the recover password html content endpoint with a normal user. + + This function tests the API's response when a normal user attempts to + retrieve the password recovery HTML content. + + Args: + client (TestClient): The test client for making HTTP requests. + normal_user_token_headers (dict[str, str]): Headers containing a normal user's authentication token. + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail. + + Notes: + This test sends a POST request to the password recovery HTML content endpoint + with a normal user's token, and verifies that the API returns a forbidden + response with the correct error message. + """ + # Send POST request to password recovery HTML content endpoint with normal user token + response = client.post( + f"{settings.API_V1_STR}/password-recovery-html-content/{random_email()}", + headers=normal_user_token_headers, + ) + # Assert forbidden response and correct error message + assert response.status_code == 403 + assert response.json()["detail"] == "The user doesn't have enough privileges" diff --git a/backend/app/tests/api/routes/test_users.py b/backend/app/tests/api/routes/test_users.py index ba9be65426..258d1cbaca 100644 --- a/backend/app/tests/api/routes/test_users.py +++ b/backend/app/tests/api/routes/test_users.py @@ -1,5 +1,4 @@ import uuid -from unittest.mock import patch from fastapi.testclient import TestClient from sqlmodel import Session, select @@ -14,8 +13,24 @@ def test_get_users_superuser_me( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test getting the current superuser. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Send a GET request to retrieve the current superuser r = client.get(f"{settings.API_V1_STR}/users/me", headers=superuser_token_headers) current_user = r.json() + + # Assert that the user exists and has the correct attributes assert current_user assert current_user["is_active"] is True assert current_user["is_superuser"] @@ -25,8 +40,24 @@ def test_get_users_superuser_me( def test_get_users_normal_user_me( client: TestClient, normal_user_token_headers: dict[str, str] ) -> None: + """ + Test getting the current normal user. + + Args: + client (TestClient): The test client. + normal_user_token_headers (dict[str, str]): The normal user token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Send a GET request to retrieve the current normal user r = client.get(f"{settings.API_V1_STR}/users/me", headers=normal_user_token_headers) current_user = r.json() + + # Assert that the user exists and has the correct attributes assert current_user assert current_user["is_active"] is True assert current_user["is_superuser"] is False @@ -36,52 +67,102 @@ def test_get_users_normal_user_me( def test_create_user_new_email( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: - with ( - patch("app.utils.send_email", return_value=None), - patch("app.core.config.settings.SMTP_HOST", "smtp.example.com"), - patch("app.core.config.settings.SMTP_USER", "admin@example.com"), - ): - username = random_email() - password = random_lower_string() - data = {"email": username, "password": password} - r = client.post( - f"{settings.API_V1_STR}/users/", - headers=superuser_token_headers, - json=data, - ) - assert 200 <= r.status_code < 300 - created_user = r.json() - user = crud.get_user_by_email(session=db, email=username) - assert user - assert user.email == created_user["email"] + """ + Test creating a new user with a new email. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Generate random email and password for the new user + username = random_email() + password = random_lower_string() + data = {"email": username, "password": password} + # Send a POST request to create a new user + r = client.post( + f"{settings.API_V1_STR}/users/", + headers=superuser_token_headers, + json=data, + ) + # Assert that the request was successful + assert 200 <= r.status_code < 300 + created_user = r.json() + # Retrieve the user from the database and assert its correctness + user = crud.get_user_by_email(session=db, email=username) + assert user + assert user.email == created_user["email"] + assert created_user["email"] == username def test_get_existing_user( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test getting an existing user. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) user_id = user.id + + # Send a GET request to retrieve the created user r = client.get( f"{settings.API_V1_STR}/users/{user_id}", headers=superuser_token_headers, ) + + # Assert that the request was successful assert 200 <= r.status_code < 300 api_user = r.json() + + # Retrieve the user from the database and assert its correctness existing_user = crud.get_user_by_email(session=db, email=username) assert existing_user assert existing_user.email == api_user["email"] def test_get_existing_user_current_user(client: TestClient, db: Session) -> None: + """ + Test getting the current user as an existing user. + + Args: + client (TestClient): The test client. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) user_id = user.id + # Log in as the created user login_data = { "username": username, "password": password, @@ -91,12 +172,17 @@ def test_get_existing_user_current_user(client: TestClient, db: Session) -> None a_token = tokens["access_token"] headers = {"Authorization": f"Bearer {a_token}"} + # Send a GET request to retrieve the current user r = client.get( f"{settings.API_V1_STR}/users/{user_id}", headers=headers, ) + + # Assert that the request was successful assert 200 <= r.status_code < 300 api_user = r.json() + + # Retrieve the user from the database and assert its correctness existing_user = crud.get_user_by_email(session=db, email=username) assert existing_user assert existing_user.email == api_user["email"] @@ -105,10 +191,26 @@ def test_get_existing_user_current_user(client: TestClient, db: Session) -> None def test_get_existing_user_permissions_error( client: TestClient, normal_user_token_headers: dict[str, str] ) -> None: + """ + Test getting an existing user with insufficient permissions. + + Args: + client (TestClient): The test client. + normal_user_token_headers (dict[str, str]): The normal user token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to retrieve a user with insufficient privileges r = client.get( f"{settings.API_V1_STR}/users/{uuid.uuid4()}", headers=normal_user_token_headers, ) + + # Assert that the request was forbidden assert r.status_code == 403 assert r.json() == {"detail": "The user doesn't have enough privileges"} @@ -116,11 +218,27 @@ def test_get_existing_user_permissions_error( def test_create_user_existing_username( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test creating a user with an existing username. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() - # username = email password = random_lower_string() user_in = UserCreate(email=username, password=password) crud.create_user(session=db, user_create=user_in) + + # Attempt to create a user with the same email data = {"email": username, "password": password} r = client.post( f"{settings.API_V1_STR}/users/", @@ -128,6 +246,8 @@ def test_create_user_existing_username( json=data, ) created_user = r.json() + + # Assert that the request was unsuccessful assert r.status_code == 400 assert "_id" not in created_user @@ -135,6 +255,20 @@ def test_create_user_existing_username( def test_create_user_by_normal_user( client: TestClient, normal_user_token_headers: dict[str, str] ) -> None: + """ + Test creating a user by a normal user. + + Args: + client (TestClient): The test client. + normal_user_token_headers (dict[str, str]): The normal user token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to create a new user with normal user privileges username = random_email() password = random_lower_string() data = {"email": username, "password": password} @@ -143,12 +277,29 @@ def test_create_user_by_normal_user( headers=normal_user_token_headers, json=data, ) + + # Assert that the request was forbidden assert r.status_code == 403 def test_retrieve_users( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test retrieving all users. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create two new users username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) @@ -159,18 +310,73 @@ def test_retrieve_users( user_in2 = UserCreate(email=username2, password=password2) crud.create_user(session=db, user_create=user_in2) + # Retrieve all users r = client.get(f"{settings.API_V1_STR}/users/", headers=superuser_token_headers) all_users = r.json() + # Assert that multiple users were retrieved assert len(all_users["data"]) > 1 assert "count" in all_users for item in all_users["data"]: assert "email" in item +def test_read_users_pagination( + client: TestClient, superuser_token_headers: dict[str, str], db: Session +) -> None: + """ + Test reading users with pagination. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create multiple users + for _ in range(5): + crud.create_user( + session=db, + user_create=UserCreate( + email=random_email(), password=random_lower_string() + ), + ) + + # Retrieve users with pagination + r = client.get( + f"{settings.API_V1_STR}/users/?skip=1&limit=2", headers=superuser_token_headers + ) + all_users = r.json() + + # Assert that the pagination works correctly + assert r.status_code == 200 + assert len(all_users["data"]) == 2 + assert all_users["count"] > 2 + + def test_update_user_me( client: TestClient, normal_user_token_headers: dict[str, str], db: Session ) -> None: + """ + Test updating the current user. + + Args: + client (TestClient): The test client. + normal_user_token_headers (dict[str, str]): The normal user token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Update the current user's full name and email full_name = "Updated Name" email = random_email() data = {"full_name": full_name, "email": email} @@ -179,11 +385,14 @@ def test_update_user_me( headers=normal_user_token_headers, json=data, ) + + # Assert that the update was successful assert r.status_code == 200 updated_user = r.json() assert updated_user["email"] == email assert updated_user["full_name"] == full_name + # Verify the update in the database user_query = select(User).where(User.email == email) user_db = db.exec(user_query).first() assert user_db @@ -194,6 +403,21 @@ def test_update_user_me( def test_update_password_me( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test updating the current user's password. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Update the current user's password new_password = random_lower_string() data = { "current_password": settings.FIRST_SUPERUSER_PASSWORD, @@ -204,10 +428,13 @@ def test_update_password_me( headers=superuser_token_headers, json=data, ) + + # Assert that the update was successful assert r.status_code == 200 updated_user = r.json() assert updated_user["message"] == "Password updated successfully" + # Verify the password update in the database user_query = select(User).where(User.email == settings.FIRST_SUPERUSER) user_db = db.exec(user_query).first() assert user_db @@ -233,6 +460,20 @@ def test_update_password_me( def test_update_password_me_incorrect_password( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test updating the current user's password with an incorrect current password. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to update password with incorrect current password new_password = random_lower_string() data = {"current_password": new_password, "new_password": new_password} r = client.patch( @@ -240,6 +481,8 @@ def test_update_password_me_incorrect_password( headers=superuser_token_headers, json=data, ) + + # Assert that the update was unsuccessful assert r.status_code == 400 updated_user = r.json() assert updated_user["detail"] == "Incorrect password" @@ -248,17 +491,35 @@ def test_update_password_me_incorrect_password( def test_update_user_me_email_exists( client: TestClient, normal_user_token_headers: dict[str, str], db: Session ) -> None: + """ + Test updating the current user's email to an existing email. + + Args: + client (TestClient): The test client. + normal_user_token_headers (dict[str, str]): The normal user token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) + # Attempt to update current user's email to an existing email data = {"email": user.email} r = client.patch( f"{settings.API_V1_STR}/users/me", headers=normal_user_token_headers, json=data, ) + + # Assert that the update was unsuccessful assert r.status_code == 409 assert r.json()["detail"] == "User with this email already exists" @@ -266,6 +527,20 @@ def test_update_user_me_email_exists( def test_update_password_me_same_password_error( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test updating the current user's password to the same password. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to update password to the same password data = { "current_password": settings.FIRST_SUPERUSER_PASSWORD, "new_password": settings.FIRST_SUPERUSER_PASSWORD, @@ -275,6 +550,8 @@ def test_update_password_me_same_password_error( headers=superuser_token_headers, json=data, ) + + # Assert that the update was unsuccessful assert r.status_code == 400 updated_user = r.json() assert ( @@ -283,6 +560,20 @@ def test_update_password_me_same_password_error( def test_register_user(client: TestClient, db: Session) -> None: + """ + Test registering a new user. + + Args: + client (TestClient): The test client. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Register a new user username = random_email() password = random_lower_string() full_name = random_lower_string() @@ -291,11 +582,14 @@ def test_register_user(client: TestClient, db: Session) -> None: f"{settings.API_V1_STR}/users/signup", json=data, ) + + # Assert that the registration was successful assert r.status_code == 200 created_user = r.json() assert created_user["email"] == username assert created_user["full_name"] == full_name + # Verify the user in the database user_query = select(User).where(User.email == username) user_db = db.exec(user_query).first() assert user_db @@ -305,6 +599,19 @@ def test_register_user(client: TestClient, db: Session) -> None: def test_register_user_already_exists_error(client: TestClient) -> None: + """ + Test registering a user with an existing email. + + Args: + client (TestClient): The test client. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to register a user with an existing email password = random_lower_string() full_name = random_lower_string() data = { @@ -316,6 +623,8 @@ def test_register_user_already_exists_error(client: TestClient) -> None: f"{settings.API_V1_STR}/users/signup", json=data, ) + + # Assert that the registration was unsuccessful assert r.status_code == 400 assert r.json()["detail"] == "The user with this email already exists in the system" @@ -323,22 +632,41 @@ def test_register_user_already_exists_error(client: TestClient) -> None: def test_update_user( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test updating a user. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) + # Update the user's full name data = {"full_name": "Updated_full_name"} r = client.patch( f"{settings.API_V1_STR}/users/{user.id}", headers=superuser_token_headers, json=data, ) + + # Assert that the update was successful assert r.status_code == 200 updated_user = r.json() assert updated_user["full_name"] == "Updated_full_name" + # Verify the update in the database user_query = select(User).where(User.email == username) user_db = db.exec(user_query).first() db.refresh(user_db) @@ -346,15 +674,31 @@ def test_update_user( assert user_db.full_name == "Updated_full_name" -def test_update_user_not_exists( +def test_update_user_non_existent( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test updating a non-existent user. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to update a non-existent user data = {"full_name": "Updated_full_name"} r = client.patch( f"{settings.API_V1_STR}/users/{uuid.uuid4()}", headers=superuser_token_headers, json=data, ) + + # Assert that the update was unsuccessful assert r.status_code == 404 assert r.json()["detail"] == "The user with this id does not exist in the system" @@ -362,6 +706,21 @@ def test_update_user_not_exists( def test_update_user_email_exists( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test updating a user's email to an existing email. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create two users username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) @@ -372,23 +731,41 @@ def test_update_user_email_exists( user_in2 = UserCreate(email=username2, password=password2) user2 = crud.create_user(session=db, user_create=user_in2) + # Attempt to update the first user's email to the second user's email data = {"email": user2.email} r = client.patch( f"{settings.API_V1_STR}/users/{user.id}", headers=superuser_token_headers, json=data, ) + + # Assert that the update was unsuccessful assert r.status_code == 409 assert r.json()["detail"] == "User with this email already exists" def test_delete_user_me(client: TestClient, db: Session) -> None: + """ + Test deleting the current user. + + Args: + client (TestClient): The test client. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) user_id = user.id + # Log in as the created user login_data = { "username": username, "password": password, @@ -398,13 +775,18 @@ def test_delete_user_me(client: TestClient, db: Session) -> None: a_token = tokens["access_token"] headers = {"Authorization": f"Bearer {a_token}"} + # Delete the current user r = client.delete( f"{settings.API_V1_STR}/users/me", headers=headers, ) + + # Assert that the deletion was successful assert r.status_code == 200 deleted_user = r.json() assert deleted_user["message"] == "User deleted successfully" + + # Verify that the user no longer exists in the database result = db.exec(select(User).where(User.id == user_id)).first() assert result is None @@ -416,10 +798,26 @@ def test_delete_user_me(client: TestClient, db: Session) -> None: def test_delete_user_me_as_superuser( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test deleting the current user as a superuser. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to delete the superuser account r = client.delete( f"{settings.API_V1_STR}/users/me", headers=superuser_token_headers, ) + + # Assert that the deletion was unsuccessful assert r.status_code == 403 response = r.json() assert response["detail"] == "Super users are not allowed to delete themselves" @@ -428,29 +826,86 @@ def test_delete_user_me_as_superuser( def test_delete_user_super_user( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test deleting a user as a superuser. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) user_id = user.id + + # Delete the user as a superuser r = client.delete( f"{settings.API_V1_STR}/users/{user_id}", headers=superuser_token_headers, ) + + # Assert that the deletion was successful assert r.status_code == 200 deleted_user = r.json() assert deleted_user["message"] == "User deleted successfully" + + # Verify that the user no longer exists in the database result = db.exec(select(User).where(User.id == user_id)).first() assert result is None +def test_delete_user_self( + client: TestClient, superuser_token_headers: dict[str, str], db: Session +) -> None: + """ + Test deleting the current user as a superuser. + + Args: + client (TestClient): The test client. + superuser_token_headers (dict[str, str]): The superuser token headers. + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + """ + # Attempt to delete the superuser account + super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + assert super_user + r = client.delete( + f"{settings.API_V1_STR}/users/{super_user.id}", + headers=superuser_token_headers, + ) + + # Assert that the deletion was unsuccessful + assert r.status_code == 403 + assert r.json()["detail"] == "Super users are not allowed to delete themselves" + + def test_delete_user_not_found( client: TestClient, superuser_token_headers: dict[str, str] ) -> None: + """ + Test that the delete_user_not_found function raises an HTTPException when the user does not exist. + """ + # Attempt to delete a non-existent user r = client.delete( f"{settings.API_V1_STR}/users/{uuid.uuid4()}", headers=superuser_token_headers, ) + + # Assert that the deletion was unsuccessful assert r.status_code == 404 assert r.json()["detail"] == "User not found" @@ -458,6 +913,10 @@ def test_delete_user_not_found( def test_delete_user_current_super_user_error( client: TestClient, superuser_token_headers: dict[str, str], db: Session ) -> None: + """ + Test that the delete_user_current_super_user_error function raises an HTTPException when the current superuser tries to delete another superuser. + """ + # Attempt to delete the superuser account super_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) assert super_user user_id = super_user.id @@ -466,6 +925,8 @@ def test_delete_user_current_super_user_error( f"{settings.API_V1_STR}/users/{user_id}", headers=superuser_token_headers, ) + + # Assert that the deletion was unsuccessful assert r.status_code == 403 assert r.json()["detail"] == "Super users are not allowed to delete themselves" @@ -473,14 +934,21 @@ def test_delete_user_current_super_user_error( def test_delete_user_without_privileges( client: TestClient, normal_user_token_headers: dict[str, str], db: Session ) -> None: + """ + Test that the delete_user_without_privileges function raises an HTTPException when the user does not have enough privileges. + """ + # Create a new user username = random_email() password = random_lower_string() user_in = UserCreate(email=username, password=password) user = crud.create_user(session=db, user_create=user_in) + # Attempt to delete the user as a normal user r = client.delete( f"{settings.API_V1_STR}/users/{user.id}", headers=normal_user_token_headers, ) + + # Assert that the deletion was unsuccessful assert r.status_code == 403 assert r.json()["detail"] == "The user doesn't have enough privileges" diff --git a/backend/app/tests/api/test_api_deps.py b/backend/app/tests/api/test_api_deps.py new file mode 100644 index 0000000000..a7a02b1451 --- /dev/null +++ b/backend/app/tests/api/test_api_deps.py @@ -0,0 +1,265 @@ +from unittest.mock import MagicMock, patch + +import jwt +import pytest +from fastapi import HTTPException +from sqlmodel import Session, select, text + +from app import crud +from app.api import deps +from app.models import TokenPayload, UserCreate +from app.tests.utils.utils import random_email, random_lower_string + + +def test_get_db() -> None: + """ + Test database session retrieval. + + This function tests that the get_db function successfully retrieves a database session. + The function is not protected and should be accessible without authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates a database session, performs a test operation, + and ensures the session is properly closed and disposed. + """ + # Get a database session using the get_db function + db = next(deps.get_db()) + # Assert that the returned object is an instance of Session + assert isinstance(db, Session) + try: + # Perform a test operation to ensure the session is working + db.exec(select(text("1"))) + # Rollback any changes to keep the database clean + db.rollback() + finally: + # Ensure the session is closed and connections are returned to the pool + db.close() + + # Flush the connection pool to release all connections + deps.engine.dispose() # type: ignore[attr-defined] + + +def test_get_current_user_valid_token(db: Session) -> None: + """ + Test current user retrieval with valid token. + + This function tests that the get_current_user function successfully retrieves the current user + when provided with a valid token. The function is protected and requires authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates a random user, mocks a valid token, and verifies + that the correct user is retrieved. + """ + # Create a random user + email = random_email() + password = random_lower_string() + user = crud.create_user( + session=db, user_create=UserCreate(email=email, password=password) + ) + + # Create a token payload with the user's ID + token_payload = TokenPayload(sub=str(user.id)) + token = MagicMock() + + # Mock the jwt.decode function to return our token payload + with patch("jwt.decode", return_value=token_payload.model_dump()): + # Call get_current_user and check if it returns the correct user + current_user = deps.get_current_user(db, token) + assert current_user.id == user.id + assert current_user.email == email + + +def test_get_current_user_invalid_token(db: Session) -> None: + """ + Test current user retrieval with invalid token. + + This function tests that the get_current_user function raises an HTTPException + when provided with an invalid token. The function is protected and requires authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test mocks an invalid token and verifies that the correct exception is raised. + """ + token = MagicMock() + + # Mock jwt.decode to raise an InvalidTokenError + with patch("jwt.decode", side_effect=jwt.exceptions.InvalidTokenError): + # Check if get_current_user raises the correct HTTPException + with pytest.raises(HTTPException) as exc_info: + deps.get_current_user(db, token) + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "Could not validate credentials" + + +def test_get_current_user_user_not_found(db: Session) -> None: + """ + Test current user retrieval with non-existent user. + + This function tests that the get_current_user function raises an HTTPException + when the user associated with the token is not found in the database. + The function is protected and requires authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test mocks a token with a non-existent user ID and verifies + that the correct exception is raised. + """ + # Create a token payload with a non-existent user ID + token_payload = TokenPayload(sub="00000000-0000-0000-0000-000000000000") + token = MagicMock() + + # Mock jwt.decode to return our token payload + with patch("jwt.decode", return_value=token_payload.model_dump()): + # Check if get_current_user raises the correct HTTPException + with pytest.raises(HTTPException) as exc_info: + deps.get_current_user(db, token) + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "User not found" + + +def test_get_current_user_inactive_user(db: Session) -> None: + """ + Test current user retrieval with inactive user. + + This function tests that the get_current_user function raises an HTTPException + when the user associated with the token is inactive. + The function is protected and requires authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates an inactive user, mocks a token for that user, + and verifies that the correct exception is raised. + """ + # Create an inactive user + email = random_email() + password = random_lower_string() + user = crud.create_user( + session=db, + user_create=UserCreate(email=email, password=password, is_active=False), + ) + + # Create a token payload with the inactive user's ID + token_payload = TokenPayload(sub=str(user.id)) + token = MagicMock() + + # Mock jwt.decode to return our token payload + with patch("jwt.decode", return_value=token_payload.model_dump()): + # Check if get_current_user raises the correct HTTPException + with pytest.raises(HTTPException) as exc_info: + deps.get_current_user(db, token) + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Inactive user" + + +def test_get_current_active_superuser(db: Session) -> None: + """ + Test current active superuser retrieval. + + This function tests that the get_current_active_superuser function + successfully retrieves the current superuser. + The function is protected and requires superuser authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates a superuser and verifies that the function + correctly identifies and returns the superuser. + """ + # Create a superuser + email = random_email() + password = random_lower_string() + superuser = crud.create_user( + session=db, + user_create=UserCreate(email=email, password=password, is_superuser=True), + ) + + # Call get_current_active_superuser and check if it returns the correct superuser + current_user = deps.get_current_active_superuser(superuser) + assert current_user.id == superuser.id + assert current_user.is_superuser + + +def test_get_current_active_superuser_not_superuser(db: Session) -> None: + """ + Test current active superuser retrieval with non-superuser. + + This function tests that the get_current_active_superuser function + raises an HTTPException when the user is not a superuser. + The function is protected and requires superuser authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates a regular user (not a superuser) and verifies + that the correct exception is raised when trying to access + superuser-only functionality. + """ + # Create a regular user (not a superuser) + email = random_email() + password = random_lower_string() + user = crud.create_user( + session=db, + user_create=UserCreate(email=email, password=password, is_superuser=False), + ) + + # Check if get_current_active_superuser raises the correct HTTPException + with pytest.raises(HTTPException) as exc_info: + deps.get_current_active_superuser(user) + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "The user doesn't have enough privileges" diff --git a/backend/app/tests/core/test_core_config.py b/backend/app/tests/core/test_core_config.py new file mode 100644 index 0000000000..691f94a453 --- /dev/null +++ b/backend/app/tests/core/test_core_config.py @@ -0,0 +1,352 @@ +import pytest +from pydantic import ValidationError + +from app.core.config import Settings, parse_cors + + +def test_parse_cors() -> None: + """ + Test CORS parsing functionality. + + This function tests that the parse_cors function successfully parses a string of origins. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + ValueError: If an invalid input is provided to parse_cors. + + Notes: + This test covers parsing of comma-separated strings, lists, and single origins. + It also checks for proper error handling with invalid input. + """ + # Test parsing a comma-separated string of origins + assert parse_cors("http://localhost,https://example.com") == [ + "http://localhost", + "https://example.com", + ] + + # Test parsing a list of origins + assert parse_cors(["http://localhost", "https://example.com"]) == [ + "http://localhost", + "https://example.com", + ] + + # Test parsing a single origin + assert parse_cors("http://localhost") == ["http://localhost"] + + # Test that an invalid input raises a ValueError + with pytest.raises(ValueError): + parse_cors(123) + + +def test_settings_default_values() -> None: + """ + Test default Settings initialization. + + This function tests that the Settings class successfully initializes with default values. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates a Settings instance with minimal required parameters and checks + that default values are set correctly for various configuration options. + """ + # Create a Settings instance with minimal required parameters + settings = Settings( + PROJECT_NAME="Test Project", + POSTGRES_SERVER="localhost", + POSTGRES_USER="postgres", + FIRST_SUPERUSER="admin@example.com", + FIRST_SUPERUSER_PASSWORD="password123", + ) + + # Assert that default values are set correctly + assert settings.API_V1_STR == "/api/v1" + assert len(settings.SECRET_KEY) >= 32 # Ensure SECRET_KEY is sufficiently long + assert settings.ACCESS_TOKEN_EXPIRE_MINUTES == 60 * 24 * 8 # 8 days + assert settings.FRONTEND_HOST == "http://localhost:5173" + assert settings.ENVIRONMENT == "local" + + +def test_settings_custom_values() -> None: + """ + Test custom Settings initialization. + + This function tests that the Settings class successfully initializes with custom values. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test creates a Settings instance with custom values and checks + that these values are correctly set for various configuration options. + """ + # Define custom settings + custom_settings = { + "API_V1_STR": "/custom/api", + "SECRET_KEY": "mysecretkey", + "ACCESS_TOKEN_EXPIRE_MINUTES": 30, + "FRONTEND_HOST": "https://example.com", + "ENVIRONMENT": "production", + "PROJECT_NAME": "Test Project", + "POSTGRES_SERVER": "db.example.com", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpass", + "POSTGRES_DB": "testdb", + "FIRST_SUPERUSER": "admin@example.com", + "FIRST_SUPERUSER_PASSWORD": "adminpass", + } + + # Create a Settings instance with custom values + settings = Settings(**custom_settings) # type: ignore + + # Assert that custom values are set correctly + assert settings.API_V1_STR == "/custom/api" + assert settings.SECRET_KEY == "mysecretkey" + assert settings.ACCESS_TOKEN_EXPIRE_MINUTES == 30 + assert settings.FRONTEND_HOST == "https://example.com" + assert settings.ENVIRONMENT == "production" + assert settings.PROJECT_NAME == "Test Project" + assert settings.POSTGRES_SERVER == "db.example.com" + assert settings.POSTGRES_USER == "testuser" + assert settings.POSTGRES_PASSWORD == "testpass" + assert settings.POSTGRES_DB == "testdb" + assert settings.FIRST_SUPERUSER == "admin@example.com" + assert settings.FIRST_SUPERUSER_PASSWORD == "adminpass" + + +def test_settings_computed_fields() -> None: + """ + Test computed fields in Settings. + + This function tests that the Settings class successfully initializes with custom values + and correctly computes derived fields. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test focuses on computed fields like all_cors_origins and SQLALCHEMY_DATABASE_URI. + """ + # Define custom settings + custom_settings = { + "BACKEND_CORS_ORIGINS": ["http://localhost", "https://example.com"], + "FRONTEND_HOST": "http://frontend.com", + "POSTGRES_SERVER": "db.example.com", + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpass", + "POSTGRES_DB": "testdb", + "PROJECT_NAME": "Test Project", + "FIRST_SUPERUSER": "admin@example.com", + "FIRST_SUPERUSER_PASSWORD": "adminpass", + } + + # Create a Settings instance with custom values + settings = Settings(**custom_settings) # type: ignore[arg-type] + + # Assert that computed fields are correct + assert settings.all_cors_origins == [ + "http://localhost", + "https://example.com", + "http://frontend.com", + ] + assert ( + str(settings.SQLALCHEMY_DATABASE_URI) + == "postgresql+psycopg://testuser:testpass@db.example.com:5432/testdb" + ) + + +def test_settings_email_configuration() -> None: + """ + Test email configuration in Settings. + + This function tests that the Settings class successfully initializes with custom email configuration. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test focuses on email-related settings and checks if they are correctly set. + """ + # Define custom settings with email configuration + custom_settings = { + "SMTP_TLS": True, + "SMTP_PORT": 587, + "SMTP_HOST": "smtp.example.com", + "SMTP_USER": "user@example.com", + "SMTP_PASSWORD": "password123", + "EMAILS_FROM_EMAIL": "noreply@example.com", + "PROJECT_NAME": "Test Project", + "FIRST_SUPERUSER": "admin@example.com", + "FIRST_SUPERUSER_PASSWORD": "adminpass", + "POSTGRES_SERVER": "localhost", + "POSTGRES_USER": "postgres", + } + + # Create a Settings instance with custom email configuration + settings = Settings(**custom_settings) # type: ignore[arg-type] + + # Assert that email configuration is set correctly + assert settings.SMTP_TLS is True + assert settings.SMTP_PORT == 587 + assert settings.SMTP_HOST == "smtp.example.com" + assert settings.SMTP_USER == "user@example.com" + assert settings.SMTP_PASSWORD == "password123" + assert settings.EMAILS_FROM_EMAIL == "noreply@example.com" + assert settings.EMAILS_FROM_NAME == "Test Project" + assert settings.emails_enabled is True + + +def test_settings_validation() -> None: + """ + Test Settings validation. + + This function tests that the Settings class successfully validates the environment and configuration. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + ValidationError: If invalid settings are provided. + + Notes: + This test checks for proper validation of environment, POSTGRES_PORT, and BACKEND_CORS_ORIGINS. + """ + # Test that an invalid environment raises a ValidationError + with pytest.raises(ValidationError): + Settings( + ENVIRONMENT="invalid_environment", # type: ignore[arg-type] + PROJECT_NAME="Test Project", + POSTGRES_SERVER="localhost", + POSTGRES_USER="postgres", + FIRST_SUPERUSER="admin@example.com", + FIRST_SUPERUSER_PASSWORD="password123", + ) + + # Test that an invalid POSTGRES_PORT raises a ValidationError + with pytest.raises(ValidationError): + Settings( + POSTGRES_PORT="not_an_integer", # type: ignore[arg-type] + PROJECT_NAME="Test Project", + POSTGRES_SERVER="localhost", + POSTGRES_USER="postgres", + FIRST_SUPERUSER="admin@example.com", + FIRST_SUPERUSER_PASSWORD="password123", + ) + + # Test that an invalid BACKEND_CORS_ORIGINS raises a ValidationError + with pytest.raises(ValidationError): + Settings( + BACKEND_CORS_ORIGINS="not_a_valid_url", + PROJECT_NAME="Test Project", + POSTGRES_SERVER="localhost", + POSTGRES_USER="postgres", + FIRST_SUPERUSER="admin@example.com", + FIRST_SUPERUSER_PASSWORD="password123", + ) + + +def test_settings_default_secrets_warning() -> None: + """ + Test default secrets warning in Settings. + + This function tests that the Settings class raises a warning when default secrets are used in a non-production environment. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test checks for a UserWarning when default secrets are used in a non-production environment. + """ + # Test that using default secrets in non-production environment raises a warning + with pytest.warns(UserWarning): + Settings( + SECRET_KEY="changethis", + POSTGRES_PASSWORD="changethis", + FIRST_SUPERUSER_PASSWORD="changethis", + PROJECT_NAME="Test", + FIRST_SUPERUSER="admin@example.com", + POSTGRES_SERVER="localhost", + POSTGRES_USER="postgres", + ) + + +def test_settings_default_secrets_error() -> None: + """ + Test default secrets error in Settings. + + This function tests that the Settings class raises an error when default secrets are used in a production environment. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + ValueError: If default secrets are used in a production environment. + + Notes: + This test checks for a ValueError when default secrets are used in a production environment. + """ + # Test that using default secrets in production environment raises a ValueError + with pytest.raises(ValueError): + Settings( + SECRET_KEY="changethis", + POSTGRES_PASSWORD="changethis", + FIRST_SUPERUSER_PASSWORD="changethis", + ENVIRONMENT="production", + PROJECT_NAME="Test", + FIRST_SUPERUSER="admin@example.com", + POSTGRES_SERVER="localhost", + POSTGRES_USER="postgres", + ) diff --git a/backend/app/tests/core/test_core_db.py b/backend/app/tests/core/test_core_db.py new file mode 100644 index 0000000000..a96b07b259 --- /dev/null +++ b/backend/app/tests/core/test_core_db.py @@ -0,0 +1,137 @@ +from sqlmodel import Session + +from app import crud +from app.core.config import settings +from app.core.db import engine, init_db +from app.models import UserCreate + + +def test_engine() -> None: + """ + Test database engine initialization. + + This function tests if the database engine is properly initialized. + The function is not protected and does not require authentication. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test checks if the engine object is not None. + """ + # Test if the database engine is properly initialized + assert engine is not None + + +def test_init_db_creates_superuser(db: Session) -> None: + """ + Test superuser creation during database initialization. + + This function tests if init_db creates a superuser when one doesn't exist. + The function is not protected and does not require authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test ensures that a superuser is created with the correct attributes + when init_db is called and no superuser exists. + """ + + # Ensure the superuser doesn't exist by deleting it if present + user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + if user: + crud.delete_user(session=db, user_id=user.id) + + # Run init_db to create the superuser + init_db(db) + + # Check if the superuser was created with correct attributes + user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + assert user is not None + assert user.email == settings.FIRST_SUPERUSER + assert user.is_superuser + + +def test_init_db_doesnt_create_duplicate_superuser(db: Session) -> None: + """ + Test prevention of duplicate superuser creation. + + This function tests if init_db doesn't create a duplicate superuser when one already exists. + The function is not protected and does not require authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test ensures that only one superuser exists after running init_db + when a superuser is already present in the database. + """ + + # Ensure the superuser exists by creating one if not present + existing_user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + if not existing_user: + user_in = UserCreate( + email=settings.FIRST_SUPERUSER, + password=settings.FIRST_SUPERUSER_PASSWORD, + is_superuser=True, + ) + existing_user = crud.create_user(session=db, user_create=user_in) + + # Run init_db + init_db(db) + + # Check that only one superuser exists and it's the same as the existing one + users = crud.get_users(session=db) + superusers = [user for user in users if user.email == settings.FIRST_SUPERUSER] + assert len(superusers) == 1 + assert superusers[0].id == existing_user.id + + +def test_init_db_with_migrations(db: Session) -> None: + """ + Test database initialization with migrations. + + This function tests if init_db creates a superuser when run with migrations. + The function is not protected and does not require authentication. + + Args: + db (Session): The database session. + + Returns: + None + + Raises: + AssertionError: If the test fails. + + Notes: + This test verifies that a superuser is created with the correct attributes + when init_db is run, simulating a scenario with migrations. + """ + + # Run init_db + init_db(db) + + # Verify that the superuser is created with correct attributes + user = crud.get_user_by_email(session=db, email=settings.FIRST_SUPERUSER) + assert user is not None + assert user.email == settings.FIRST_SUPERUSER + assert user.is_superuser diff --git a/backend/app/tests/crud/test_crud_item.py b/backend/app/tests/crud/test_crud_item.py new file mode 100644 index 0000000000..a1d77b1e17 --- /dev/null +++ b/backend/app/tests/crud/test_crud_item.py @@ -0,0 +1,118 @@ +from sqlmodel import Session + +from app import crud +from app.models import ItemCreate, ItemUpdate, User, UserCreate +from app.tests.utils.utils import random_email, random_lower_string + + +def test_create_item(db: Session) -> None: + title = random_lower_string() + description = random_lower_string() + owner = create_random_user(db) + item_in = ItemCreate(title=title, description=description) + item = crud.create_item(session=db, item_in=item_in, owner_id=owner.id) + assert item.title == title + assert item.description == description + assert item.owner_id == owner.id + + +def test_get_item(db: Session) -> None: + title = random_lower_string() + description = random_lower_string() + owner = create_random_user(db) + item_in = ItemCreate(title=title, description=description) + item = crud.create_item(session=db, item_in=item_in, owner_id=owner.id) + stored_item = crud.get_item(session=db, item_id=item.id) + assert stored_item + assert item.id == stored_item.id + assert item.title == stored_item.title + assert item.description == stored_item.description + assert item.owner_id == stored_item.owner_id + + +def test_update_item(db: Session) -> None: + title = random_lower_string() + description = random_lower_string() + owner = create_random_user(db) + item_in = ItemCreate(title=title, description=description) + item = crud.create_item(session=db, item_in=item_in, owner_id=owner.id) + new_title = random_lower_string() + item_update = ItemUpdate(title=new_title) + updated_item = crud.update_item(session=db, db_item=item, item_in=item_update) + assert updated_item.title == new_title + assert updated_item.description == description + assert updated_item.id == item.id + assert updated_item.owner_id == owner.id + + +def test_delete_item(db: Session) -> None: + title = random_lower_string() + description = random_lower_string() + owner = create_random_user(db) + item_in = ItemCreate(title=title, description=description) + item = crud.create_item(session=db, item_in=item_in, owner_id=owner.id) + deleted_item = crud.delete_item(session=db, item_id=item.id) + assert deleted_item + stored_item = crud.get_item(session=db, item_id=item.id) + assert stored_item is None + + +def create_random_user(db: Session) -> User: + email = random_email() + password = random_lower_string() + user_in = UserCreate(email=email, password=password) + return crud.create_user(session=db, user_create=user_in) + + +def test_get_items(db: Session) -> None: + owner = create_random_user(db) + item1 = crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner.id + ) + item2 = crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner.id + ) + items = crud.get_items(session=db) + assert len(items) >= 2 + assert item1 in items + assert item2 in items + + +def test_get_items_by_owner(db: Session) -> None: + owner1 = create_random_user(db) + owner2 = create_random_user(db) + item1 = crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner1.id + ) + item2 = crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner2.id + ) + items = crud.get_items_by_owner(session=db, owner_id=owner1.id) + assert len(items) == 1 + assert item1 in items + assert item2 not in items + + +def test_get_item_count(db: Session) -> None: + initial_count = crud.get_item_count(session=db) + owner = create_random_user(db) + crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner.id + ) + crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner.id + ) + final_count = crud.get_item_count(session=db) + assert final_count == initial_count + 2 + + +def test_get_item_count_by_owner(db: Session) -> None: + owner = create_random_user(db) + crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner.id + ) + crud.create_item( + session=db, item_in=ItemCreate(title=random_lower_string()), owner_id=owner.id + ) + count = crud.get_item_count_by_owner(session=db, owner_id=owner.id) + assert count == 2 diff --git a/backend/app/tests/crud/test_user.py b/backend/app/tests/crud/test_crud_user.py similarity index 58% rename from backend/app/tests/crud/test_user.py rename to backend/app/tests/crud/test_crud_user.py index e9eb4a0391..1f29e99a44 100644 --- a/backend/app/tests/crud/test_user.py +++ b/backend/app/tests/crud/test_crud_user.py @@ -76,16 +76,79 @@ def test_get_user(db: Session) -> None: assert jsonable_encoder(user) == jsonable_encoder(user_2) -def test_update_user(db: Session) -> None: +def test_update_user_attributes(db: Session) -> None: + """ + Test updating user attributes. + + This test verifies that the update_user_attributes function correctly updates + a user's attributes in the database. + + Args: + db: The database session. + + Returns: + None + + Raises: + AssertionError: If the user attributes are not updated correctly. + """ + # Create a new user password = random_lower_string() email = random_email() - user_in = UserCreate(email=email, password=password, is_superuser=True) + user_in = UserCreate(email=email, password=password, is_superuser=False) user = crud.create_user(session=db, user_create=user_in) + + # Update user attributes + new_email = random_email() + user_in_update = UserUpdate(email=new_email, is_superuser=True) + updated_user = crud.update_user_attributes( + session=db, db_user=user, user_in=user_in_update + ) + + # Verify the updates + assert updated_user.email == new_email + assert updated_user.is_superuser is True + assert updated_user.id == user.id # Ensure it's the same user + + # Fetch the user from the database to double-check + db.refresh(user) + assert user.email == new_email + assert user.is_superuser is True + + +def test_update_user_password(db: Session) -> None: + """ + Test updating user password. + + This test verifies that the update_user_password function correctly updates + a user's password in the database. + + Args: + db: The database session. + + Returns: + None + + Raises: + AssertionError: If the user password is not updated correctly. + """ + # Create a new user + password = random_lower_string() + email = random_email() + user_in = UserCreate(email=email, password=password) + user = crud.create_user(session=db, user_create=user_in) + + # Update user password new_password = random_lower_string() - user_in_update = UserUpdate(password=new_password, is_superuser=True) - if user.id is not None: - crud.update_user(session=db, db_user=user, user_in=user_in_update) - user_2 = db.get(User, user.id) - assert user_2 - assert user.email == user_2.email - assert verify_password(new_password, user_2.hashed_password) + updated_user = crud.update_user_password( + session=db, db_user=user, new_password=new_password + ) + + # Verify the password update + assert verify_password(new_password, updated_user.hashed_password) + assert not verify_password(password, updated_user.hashed_password) + + # Fetch the user from the database to double-check + db.refresh(user) + assert verify_password(new_password, user.hashed_password) + assert not verify_password(password, user.hashed_password) diff --git a/backend/app/tests/scripts/test_backend_pre_start.py b/backend/app/tests/scripts/test_backend_pre_start.py index 631690fcf6..e903fd9eae 100644 --- a/backend/app/tests/scripts/test_backend_pre_start.py +++ b/backend/app/tests/scripts/test_backend_pre_start.py @@ -2,16 +2,38 @@ from sqlmodel import select -from app.backend_pre_start import init, logger +from app.scripts.backend_pre_start import init, logger, main def test_init_successful_connection() -> None: + """ + Test successful database connection initialization. + + This test verifies that the init function successfully connects to the database + by mocking the necessary components and asserting the expected behavior. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the connection is not successful or if the session + does not execute the select statement as expected. + + Notes: + This test uses mocking to simulate the database engine and session. + """ + # Create a mock for the database engine engine_mock = MagicMock() + # Create a mock for the database session and its exec method session_mock = MagicMock() exec_mock = MagicMock(return_value=True) session_mock.configure_mock(**{"exec.return_value": exec_mock}) + # Patch necessary dependencies with ( patch("sqlmodel.Session", return_value=session_mock), patch.object(logger, "info"), @@ -19,15 +41,56 @@ def test_init_successful_connection() -> None: patch.object(logger, "warn"), ): try: + # Attempt to initialize the database connection init(engine_mock) connection_successful = True except Exception: connection_successful = False + # Assert that the connection was successful assert ( connection_successful ), "The database connection should be successful and not raise an exception." + # Assert that the session executed a select statement once assert session_mock.exec.called_once_with( select(1) ), "The session should execute a select statement once." + + +def test_main() -> None: + """ + Test the main function for service initialization. + + This test verifies that the main function successfully initializes the service + by mocking the necessary dependencies and asserting the expected behavior. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the init function is not called with the mocked engine + or if the logger does not log the expected messages. + + Notes: + This test uses mocking to simulate the init function, engine, and logger. + """ + # Patch necessary dependencies + with ( + patch("app.scripts.backend_pre_start.init") as mock_init, + patch("app.scripts.backend_pre_start.engine") as mock_engine, + patch.object(logger, "info") as mock_logger_info, + ): + # Call the main function + main() + + # Assert that init was called once with the mocked engine + mock_init.assert_called_once_with(mock_engine) + # Assert that logger.info was called twice + assert mock_logger_info.call_count == 2 + # Assert that the correct log messages were printed + mock_logger_info.assert_any_call("Initializing service") + mock_logger_info.assert_any_call("Service finished initializing") diff --git a/backend/app/tests/scripts/test_scripts_inital_data.py b/backend/app/tests/scripts/test_scripts_inital_data.py new file mode 100644 index 0000000000..e0ec3ddb59 --- /dev/null +++ b/backend/app/tests/scripts/test_scripts_inital_data.py @@ -0,0 +1,105 @@ +from unittest.mock import Mock, patch + +import pytest +from sqlmodel import Session + +from app.core.db import init_db +from app.scripts.initial_data import logger, main + + +def test_init_db_creates_superuser() -> None: + """ + Test init_db function's superuser creation. + + This test verifies that the init_db function correctly creates a superuser + with the specified email and password. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail during the test. + + Notes: + This test uses mock objects to simulate database interactions and + isolate the functionality being tested. + """ + # Create a mock session object + mock_session = Mock(spec=Session) + + # Use multiple patch decorators to mock various dependencies + with ( + patch("app.scripts.initial_data.Session", return_value=mock_session), + patch("app.scripts.initial_data.init_db"), + patch("app.core.db.crud.get_user_by_email", return_value=None) as mock_get_user, + patch("app.core.db.crud.create_user") as mock_create_user, + patch("app.core.db.settings") as mock_settings, + ): + # Set up mock settings for superuser + mock_settings.FIRST_SUPERUSER = "test@example.com" + mock_settings.FIRST_SUPERUSER_PASSWORD = "testpassword" + + # Call the function under test + init_db(mock_session) + + # Assert that get_user_by_email was called with correct arguments + mock_get_user.assert_called_once_with( + session=mock_session, email="test@example.com" + ) + # Assert that create_user was called + mock_create_user.assert_called_once() + # Get the arguments passed to create_user + create_user_args = mock_create_user.call_args[1] + # Assert that the correct session was passed + assert create_user_args["session"] == mock_session + # Assert that the correct email was used + assert create_user_args["user_create"].email == "test@example.com" + # Assert that the correct password was used + assert create_user_args["user_create"].password == "testpassword" + # Assert that the user was created as a superuser + assert create_user_args["user_create"].is_superuser is True + + +def test_main() -> None: + """ + Test main function's service initialization. + + This test ensures that the main function successfully initializes the service + and logs the appropriate messages. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If any of the assertions fail during the test. + + Notes: + This test uses mock objects to simulate the initialization process and + verify the logging behavior. + """ + # Use patch decorators to mock dependencies + with ( + patch("app.scripts.initial_data.init") as mock_init, + patch.object(logger, "info") as mock_logger_info, + ): + # Call the function under test + main() + + # Assert that init was called once + mock_init.assert_called_once() + # Assert that logger.info was called twice + assert mock_logger_info.call_count == 2 + # Assert that the correct log messages were printed + mock_logger_info.assert_any_call("Creating initial data") + mock_logger_info.assert_any_call("Initial data created") + + +if __name__ == "__main__": + # Run the tests if this script is executed directly + pytest.main([__file__]) diff --git a/backend/app/tests/scripts/test_test_pre_start.py b/backend/app/tests/scripts/test_test_pre_start.py index a176f380de..6ea65ed3f3 100644 --- a/backend/app/tests/scripts/test_test_pre_start.py +++ b/backend/app/tests/scripts/test_test_pre_start.py @@ -2,16 +2,40 @@ from sqlmodel import select -from app.tests_pre_start import init, logger +from app.scripts.tests_pre_start import init, logger def test_init_successful_connection() -> None: + """ + Test successful database connection initialization. + + This test verifies that the init function successfully connects to the database + by mocking the necessary components and asserting the expected behavior. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the connection is not successful or if the session + does not execute the select statement as expected. + + Notes: + This test uses mocking to simulate the database engine and session. + """ + # Create a mock for the database engine engine_mock = MagicMock() + # Create a mock for the database session session_mock = MagicMock() + # Create a mock for the exec method of the session exec_mock = MagicMock(return_value=True) + # Configure the session mock to return the exec mock when exec is called session_mock.configure_mock(**{"exec.return_value": exec_mock}) + # Use patch to mock various components with ( patch("sqlmodel.Session", return_value=session_mock), patch.object(logger, "info"), @@ -19,15 +43,18 @@ def test_init_successful_connection() -> None: patch.object(logger, "warn"), ): try: + # Attempt to initialize the database connection init(engine_mock) connection_successful = True except Exception: connection_successful = False + # Assert that the connection was successful assert ( connection_successful ), "The database connection should be successful and not raise an exception." + # Assert that the session executed a select statement once assert session_mock.exec.called_once_with( select(1) ), "The session should execute a select statement once." diff --git a/backend/app/tests/test_utils.py b/backend/app/tests/test_utils.py new file mode 100644 index 0000000000..38990befdb --- /dev/null +++ b/backend/app/tests/test_utils.py @@ -0,0 +1,264 @@ +from datetime import datetime, timezone +from unittest.mock import patch + +import jwt +import pytest + +from app.core.config import settings +from app.utils import ( + EmailData, + generate_new_account_email, + generate_password_reset_token, + generate_reset_password_email, + generate_test_email, + send_email, + verify_password_reset_token, +) + + +# Test the send_email function with different SMTP configurations +@pytest.mark.parametrize( + "smtp_config", + [ + {"SMTP_TLS": True, "SMTP_SSL": False}, + {"SMTP_TLS": False, "SMTP_SSL": True}, + {"SMTP_TLS": False, "SMTP_SSL": False}, + ], +) +def test_send_email(smtp_config: dict[str, bool]) -> None: + """ + Test the send_email function with different SMTP configurations. + + This test verifies that the send_email function correctly sends an email + using various SMTP configurations. + + Args: + smtp_config (dict[str, bool]): A dictionary containing SMTP configuration options. + + Returns: + None + + Raises: + AssertionError: If the email message is not created or sent as expected. + + Notes: + This test uses mocking to simulate different SMTP configurations and verify + the behavior of the send_email function. + """ + with ( + patch("app.utils.emails.Message") as mock_message, + patch("app.utils.settings") as mock_settings, + ): + # Mock the settings for the email configuration + mock_settings.emails_enabled = True + mock_settings.SMTP_HOST = "localhost" + mock_settings.SMTP_PORT = 25 + mock_settings.SMTP_USER = "user" + mock_settings.SMTP_PASSWORD = "password" + mock_settings.EMAILS_FROM_NAME = "Test" + mock_settings.EMAILS_FROM_EMAIL = "test@example.com" + mock_settings.SMTP_TLS = smtp_config["SMTP_TLS"] + mock_settings.SMTP_SSL = smtp_config["SMTP_SSL"] + + # Call the send_email function + send_email( + email_to="to@example.com", subject="Test", html_content="
Test
" + ) + + # Assert that the email message was created and sent + mock_message.assert_called_once() + mock_message.return_value.send.assert_called_once() + + +def test_generate_test_email() -> None: + """ + Test the generate_test_email function. + + This test verifies that the generate_test_email function correctly generates + an email with the expected properties. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the generated email does not have the expected properties. + + Notes: + This test uses mocking to simulate the email template rendering process. + """ + email_to = "test@example.com" + with patch("app.utils.render_email_template") as mock_render: + # Mock the email template rendering + mock_render.return_value = "Test Email
" + result = generate_test_email(email_to) + + # Assert that the generated email has the correct properties + assert isinstance(result, EmailData) + assert result.subject == f"{settings.PROJECT_NAME} - Test email" + assert result.html_content == "Test Email
" + + +def test_generate_reset_password_email() -> None: + """ + Test the generate_reset_password_email function. + + This test verifies that the generate_reset_password_email function correctly + generates an email with the expected properties for password reset. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the generated email does not have the expected properties. + + Notes: + This test uses mocking to simulate the email template rendering process. + """ + email_to = "test@example.com" + email = "user@example.com" + token = "test_token" + with patch("app.utils.render_email_template") as mock_render: + # Mock the email template rendering + mock_render.return_value = "Reset Password
" + result = generate_reset_password_email(email_to, email, token) + + # Assert that the generated email has the correct properties + assert isinstance(result, EmailData) + assert ( + result.subject + == f"{settings.PROJECT_NAME} - Password recovery for user {email}" + ) + assert result.html_content == "Reset Password
" + + +def test_generate_new_account_email() -> None: + """ + Test the generate_new_account_email function. + + This test verifies that the generate_new_account_email function correctly + generates an email with the expected properties for a new account. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the generated email does not have the expected properties. + + Notes: + This test uses mocking to simulate the email template rendering process. + """ + email_to = "test@example.com" + username = "testuser" + password = "testpass" + with patch("app.utils.render_email_template") as mock_render: + # Mock the email template rendering + mock_render.return_value = "New Account
" + result = generate_new_account_email(email_to, username, password) + + # Assert that the generated email has the correct properties + assert isinstance(result, EmailData) + assert ( + result.subject == f"{settings.PROJECT_NAME} - New account for user {username}" + ) + assert result.html_content == "New Account
" + + +def test_generate_password_reset_token() -> str: + """ + Test the generate_password_reset_token function. + + This test verifies that the generate_password_reset_token function correctly + generates a token for password reset. + + Args: + None + + Returns: + str: The generated token. + + Raises: + AssertionError: If the generated token is not as expected. + + Notes: + This test uses mocking to simulate the current datetime and JWT encoding process. + """ + email = "test@example.com" + with ( + patch("app.utils.datetime") as mock_datetime, + patch("app.utils.jwt.encode") as mock_encode, + ): + # Mock the current datetime and JWT encoding + mock_datetime.now.return_value = datetime(2023, 1, 1, tzinfo=timezone.utc) + mock_encode.return_value = "encoded_token" + result = generate_password_reset_token(email) + + # Assert that the generated token is correct + assert result == "encoded_token" + return result + + +def test_verify_password_reset_token() -> None: + """ + Test the verify_password_reset_token function. + + This test verifies that the verify_password_reset_token function correctly + verifies a valid password reset token. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the verified email is not as expected. + + Notes: + This test uses mocking to simulate the JWT decoding process. + """ + token = "valid_token" + with patch("app.utils.jwt.decode") as mock_decode: + # Mock the JWT decoding + mock_decode.return_value = {"sub": "test@example.com"} + result = verify_password_reset_token(token) + + # Assert that the verified email is correct + assert result == "test@example.com" + + +def test_verify_password_reset_token_invalid() -> None: + """ + Test the verify_password_reset_token function with an invalid token. + + This test verifies that the verify_password_reset_token function correctly + handles an invalid password reset token. + + Args: + None + + Returns: + None + + Raises: + AssertionError: If the result is not None for an invalid token. + + Notes: + This test uses mocking to simulate the JWT decoding process raising an InvalidTokenError. + """ + token = "invalid_token" + with patch("app.utils.jwt.decode") as mock_decode: + # Mock the JWT decoding to raise an InvalidTokenError + mock_decode.side_effect = jwt.exceptions.InvalidTokenError + result = verify_password_reset_token(token) + + # Assert that the result is None for an invalid token + assert result is None diff --git a/backend/app/tests/utils/user.py b/backend/app/tests/utils/user.py index 9c1b073109..b570422a3a 100644 --- a/backend/app/tests/utils/user.py +++ b/backend/app/tests/utils/user.py @@ -3,27 +3,73 @@ from app import crud from app.core.config import settings -from app.models import User, UserCreate, UserUpdate +from app.models import User, UserCreate from app.tests.utils.utils import random_email, random_lower_string def user_authentication_headers( *, client: TestClient, email: str, password: str ) -> dict[str, str]: + """ + Generate authentication headers for a user. + + This function creates and returns a valid token for the user with the given email and password. + + Args: + client (TestClient): The test client used to make requests. + email (str): The email of the user. + password (str): The password of the user. + + Returns: + dict[str, str]: A dictionary containing the authentication headers. + + Raises: + None + + Notes: + The function sends a POST request to obtain an access token and then creates the authentication headers. + """ + # Prepare the login data data = {"username": email, "password": password} + # Send a POST request to obtain an access token r = client.post(f"{settings.API_V1_STR}/login/access-token", data=data) response = r.json() + + # Extract the access token from the response auth_token = response["access_token"] + + # Create and return the authentication headers headers = {"Authorization": f"Bearer {auth_token}"} return headers def create_random_user(db: Session) -> User: + """ + Create a random user in the database. + + This function generates random email and password, creates a new user with these credentials, and returns the user object. + + Args: + db (Session): The database session. + + Returns: + User: The created user object. + + Raises: + None + + Notes: + This function uses utility functions to generate random email and password. + """ + # Generate random email and password email = random_email() password = random_lower_string() + + # Create a new user with the random email and password user_in = UserCreate(email=email, password=password) user = crud.create_user(session=db, user_create=user_in) + return user @@ -31,19 +77,39 @@ def authentication_token_from_email( *, client: TestClient, email: str, db: Session ) -> dict[str, str]: """ - Return a valid token for the user with given email. + Get or create a user and return their authentication token. - If the user doesn't exist it is created first. + This function returns a valid token for the user with the given email. If the user doesn't exist, it is created first. + + Args: + client (TestClient): The test client used to make requests. + email (str): The email of the user. + db (Session): The database session. + + Returns: + dict[str, str]: A dictionary containing the authentication headers. + + Raises: + None + + Notes: + If the user already exists, their password is updated before generating the authentication token. """ + # Generate a random password password = random_lower_string() + + # Try to get the user by email user = crud.get_user_by_email(session=db, email=email) + if not user: + # If the user doesn't exist, create a new one user_in_create = UserCreate(email=email, password=password) user = crud.create_user(session=db, user_create=user_in_create) else: - user_in_update = UserUpdate(password=password) - if not user.id: - raise Exception("User id not set") - user = crud.update_user(session=db, db_user=user, user_in=user_in_update) + # If the user exists, update their password + user = crud.update_user_password( + session=db, db_user=user, new_password=password + ) + # Get and return the authentication headers for the user return user_authentication_headers(client=client, email=email, password=password) diff --git a/backend/app/tests/utils/utils.py b/backend/app/tests/utils/utils.py index 184bac44d9..3804c5d7eb 100644 --- a/backend/app/tests/utils/utils.py +++ b/backend/app/tests/utils/utils.py @@ -7,20 +7,80 @@ def random_lower_string() -> str: + """ + Generate a random lowercase string. + + This function creates and returns a random string of 32 lowercase ASCII characters. + + Args: + None + + Returns: + str: A random string of 32 lowercase ASCII characters. + + Raises: + None + + Notes: + This function uses the random.choices method to select characters from string.ascii_lowercase. + """ + # Generate a random string of 32 lowercase ASCII characters return "".join(random.choices(string.ascii_lowercase, k=32)) def random_email() -> str: + """ + Generate a random email address. + + This function creates a random email address by combining two random lowercase strings. + + Args: + None + + Returns: + str: A randomly generated email address. + + Raises: + None + + Notes: + This function uses the random_lower_string function to generate the local and domain parts of the email. + """ + # Create a random email by combining two random lowercase strings return f"{random_lower_string()}@{random_lower_string()}.com" def get_superuser_token_headers(client: TestClient) -> dict[str, str]: + """ + Get authentication token headers for the superuser. + + This function obtains an access token for the superuser and returns it in the format required for authorization headers. + + Args: + client (TestClient): The test client used to make requests. + + Returns: + dict[str, str]: A dictionary containing the authorization header with the access token. + + Raises: + ValueError: If no access token is found in the response. + + Notes: + This function uses the settings module to get the superuser credentials and API endpoint. + """ + # Prepare login data for the superuser login_data = { "username": settings.FIRST_SUPERUSER, "password": settings.FIRST_SUPERUSER_PASSWORD, } + # Send a POST request to obtain an access token r = client.post(f"{settings.API_V1_STR}/login/access-token", data=login_data) + # Parse the JSON response tokens = r.json() - a_token = tokens["access_token"] - headers = {"Authorization": f"Bearer {a_token}"} - return headers + # Extract the access token from the response + a_token = tokens.get("access_token") + # Raise an error if no access token is found + if not a_token: + raise ValueError(f"No access token found in response: {tokens}") + # Return the token in the format required for authorization headers + return {"Authorization": f"Bearer {a_token}"} diff --git a/backend/app/tests_pre_start.py b/backend/app/tests_pre_start.py deleted file mode 100644 index 0ce6045635..0000000000 --- a/backend/app/tests_pre_start.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging - -from sqlalchemy import Engine -from sqlmodel import Session, select -from tenacity import after_log, before_log, retry, stop_after_attempt, wait_fixed - -from app.core.db import engine - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -max_tries = 60 * 5 # 5 minutes -wait_seconds = 1 - - -@retry( - stop=stop_after_attempt(max_tries), - wait=wait_fixed(wait_seconds), - before=before_log(logger, logging.INFO), - after=after_log(logger, logging.WARN), -) -def init(db_engine: Engine) -> None: - try: - # Try to create session to check if DB is awake - with Session(db_engine) as session: - session.exec(select(1)) - except Exception as e: - logger.error(e) - raise e - - -def main() -> None: - logger.info("Initializing service") - init(engine) - logger.info("Service finished initializing") - - -if __name__ == "__main__": - main() diff --git a/backend/app/utils.py b/backend/app/utils.py index ac029f6342..7250f0910a 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -12,20 +12,44 @@ from app.core import security from app.core.config import settings -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - @dataclass class EmailData: - html_content: str - subject: str + """ + Class for email data. + """ + + html_content: str # The HTML content of the email + subject: str # The subject line of the email def render_email_template(*, template_name: str, context: dict[str, Any]) -> str: + """ + Render an email template. + + This function reads an email template file, renders it with the provided context, + and returns the resulting HTML content. + + Args: + template_name (str): The name of the template file to render. + context (dict[str, Any]): A dictionary containing the context data for rendering the template. + + Returns: + str: The rendered HTML content of the email. + + Raises: + FileNotFoundError: If the specified template file is not found. + jinja2.exceptions.TemplateError: If there's an error in rendering the template. + + Notes: + The template files are expected to be located in the 'email-templates/build' directory + relative to the current file's location. + """ + # Construct the path to the email template file template_str = ( Path(__file__).parent / "email-templates" / "build" / template_name ).read_text() + # Render the template with the provided context html_content = Template(template_str).render(context) return html_content @@ -36,12 +60,38 @@ def send_email( subject: str = "", html_content: str = "", ) -> None: + """ + Send an email. + + This function sends an email using the configured SMTP settings. + + Args: + email_to (str): The recipient's email address. + subject (str, optional): The subject of the email. Defaults to an empty string. + html_content (str, optional): The HTML content of the email. Defaults to an empty string. + + Returns: + None + + Raises: + AssertionError: If email functionality is not enabled in the settings. + SMTPException: If there's an error in sending the email. + + Notes: + This function relies on the 'emails' library and the application's settings + for SMTP configuration. + """ + # Ensure email functionality is enabled in settings assert settings.emails_enabled, "no provided configuration for email variables" + + # Create an email message object with subject, content, and sender information message = emails.Message( subject=subject, html=html_content, mail_from=(settings.EMAILS_FROM_NAME, settings.EMAILS_FROM_EMAIL), ) + + # Configure SMTP options based on settings smtp_options = {"host": settings.SMTP_HOST, "port": settings.SMTP_PORT} if settings.SMTP_TLS: smtp_options["tls"] = True @@ -51,13 +101,30 @@ def send_email( smtp_options["user"] = settings.SMTP_USER if settings.SMTP_PASSWORD: smtp_options["password"] = settings.SMTP_PASSWORD + + # Send the email using the configured SMTP options and log the result response = message.send(to=email_to, smtp=smtp_options) - logger.info(f"send email result: {response}") + logging.info(f"send email result: {response}") def generate_test_email(email_to: str) -> EmailData: + """ + Generate a test email. + + This function creates a test email with a predefined subject and content. + + Args: + email_to (str): The recipient's email address. + + Returns: + EmailData: An object containing the HTML content and subject of the test email. + + Notes: + The email content is generated using a template named 'test_email.html'. + """ project_name = settings.PROJECT_NAME subject = f"{project_name} - Test email" + # Render the test email template with project name and recipient email html_content = render_email_template( template_name="test_email.html", context={"project_name": settings.PROJECT_NAME, "email": email_to}, @@ -66,9 +133,27 @@ def generate_test_email(email_to: str) -> EmailData: def generate_reset_password_email(email_to: str, email: str, token: str) -> EmailData: + """ + Generate a reset password email. + + This function creates an email for password reset with a link containing a reset token. + + Args: + email_to (str): The recipient's email address. + email (str): The user's email address (may be different from email_to). + token (str): The password reset token. + + Returns: + EmailData: An object containing the HTML content and subject of the reset password email. + + Notes: + The email content is generated using a template named 'reset_password.html'. + """ project_name = settings.PROJECT_NAME subject = f"{project_name} - Password recovery for user {email}" + # Construct the reset password link with the provided token link = f"{settings.FRONTEND_HOST}/reset-password?token={token}" + # Render the reset password email template with necessary context html_content = render_email_template( template_name="reset_password.html", context={ @@ -85,8 +170,26 @@ def generate_reset_password_email(email_to: str, email: str, token: str) -> Emai def generate_new_account_email( email_to: str, username: str, password: str ) -> EmailData: + """ + Generate a new account email. + + This function creates an email for a newly created account with login credentials. + + Args: + email_to (str): The recipient's email address. + username (str): The username for the new account. + password (str): The password for the new account. + + Returns: + EmailData: An object containing the HTML content and subject of the new account email. + + Notes: + The email content is generated using a template named 'new_account.html'. + Sending passwords via email is generally not recommended for security reasons. + """ project_name = settings.PROJECT_NAME subject = f"{project_name} - New account for user {username}" + # Render the new account email template with account details html_content = render_email_template( template_name="new_account.html", context={ @@ -101,10 +204,27 @@ def generate_new_account_email( def generate_password_reset_token(email: str) -> str: + """ + Generate a password reset token. + + This function creates a JWT token for password reset purposes. + + Args: + email (str): The email address of the user requesting a password reset. + + Returns: + str: A JWT token encoded as a string. + + Notes: + The token includes expiration time, not-before time, and the user's email as subject. + The token is signed using the application's SECRET_KEY. + """ + # Calculate token expiration time delta = timedelta(hours=settings.EMAIL_RESET_TOKEN_EXPIRE_HOURS) now = datetime.now(timezone.utc) expires = now + delta exp = expires.timestamp() + # Encode the JWT token with expiration, not-before, and subject claims encoded_jwt = jwt.encode( {"exp": exp, "nbf": now, "sub": email}, settings.SECRET_KEY, @@ -114,10 +234,29 @@ def generate_password_reset_token(email: str) -> str: def verify_password_reset_token(token: str) -> str | None: + """ + Verify a password reset token. + + This function decodes and verifies a JWT token used for password reset. + + Args: + token (str): The JWT token to verify. + + Returns: + str | None: The email address (subject) from the token if valid, None otherwise. + + Raises: + jwt.exceptions.InvalidTokenError: If the token is invalid or expired. + + Notes: + The token is verified using the application's SECRET_KEY. + """ try: + # Attempt to decode and verify the JWT token decoded_token = jwt.decode( token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] ) return str(decoded_token["sub"]) except InvalidTokenError: + # Return None if the token is invalid return None diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1c77b83ded..691a6b3562 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -39,7 +39,7 @@ build-backend = "hatchling.build" [tool.mypy] strict = true -exclude = ["venv", ".venv", "alembic"] +exclude = ["venv", ".venv", "alembic", "app/tests/core/test_config.py", "app/tests/api/test_deps.py"] [tool.ruff] target-version = "py310" diff --git a/backend/pytest.ini b/backend/pytest.ini new file mode 100644 index 0000000000..3f05abf289 --- /dev/null +++ b/backend/pytest.ini @@ -0,0 +1,16 @@ +[pytest] +markers = + order: mark test execution order + +addopts = -v -s --tb=short + +python_files = test_*.py + +filterwarnings = + ignore::DeprecationWarning + ignore::pytest.PytestUnknownMarkWarning + +log_cli = True +log_cli_level = INFO + +console_output_style = classic diff --git a/backend/scripts/lint.sh b/backend/scripts/lint.sh old mode 100644 new mode 100755 diff --git a/backend/scripts/prestart.sh b/backend/scripts/prestart.sh old mode 100644 new mode 100755 index 1b395d513f..89339b017f --- a/backend/scripts/prestart.sh +++ b/backend/scripts/prestart.sh @@ -4,10 +4,10 @@ set -e set -x # Let the DB start -python app/backend_pre_start.py +python app/scripts/backend_pre_start.py # Run migrations alembic upgrade head # Create initial data in DB -python app/initial_data.py +python app/scripts/initial_data.py diff --git a/backend/scripts/tests-start.sh b/backend/scripts/tests-start.sh old mode 100644 new mode 100755 index 89dcb0da23..c736e3e9e4 --- a/backend/scripts/tests-start.sh +++ b/backend/scripts/tests-start.sh @@ -2,6 +2,6 @@ set -e set -x -python app/tests_pre_start.py +python app/scripts/tests_pre_start.py bash scripts/test.sh "$@" diff --git a/frontend/src/client/core/OpenAPI.ts b/frontend/src/client/core/OpenAPI.ts index 746df5e61d..843f3b10ff 100644 --- a/frontend/src/client/core/OpenAPI.ts +++ b/frontend/src/client/core/OpenAPI.ts @@ -51,7 +51,7 @@ export const OpenAPI: OpenAPIConfig = { RESULT: "body", TOKEN: undefined, USERNAME: undefined, - VERSION: "0.1.0", + VERSION: "1", WITH_CREDENTIALS: false, interceptors: { request: new Interceptors(), response: new Interceptors() }, } diff --git a/frontend/src/client/models.ts b/frontend/src/client/models.ts index 2c8074ddd6..a4215ca5a8 100644 --- a/frontend/src/client/models.ts +++ b/frontend/src/client/models.ts @@ -33,25 +33,75 @@ export type ItemsPublic = { count: number } +/** + * Generic message. + * + * Defines a simple structure for generic messages. + * + * Args: + * message (str): The content of the message. + * + * Returns: + * None + * + * Notes: + * This class can be used for various messaging purposes throughout the application. + */ export type Message = { message: string } +/** + * Class for resetting password. + * + * Defines the structure for a password reset request. + * + * Args: + * token (str): The token for password reset verification. + * new_password (str): The new password to set, with length constraints. + * + * Returns: + * None + * + * Notes: + * This class is used in the password reset process. + */ export type NewPassword = { token: string new_password: string } +/** + * JSON payload containing access token. + * + * Defines the structure for an authentication token response. + * + * Args: + * access_token (str): The access token string. + * token_type (str): The type of token, defaults to "bearer". + * + * Returns: + * None + * + * Notes: + * This class is used in the authentication process to return token information. + */ export type Token = { access_token: string token_type?: string } +/** + * Class for updating user password. + */ export type UpdatePassword = { current_password: string new_password: string } +/** + * Class for creating a new user. + */ export type UserCreate = { email: string is_active?: boolean @@ -60,6 +110,9 @@ export type UserCreate = { password: string } +/** + * Public properties for user. + */ export type UserPublic = { email: string is_active?: boolean @@ -68,12 +121,18 @@ export type UserPublic = { id: string } +/** + * Class for user registration. + */ export type UserRegister = { email: string password: string full_name?: string | null } +/** + * Class for updating user information. + */ export type UserUpdate = { email?: string | null is_active?: boolean @@ -82,11 +141,17 @@ export type UserUpdate = { password?: string | null } +/** + * Class for updating user information. + */ export type UserUpdateMe = { full_name?: string | null email?: string | null } +/** + * Public properties for users. + */ export type UsersPublic = { data: Array