diff --git a/docs/customization.qmd b/docs/customization.qmd index d032cac..00e9685 100644 --- a/docs/customization.qmd +++ b/docs/customization.qmd @@ -25,7 +25,8 @@ The following fixtures, defined in `tests/conftest.py`, are available in the tes - `set_up_database`: Sets up the test database before running the test suite by dropping all tables and recreating them to ensure a clean state. - `session`: Provides a session for database operations in tests. - `clean_db`: Cleans up the database tables before each test by deleting all entries in the `PasswordResetToken` and `User` tables. -- `client`: Provides a `TestClient` instance with the session fixture, overriding the `get_session` dependency to use the test session. +- `auth_client`: Provides a `TestClient` instance with access and refresh token cookies set, overriding the `get_session` dependency to use the `session` fixture. +- `unauth_client`: Provides a `TestClient` instance without authentication cookies set, overriding the `get_session` dependency to use the `session` fixture. - `test_user`: Creates a test user in the database with a predefined name, email, and hashed password. To run the tests, use these commands: @@ -43,7 +44,7 @@ The project uses type annotations and mypy for static type checking. To run mypy mypy ``` -We find that mypy is an enormous time-saver, catching many errors early and greatly reducing time spent debugging unit tests. However, note that mypy requires you type annotate every variable, function, and method in your code base, so taking advantage of it is a lifestyle change! +We find that mypy is an enormous time-saver, catching many errors early and greatly reducing time spent debugging unit tests. However, note that mypy requires you type annotate every variable, function, and method in your code base, so taking advantage of it requires a lifestyle change! ## Project structure @@ -80,7 +81,9 @@ We also create POST endpoints, which accept form submissions so the user can cre #### Routing patterns in this template -In this template, GET routes are defined in the main entry point for the application, `main.py`. POST routes are organized into separate modules within the `routers/` directory. We name our GET routes using the convention `read_`, where `` is the name of the page, to indicate that they are read-only endpoints that do not modify the database. +In this template, GET routes are defined in the main entry point for the application, `main.py`. POST routes are organized into separate modules within the `routers/` directory. + +We name our GET routes using the convention `read_`, where `` is the name of the page, to indicate that they are read-only endpoints that do not modify the database. We divide our GET routes into authenticated and unauthenticated routes, using commented section headers in our code that look like this: @@ -88,7 +91,9 @@ We divide our GET routes into authenticated and unauthenticated routes, using co # -- Authenticated Routes -- ``` -Some of our routes take request parameters, which we pass as keyword arguments to the route handler. These parameters should be type annotated for validation purposes. Some parameters are shared across all authenticated or unauthenticated routes, so we define them in the `common_authenticated_parameters` and `common_unauthenticated_parameters` dependencies defined in `main.py`. +Some of our routes take request parameters, which we pass as keyword arguments to the route handler. These parameters should be type annotated for validation purposes. + +Some parameters are shared across all authenticated or unauthenticated routes, so we define them in the `common_authenticated_parameters` and `common_unauthenticated_parameters` dependencies defined in `main.py`. ### HTML templating with Jinja2 @@ -109,7 +114,7 @@ async def welcome(request: Request): ) ``` -In this example, the `welcome.html` template will receive two pieces of context: the user's `request`, which is always passed automatically by FastAPI, and a `username` variable, which we specify as "Alice". We can then use the `{{ username }}` syntax in the `welcome.html` template (or any of its parent or child templates) to insert the value into the HTML. +In this example, the `welcome.html` template will receive two pieces of context: the user's `request`, which is always passed automatically by FastAPI, and a `username` variable, which we specify as "Alice". We can then use the `{{{ username }}}` syntax in the `welcome.html` template (or any of its parent or child templates) to insert the value into the HTML. #### Form validation strategy diff --git a/docs/installation.qmd b/docs/installation.qmd index 7167589..2014027 100644 --- a/docs/installation.qmd +++ b/docs/installation.qmd @@ -20,6 +20,8 @@ If you use VSCode with Docker to develop in a container, the following VSCode De Simply create a `.devcontainer` folder in the root of the project and add a `devcontainer.json` file in the folder with the above content. VSCode may prompt you to install the Dev Container extension if you haven't already, and/or to open the project in a container. If not, you can manually select "Dev Containers: Reopen in Container" from View > Command Palette. +*IMPORTANT: If using this dev container configuration, you will need to set the `DB_HOST` environment variable to "host.docker.internal" in the `.env` file.* + ## Install development dependencies manually ### Python and Docker @@ -103,15 +105,32 @@ Set your desired database name, username, and password in the .env file. To use password recovery, register a [Resend](https://resend.com/) account, verify a domain, get an API key, and paste the API key into the .env file. +If using the dev container configuration, you will need to set the `DB_HOST` environment variable to "host.docker.internal" in the .env file. Otherwise, set `DB_HOST` to "localhost" for local development. (In production, `DB_HOST` will be set to the hostname of the database server.) + ## Start development database +To start the development database, run the following command in your terminal from the root directory: + ``` bash docker compose up -d ``` +If at any point you change the environment variables in the .env file, you will need to stop the database service *and tear down the volume*: + +``` bash +# Don't forget the -v flag to tear down the volume! +docker compose down -v +``` + +You may also need to restart the terminal session to pick up the new environment variables. You can also add the `--force-recreate` and `--build` flags to the startup command to ensure the container is rebuilt: + +``` bash +docker compose up -d --force-recreate --build +``` + ## Run the development server -Make sure the development database is running and tables and default permissions/roles are created first. +Before running the development server, make sure the development database is running and tables and default permissions/roles are created first. Then run the following command in your terminal from the root directory: ``` bash uvicorn main:app --host 0.0.0.0 --port 8000 --reload diff --git a/index.qmd b/index.qmd index 0266e7b..b23f439 100644 --- a/index.qmd +++ b/index.qmd @@ -107,6 +107,8 @@ To use password recovery, register a [Resend](https://resend.com/) account, veri ### Start development database +To start the development database, run the following command in your terminal from the root directory: + ``` bash docker compose up -d ``` diff --git a/main.py b/main.py index 676808e..5642f6a 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ from fastapi.exceptions import RequestValidationError, HTTPException, StarletteHTTPException from sqlmodel import Session from routers import authentication, organization, role, user -from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError +from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError, AuthenticationError from utils.models import User from utils.db import get_session, set_up_db @@ -37,6 +37,15 @@ async def lifespan(app: FastAPI): # -- Exception Handling Middlewares -- +# Handle AuthenticationError by redirecting to login page +@app.exception_handler(AuthenticationError) +async def authentication_error_handler(request: Request, exc: AuthenticationError): + return RedirectResponse( + url="/login", + status_code=status.HTTP_303_SEE_OTHER + ) + + # Handle NeedsNewTokens by setting new tokens and redirecting to same page @app.exception_handler(NeedsNewTokens) async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens): @@ -104,10 +113,6 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE # Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): - # Don't handle redirects - if exc.status_code in [301, 302, 303, 307, 308]: - raise exc - return templates.TemplateResponse( request, "errors/error.html", diff --git a/routers/authentication.py b/routers/authentication.py index 0a1098b..d487575 100644 --- a/routers/authentication.py +++ b/routers/authentication.py @@ -114,7 +114,6 @@ class UserRead(BaseModel): organization_id: Optional[int] created_at: datetime updated_at: datetime - deleted: bool # -- Routes -- diff --git a/routers/organization.py b/routers/organization.py index 9d4f1de..ea48a94 100644 --- a/routers/organization.py +++ b/routers/organization.py @@ -27,7 +27,6 @@ class OrganizationRead(BaseModel): name: str created_at: datetime updated_at: datetime - deleted: bool class OrganizationUpdate(BaseModel): @@ -113,9 +112,7 @@ def delete_organization( if not db_org: raise HTTPException(status_code=404, detail="Organization not found") - db_org.deleted = True - db_org.updated_at = datetime.utcnow() - session.add(db_org) + session.delete(db_org) session.commit() return RedirectResponse(url="/organizations", status_code=303) diff --git a/routers/role.py b/routers/role.py index cb2488f..a6429c4 100644 --- a/routers/role.py +++ b/routers/role.py @@ -31,7 +31,6 @@ class RoleRead(BaseModel): name: str created_at: datetime updated_at: datetime - deleted: bool permissions: List[ValidPermissions] @@ -74,7 +73,7 @@ def create_role( @router.get("/{role_id}", response_model=RoleRead) def read_role(role_id: int, session: Session = Depends(get_session)): db_role: Role | None = session.get(Role, role_id) - if not db_role or not db_role.id or db_role.deleted: + if not db_role or not db_role.id: raise HTTPException(status_code=404, detail="Role not found") permissions = [ @@ -88,7 +87,6 @@ def read_role(role_id: int, session: Session = Depends(get_session)): name=db_role.name, created_at=db_role.created_at, updated_at=db_role.updated_at, - deleted=db_role.deleted, permissions=permissions ) @@ -99,7 +97,7 @@ def update_role( session: Session = Depends(get_session) ) -> RedirectResponse: db_role: Role | None = session.get(Role, role.id) - if not db_role or not db_role.id or db_role.deleted: + if not db_role or not db_role.id: raise HTTPException(status_code=404, detail="Role not found") role_data = role.model_dump(exclude_unset=True) for key, value in role_data.items(): @@ -131,8 +129,6 @@ def delete_role( db_role = session.get(Role, role_id) if not db_role: raise HTTPException(status_code=404, detail="Role not found") - db_role.deleted = True - db_role.updated_at = utc_time() - session.add(db_role) + session.delete(db_role) session.commit() return RedirectResponse(url="/roles", status_code=303) diff --git a/routers/user.py b/routers/user.py index df7d698..043509d 100644 --- a/routers/user.py +++ b/routers/user.py @@ -11,7 +11,8 @@ # -- Server Request and Response Models -- -class UserProfile(BaseModel): +class UpdateProfile(BaseModel): + """Request model for updating user profile information""" name: str email: EmailStr avatar_url: str @@ -40,26 +41,16 @@ async def as_form( # -- Routes -- -@router.get("/profile", response_class=RedirectResponse) -async def view_profile( - current_user: User = Depends(get_authenticated_user) -): - # Render the profile page with the current user's data - return {"user": current_user} - - -@router.post("/edit_profile", response_class=RedirectResponse) -async def edit_profile( - name: str = Form(...), - email: str = Form(...), - avatar_url: str = Form(...), +@router.post("/update_profile", response_class=RedirectResponse) +async def update_profile( + user_profile: UpdateProfile = Depends(UpdateProfile.as_form), current_user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ): # Update user details - current_user.name = name - current_user.email = email - current_user.avatar_url = avatar_url + current_user.name = user_profile.name + current_user.email = user_profile.email + current_user.avatar_url = user_profile.avatar_url session.commit() session.refresh(current_user) return RedirectResponse(url="/profile", status_code=303) @@ -67,14 +58,23 @@ async def edit_profile( @router.post("/delete_account", response_class=RedirectResponse) async def delete_account( - confirm_delete_password: str = Form(...), + user_delete_account: UserDeleteAccount = Depends( + UserDeleteAccount.as_form), current_user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ): - if not verify_password(confirm_delete_password, current_user.hashed_password): - raise HTTPException(status_code=400, detail="Password is incorrect") + if not verify_password( + user_delete_account.confirm_delete_password, + current_user.hashed_password + ): + raise HTTPException( + status_code=400, + detail="Password is incorrect" + ) - # Mark the user as deleted - current_user.deleted = True + # Delete the user + session.delete(current_user) session.commit() - return RedirectResponse(url="/", status_code=303) + + # Log out the user + return RedirectResponse(url="/auth/logout", status_code=303) diff --git a/templates/authentication/register.html b/templates/authentication/register.html index 69320a1..ceb8aac 100644 --- a/templates/authentication/register.html +++ b/templates/authentication/register.html @@ -25,7 +25,7 @@
diff --git a/templates/users/profile.html b/templates/users/profile.html index 5afa794..448c0d6 100644 --- a/templates/users/profile.html +++ b/templates/users/profile.html @@ -34,7 +34,7 @@

User Profile

Edit Profile
-
+
diff --git a/tests/conftest.py b/tests/conftest.py index da1eec2..f9b90ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ from fastapi.testclient import TestClient from utils.db import get_connection_url, set_up_db, tear_down_db, get_session from utils.models import User, PasswordResetToken -from utils.auth import get_password_hash +from utils.auth import get_password_hash, create_access_token, create_refresh_token from main import app load_dotenv() @@ -54,22 +54,6 @@ def clean_db(session: Session): session.commit() -# Test client fixture -@pytest.fixture() -def client(session: Session): - """ - Provides a TestClient instance with the session fixture. - Overrides the get_session dependency to use the test session. - """ - def get_session_override(): - return session - - app.dependency_overrides[get_session] = get_session_override - client = TestClient(app) - yield client - app.dependency_overrides.clear() - - # Test user fixture @pytest.fixture() def test_user(session: Session): @@ -85,3 +69,41 @@ def test_user(session: Session): session.commit() session.refresh(user) return user + + +# Unauthenticated client fixture +@pytest.fixture() +def unauth_client(session: Session): + """ + Provides a TestClient instance without authentication. + """ + def get_session_override(): + return session + + app.dependency_overrides[get_session] = get_session_override + client = TestClient(app) + yield client + app.dependency_overrides.clear() + + +# Authenticated client fixture +@pytest.fixture() +def auth_client(session: Session, test_user: User): + """ + Provides a TestClient instance with valid authentication tokens. + """ + def get_session_override(): + return session + + app.dependency_overrides[get_session] = get_session_override + client = TestClient(app) + + # Create and set valid tokens + access_token = create_access_token({"sub": test_user.email}) + refresh_token = create_refresh_token({"sub": test_user.email}) + + client.cookies.set("access_token", access_token) + client.cookies.set("refresh_token", refresh_token) + + yield client + app.dependency_overrides.clear() diff --git a/tests/test_authentication.py b/tests/test_authentication.py index ac0df9e..0ba2331 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -86,9 +86,9 @@ def test_invalid_token_type(): # --- API Endpoint Tests --- -def test_register_endpoint(client: TestClient, session: Session): - response = client.post( - "/auth/register", +def test_register_endpoint(unauth_client: TestClient, session: Session): + response = unauth_client.post( + app.url_path_for("register"), data={ "name": "New User", "email": "new@example.com", @@ -107,9 +107,9 @@ def test_register_endpoint(client: TestClient, session: Session): assert verify_password("NewPass123!@#", user.hashed_password) -def test_login_endpoint(client: TestClient, test_user: User): - response = client.post( - "/auth/login", +def test_login_endpoint(unauth_client: TestClient, test_user: User): + response = unauth_client.post( + app.url_path_for("login"), data={ "email": test_user.email, "password": "Test123!@#" @@ -124,18 +124,18 @@ def test_login_endpoint(client: TestClient, test_user: User): assert "refresh_token" in cookies -def test_refresh_token_endpoint(client: TestClient, test_user: User): - # Create expired access token and valid refresh token - access_token = create_access_token( +def test_refresh_token_endpoint(auth_client: TestClient, test_user: User): + # Override just the access token to be expired, keeping the valid refresh token + expired_access_token = create_access_token( {"sub": test_user.email}, timedelta(minutes=-10) ) - refresh_token = create_refresh_token({"sub": test_user.email}) + auth_client.cookies.set("access_token", expired_access_token) - client.cookies.set("access_token", access_token) - client.cookies.set("refresh_token", refresh_token) - - response = client.post("/auth/refresh", follow_redirects=False) + response = auth_client.post( + app.url_path_for("refresh_token"), + follow_redirects=False + ) assert response.status_code == 303 # Check for new tokens in headers @@ -155,10 +155,10 @@ def test_refresh_token_endpoint(client: TestClient, test_user: User): assert decoded["sub"] == test_user.email -def test_password_reset_flow(client: TestClient, session: Session, test_user: User, mock_resend_send): +def test_password_reset_flow(unauth_client: TestClient, session: Session, test_user: User, mock_resend_send): # Test forgot password request - response = client.post( - "/auth/forgot_password", + response = unauth_client.post( + app.url_path_for("forgot_password"), data={"email": test_user.email}, follow_redirects=False ) @@ -188,8 +188,8 @@ def test_password_reset_flow(client: TestClient, session: Session, test_user: Us assert not reset_token.used # Test password reset - response = client.post( - "/auth/reset_password", + response = unauth_client.post( + app.url_path_for("reset_password"), data={ "email": test_user.email, "token": reset_token.token, @@ -207,12 +207,11 @@ def test_password_reset_flow(client: TestClient, session: Session, test_user: Us assert reset_token.used -def test_logout_endpoint(client: TestClient): - # First set some cookies - client.cookies.set("access_token", "some_access_token") - client.cookies.set("refresh_token", "some_refresh_token") - - response = client.get("/auth/logout", follow_redirects=False) +def test_logout_endpoint(auth_client: TestClient): + response = auth_client.get( + app.url_path_for("logout"), + follow_redirects=False + ) assert response.status_code == 303 # Check for cookie deletion in headers @@ -226,9 +225,9 @@ def test_logout_endpoint(client: TestClient): # --- Error Case Tests --- -def test_register_with_existing_email(client: TestClient, test_user: User): - response = client.post( - "/auth/register", +def test_register_with_existing_email(unauth_client: TestClient, test_user: User): + response = unauth_client.post( + app.url_path_for("register"), data={ "name": "Another User", "email": test_user.email, @@ -239,9 +238,9 @@ def test_register_with_existing_email(client: TestClient, test_user: User): assert response.status_code == 400 -def test_login_with_invalid_credentials(client: TestClient, test_user: User): - response = client.post( - "/auth/login", +def test_login_with_invalid_credentials(unauth_client: TestClient, test_user: User): + response = unauth_client.post( + app.url_path_for("login"), data={ "email": test_user.email, "password": "WrongPass123!@#" @@ -250,9 +249,9 @@ def test_login_with_invalid_credentials(client: TestClient, test_user: User): assert response.status_code == 400 -def test_password_reset_with_invalid_token(client: TestClient, test_user: User): - response = client.post( - "/auth/reset_password", +def test_password_reset_with_invalid_token(unauth_client: TestClient, test_user: User): + response = unauth_client.post( + app.url_path_for("reset_password"), data={ "email": test_user.email, "token": "invalid_token", @@ -263,7 +262,7 @@ def test_password_reset_with_invalid_token(client: TestClient, test_user: User): assert response.status_code == 400 -def test_password_reset_url_generation(client: TestClient): +def test_password_reset_url_generation(unauth_client: TestClient): """ Tests that the password reset URL is correctly formatted and contains the required query parameters. @@ -290,12 +289,12 @@ def test_password_reset_url_generation(client: TestClient): assert query_params["token"][0] == test_token -def test_password_reset_email_url(client: TestClient, session: Session, test_user: User, mock_resend_send): +def test_password_reset_email_url(unauth_client: TestClient, session: Session, test_user: User, mock_resend_send): """ Tests that the password reset email contains a properly formatted reset URL. """ - response = client.post( - "/auth/forgot_password", + response = unauth_client.post( + app.url_path_for("forgot_password"), data={"email": test_user.email}, follow_redirects=False ) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..a8dd554 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,21 @@ +from fastapi.testclient import TestClient + +from utils.models import User +from main import app + + +def test_read_profile_unauthorized(unauth_client: TestClient): + """Test that unauthorized users cannot view profile""" + response = unauth_client.get(app.url_path_for( + "read_profile"), follow_redirects=False) + assert response.status_code == 303 # Redirect to login + assert response.headers["location"] == app.url_path_for("read_login") + + +def test_read_profile_authorized(auth_client: TestClient, test_user: User): + """Test that authorized users can view their profile""" + response = auth_client.get(app.url_path_for("read_profile")) + assert response.status_code == 200 + # Check that the response contains the expected HTML content + assert test_user.email in response.text + assert test_user.name in response.text diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 0000000..d6f42d0 --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,83 @@ +from fastapi.testclient import TestClient +from httpx import Response +from sqlmodel import Session + +from main import app +from utils.models import User + + +def test_update_profile_unauthorized(unauth_client: TestClient): + """Test that unauthorized users cannot edit profile""" + response: Response = unauth_client.post( + app.url_path_for("update_profile"), + data={ + "name": "New Name", + "email": "new@example.com", + "avatar_url": "https://example.com/avatar.jpg" + }, + follow_redirects=False + ) + assert response.status_code == 303 # Redirect to login + assert response.headers["location"] == app.url_path_for("read_login") + + +def test_update_profile_authorized(auth_client: TestClient, test_user: User, session: Session): + """Test that authorized users can edit their profile""" + + # Update profile + response: Response = auth_client.post( + app.url_path_for("update_profile"), + data={ + "name": "Updated Name", + "email": "updated@example.com", + "avatar_url": "https://example.com/new-avatar.jpg" + }, + follow_redirects=False + ) + assert response.status_code == 303 + assert response.headers["location"] == app.url_path_for("read_profile") + + # Verify changes in database + session.refresh(test_user) + assert test_user.name == "Updated Name" + assert test_user.email == "updated@example.com" + assert test_user.avatar_url == "https://example.com/new-avatar.jpg" + + +def test_delete_account_unauthorized(unauth_client: TestClient): + """Test that unauthorized users cannot delete account""" + response: Response = unauth_client.post( + app.url_path_for("delete_account"), + data={"confirm_delete_password": "Test123!@#"}, + follow_redirects=False + ) + assert response.status_code == 303 # Redirect to login + assert response.headers["location"] == app.url_path_for("read_login") + + +def test_delete_account_wrong_password(auth_client: TestClient, test_user: User): + """Test that account deletion fails with wrong password""" + response: Response = auth_client.post( + app.url_path_for("delete_account"), + data={"confirm_delete_password": "WrongPassword123!"}, + follow_redirects=False + ) + assert response.status_code == 400 + assert "Password is incorrect" in response.text.strip() + + +def test_delete_account_success(auth_client: TestClient, test_user: User, session: Session): + """Test successful account deletion""" + + # Delete account + response: Response = auth_client.post( + app.url_path_for("delete_account"), + data={"confirm_delete_password": "Test123!@#"}, + follow_redirects=False + ) + assert response.status_code == 303 + assert response.headers["location"] == app.url_path_for("logout") + + # Verify user is deleted from database + user = session.get(User, test_user.id) + assert user is None diff --git a/utils/auth.py b/utils/auth.py index 3bf7dac..5793c3e 100644 --- a/utils/auth.py +++ b/utils/auth.py @@ -75,7 +75,7 @@ def validate_password_strength(v: str) -> str: """ logger.debug(f"Validating password for {field_name}") pattern = re.compile( - r"(?=.*\d)(?=.*[a-z])(?=.*[A-Z])(?=.*[@$!%*?&{}<>.,\\'#\-_=+\(\)\[\]:;|~])[A-Za-z\d@$!%*?&{}<>.,\\'#\-_=+\(\)\[\]:;|~]{8,}") + r"(?=.*\d)(?=.*[a-z])(?=.*[A-Z])(?=.*[@$!%*?&{}<>.,\\'#\-_=+\(\)\[\]:;|~/])[A-Za-z\d@$!%*?&{}<>.,\\'#\-_=+\(\)\[\]:;|~/]{8,}") if not pattern.match(v): logger.debug(f"Password for { field_name} does not satisfy the security policy") @@ -180,7 +180,8 @@ def validate_token_and_get_user( if decoded_token: user_email = decoded_token.get("sub") user = session.exec(select(User).where( - User.email == user_email)).first() + User.email == user_email + )).first() if user: if token_type == "refresh": new_access_token = create_access_token( @@ -215,6 +216,14 @@ def get_user_from_tokens( return None, None, None +class AuthenticationError(HTTPException): + def __init__(self): + super().__init__( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": "/login"} + ) + + def get_authenticated_user( tokens: tuple[Optional[str], Optional[str] ] = Depends(oauth2_scheme_cookie), @@ -228,11 +237,7 @@ def get_authenticated_user( raise NeedsNewTokens(user, new_access_token, new_refresh_token) return user - # If both tokens are invalid or missing, redirect to login - raise HTTPException( - status_code=status.HTTP_307_TEMPORARY_REDIRECT, - headers={"Location": "/login"} - ) + raise AuthenticationError() def get_optional_user( @@ -275,7 +280,9 @@ def generate_password_reset_url(email: str, token: str) -> str: def send_reset_email(email: str, session: Session): # Check for an existing unexpired token - user = session.exec(select(User).where(User.email == email)).first() + user = session.exec(select(User).where( + User.email == email + )).first() if user: existing_token = session.exec( select(PasswordResetToken) @@ -316,18 +323,19 @@ def send_reset_email(email: str, session: Session): def get_user_from_reset_token(email: str, token: str, session: Session) -> tuple[Optional[User], Optional[PasswordResetToken]]: - reset_token = session.exec(select(PasswordResetToken).where( - PasswordResetToken.token == token, - PasswordResetToken.expires_at > datetime.now(UTC), - PasswordResetToken.used == False - )).first() + result = session.exec( + select(User, PasswordResetToken) + .where( + User.email == email, + PasswordResetToken.token == token, + PasswordResetToken.expires_at > datetime.now(UTC), + PasswordResetToken.used == False, + PasswordResetToken.user_id == User.id + ) + ).first() - if not reset_token: + if not result: return None, None - user = session.exec(select(User).where( - User.email == email, - User.id == reset_token.user_id - )).first() - + user, reset_token = result return user, reset_token diff --git a/utils/models.py b/utils/models.py index 5941f3f..cd43f34 100644 --- a/utils/models.py +++ b/utils/models.py @@ -29,7 +29,6 @@ class Organization(SQLModel, table=True): name: str created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - deleted: bool = Field(default=False) users: List["User"] = Relationship(back_populates="organization") @@ -41,7 +40,6 @@ class Role(SQLModel, table=True): default=None, foreign_key="organization.id") created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - deleted: bool = Field(default=False) users: List["User"] = Relationship(back_populates="role") role_permission_links: List["RolePermissionLink"] = Relationship( @@ -54,7 +52,6 @@ class Permission(SQLModel, table=True): sa_column=Column(SQLAlchemyEnum(ValidPermissions, create_type=False))) created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - deleted: bool = Field(default=False) role_permission_links: List["RolePermissionLink"] = Relationship( back_populates="permission") @@ -97,13 +94,14 @@ class User(SQLModel, table=True): role_id: Optional[int] = Field(default=None, foreign_key="role.id") created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - deleted: bool = Field(default=False) organization: Optional["Organization"] = Relationship( back_populates="users") role: Optional["Role"] = Relationship(back_populates="users") password_reset_tokens: List["PasswordResetToken"] = Relationship( - back_populates="user") + back_populates="user", + sa_relationship_kwargs={"cascade": "all, delete-orphan"} + ) class UserOrganizationLink(SQLModel, table=True):