diff --git a/.gitignore b/.gitignore index b155443..3335ab9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ __pycache__ /.quarto/ _docs/ .pytest_cache/ -.mypy_cache/ \ No newline at end of file +.mypy_cache/ +.cursorrules \ No newline at end of file diff --git a/docs/customization.qmd b/docs/customization.qmd index 6bc663d..e43169e 100644 --- a/docs/customization.qmd +++ b/docs/customization.qmd @@ -41,7 +41,7 @@ To run the tests, use these commands: The project uses type annotations and mypy for static type checking. To run mypy, use this command from the root directory: ```bash -mypy +mypy . ``` We find that mypy is an enormous time-saver, catching many errors early and greatly reducing time spent debugging unit tests. However, note that mypy requires you type annotate every variable, function, and method in your code base, so taking advantage of it requires a lifestyle change! @@ -274,9 +274,9 @@ graph.write_png('static/schema.png') ![Database Schema](static/schema.png) -#### Database operations +#### Database helpers -Database operations are handled by helper functions in `utils/db.py`. Key functions include: +Database operations are facilitated by helper functions in `utils/db.py`. Key functions include: - `set_up_db()`: Initializes the database schema and default data (which we do on every application start in `main.py`) - `get_connection_url()`: Creates a database connection URL from environment variables in `.env` @@ -292,3 +292,30 @@ async def get_users(session: Session = Depends(get_session)): ``` The session automatically handles transaction management, ensuring that database operations are atomic and consistent. + +#### Cascade deletes + +Cascade deletes (in which deleting a record from one table deletes related records from another table) can be handled at either the ORM level or the database level. This template handles cascade deletes at the ORM level, via SQLModel relationships. Inside a SQLModel `Relationship`, we set: + +```python +sa_relationship_kwargs={ + "cascade": "all, delete-orphan" +} +``` + +This tells SQLAlchemy to cascade all operations (e.g., `SELECT`, `INSERT`, `UPDATE`, `DELETE`) to the related table. Since this happens through the ORM, we need to be careful to do all our database operations through the ORM using supported syntax. That generally means loading database records into Python objects and then deleting those objects rather than deleting records in the database directly. + +For example, + +```python +session.exec(delete(Role)) +``` + +will not trigger the cascade delete. Instead, we need to select the role objects and then delete them: + +```python +for role in session.exec(select(Role)).all(): + session.delete(role) +``` + +This is slower than deleting the records directly, but it makes [many-to-many relationships](https://sqlmodel.tiangolo.com/tutorial/many-to-many/create-models-with-link/#create-the-tables) much easier to manage. diff --git a/docs/static/llms.txt b/docs/static/llms.txt index 8f5edce..0635d39 100644 --- a/docs/static/llms.txt +++ b/docs/static/llms.txt @@ -105,6 +105,8 @@ To use password recovery, register a [Resend](https://resend.com/) account, veri ### Start development database +To start the development database, run the following command in your terminal from the root directory: + ``` bash docker compose up -d ``` @@ -515,6 +517,8 @@ If you use VSCode with Docker to develop in a container, the following VSCode De Simply create a `.devcontainer` folder in the root of the project and add a `devcontainer.json` file in the folder with the above content. VSCode may prompt you to install the Dev Container extension if you haven't already, and/or to open the project in a container. If not, you can manually select "Dev Containers: Reopen in Container" from View > Command Palette. +*IMPORTANT: If using this dev container configuration, you will need to set the `DB_HOST` environment variable to "host.docker.internal" in the `.env` file.* + ## Install development dependencies manually ### Python and Docker @@ -598,15 +602,32 @@ Set your desired database name, username, and password in the .env file. To use password recovery, register a [Resend](https://resend.com/) account, verify a domain, get an API key, and paste the API key into the .env file. +If using the dev container configuration, you will need to set the `DB_HOST` environment variable to "host.docker.internal" in the .env file. Otherwise, set `DB_HOST` to "localhost" for local development. (In production, `DB_HOST` will be set to the hostname of the database server.) + ## Start development database +To start the development database, run the following command in your terminal from the root directory: + ``` bash docker compose up -d ``` +If at any point you change the environment variables in the .env file, you will need to stop the database service *and tear down the volume*: + +``` bash +# Don't forget the -v flag to tear down the volume! +docker compose down -v +``` + +You may also need to restart the terminal session to pick up the new environment variables. You can also add the `--force-recreate` and `--build` flags to the startup command to ensure the container is rebuilt: + +``` bash +docker compose up -d --force-recreate --build +``` + ## Run the development server -Make sure the development database is running and tables and default permissions/roles are created first. +Before running the development server, make sure the development database is running and tables and default permissions/roles are created first. Then run the following command in your terminal from the root directory: ``` bash uvicorn main:app --host 0.0.0.0 --port 8000 --reload @@ -646,7 +667,8 @@ The following fixtures, defined in `tests/conftest.py`, are available in the tes - `set_up_database`: Sets up the test database before running the test suite by dropping all tables and recreating them to ensure a clean state. - `session`: Provides a session for database operations in tests. - `clean_db`: Cleans up the database tables before each test by deleting all entries in the `PasswordResetToken` and `User` tables. -- `client`: Provides a `TestClient` instance with the session fixture, overriding the `get_session` dependency to use the test session. +- `auth_client`: Provides a `TestClient` instance with access and refresh token cookies set, overriding the `get_session` dependency to use the `session` fixture. +- `unauth_client`: Provides a `TestClient` instance without authentication cookies set, overriding the `get_session` dependency to use the `session` fixture. - `test_user`: Creates a test user in the database with a predefined name, email, and hashed password. To run the tests, use these commands: @@ -661,10 +683,10 @@ To run the tests, use these commands: The project uses type annotations and mypy for static type checking. To run mypy, use this command from the root directory: ```bash -mypy +mypy . ``` -We find that mypy is an enormous time-saver, catching many errors early and greatly reducing time spent debugging unit tests. However, note that mypy requires you type annotate every variable, function, and method in your code base, so taking advantage of it is a lifestyle change! +We find that mypy is an enormous time-saver, catching many errors early and greatly reducing time spent debugging unit tests. However, note that mypy requires you type annotate every variable, function, and method in your code base, so taking advantage of it requires a lifestyle change! ### Developing with LLMs @@ -705,7 +727,9 @@ We also create POST endpoints, which accept form submissions so the user can cre #### Routing patterns in this template -In this template, GET routes are defined in the main entry point for the application, `main.py`. POST routes are organized into separate modules within the `routers/` directory. We name our GET routes using the convention `read_`, where `` is the name of the page, to indicate that they are read-only endpoints that do not modify the database. +In this template, GET routes are defined in the main entry point for the application, `main.py`. POST routes are organized into separate modules within the `routers/` directory. + +We name our GET routes using the convention `read_`, where `` is the name of the page, to indicate that they are read-only endpoints that do not modify the database. We divide our GET routes into authenticated and unauthenticated routes, using commented section headers in our code that look like this: @@ -713,7 +737,9 @@ We divide our GET routes into authenticated and unauthenticated routes, using co # -- Authenticated Routes -- ``` -Some of our routes take request parameters, which we pass as keyword arguments to the route handler. These parameters should be type annotated for validation purposes. Some parameters are shared across all authenticated or unauthenticated routes, so we define them in the `common_authenticated_parameters` and `common_unauthenticated_parameters` dependencies defined in `main.py`. +Some of our routes take request parameters, which we pass as keyword arguments to the route handler. These parameters should be type annotated for validation purposes. + +Some parameters are shared across all authenticated or unauthenticated routes, so we define them in the `common_authenticated_parameters` and `common_unauthenticated_parameters` dependencies defined in `main.py`. ### HTML templating with Jinja2 @@ -734,7 +760,7 @@ async def welcome(request: Request): ) ``` -In this example, the `welcome.html` template will receive two pieces of context: the user's `request`, which is always passed automatically by FastAPI, and a `username` variable, which we specify as "Alice". We can then use the `{{ username }}` syntax in the `welcome.html` template (or any of its parent or child templates) to insert the value into the HTML. +In this example, the `welcome.html` template will receive two pieces of context: the user's `request`, which is always passed automatically by FastAPI, and a `username` variable, which we specify as "Alice". We can then use the `{{{ username }}}` syntax in the `welcome.html` template (or any of its parent or child templates) to insert the value into the HTML. #### Form validation strategy @@ -890,9 +916,9 @@ graph.write_png('static/schema.png') ![Database Schema](static/schema.png) -#### Database operations +#### Database helpers -Database operations are handled by helper functions in `utils/db.py`. Key functions include: +Database operations are facilitated by helper functions in `utils/db.py`. Key functions include: - `set_up_db()`: Initializes the database schema and default data (which we do on every application start in `main.py`) - `get_connection_url()`: Creates a database connection URL from environment variables in `.env` @@ -909,6 +935,33 @@ async def get_users(session: Session = Depends(get_session)): The session automatically handles transaction management, ensuring that database operations are atomic and consistent. +#### Cascade deletes + +Cascade deletes (in which deleting a record from one table deletes related records from another table) can be handled at either the ORM level or the database level. This template handles cascade deletes at the ORM level, via SQLModel relationships. Inside a SQLModel `Relationship`, we set: + +```python +sa_relationship_kwargs={ + "cascade": "all, delete-orphan" +} +``` + +This tells SQLAlchemy to cascade all operations (e.g., `SELECT`, `INSERT`, `UPDATE`, `DELETE`) to the related table. Since this happens through the ORM, we need to be careful to do all our database operations through the ORM using supported syntax. That generally means loading database records into Python objects and then deleting those objects rather than deleting records in the database directly. + +For example, + +```python +session.exec(delete(Role)) +``` + +will not trigger the cascade delete. Instead, we need to select the role objects and then delete them: + +```python +for role in session.exec(select(Role)).all(): + session.delete(role) +``` + +This is slower than deleting the records directly, but it makes [many-to-many relationships](https://sqlmodel.tiangolo.com/tutorial/many-to-many/create-models-with-link/#create-the-tables) much easier to manage. + # Deployment diff --git a/docs/static/schema.png b/docs/static/schema.png index 0656c09..ae8fbf9 100644 Binary files a/docs/static/schema.png and b/docs/static/schema.png differ diff --git a/main.py b/main.py index 5642f6a..82e1cc4 100644 --- a/main.py +++ b/main.py @@ -8,11 +8,10 @@ from fastapi.exceptions import RequestValidationError, HTTPException, StarletteHTTPException from sqlmodel import Session from routers import authentication, organization, role, user -from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError, AuthenticationError -from utils.models import User +from utils.auth import get_user_with_relations, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError, AuthenticationError +from utils.models import User, Organization from utils.db import get_session, set_up_db - logger = logging.getLogger("uvicorn.error") logger.setLevel(logging.DEBUG) @@ -20,7 +19,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): # Optional startup logic - set_up_db(drop=False) + set_up_db() yield # Optional shutdown logic @@ -229,8 +228,8 @@ async def read_reset_password( # Define a dependency for common parameters async def common_authenticated_parameters( request: Request, - user: User = Depends(get_authenticated_user), - error_message: Optional[str] = None, + user: User = Depends(get_user_with_relations), + error_message: Optional[str] = None ) -> dict: return {"request": request, "user": user, "error_message": error_message} @@ -240,8 +239,6 @@ async def common_authenticated_parameters( async def read_dashboard( params: dict = Depends(common_authenticated_parameters) ): - if not params["user"]: - return RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND) return templates.TemplateResponse(params["request"], "dashboard/index.html", params) @@ -249,12 +246,27 @@ async def read_dashboard( async def read_profile( params: dict = Depends(common_authenticated_parameters) ): - if not params["user"]: - # Changed to 302 - return RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND) return templates.TemplateResponse(params["request"], "users/profile.html", params) +@app.get("/organizations/{org_id}") +async def read_organization( + org_id: int, + params: dict = Depends(common_authenticated_parameters) +): + # Get the organization only if the user is a member of it + org: Organization = params["user"].organizations.get(org_id) + if not org: + raise organization.OrganizationNotFoundError() + + # Eagerly load roles and users + org.roles + org.users + params["organization"] = org + + return templates.TemplateResponse(params["request"], "users/organization.html", params) + + # -- Include Routers -- diff --git a/routers/authentication.py b/routers/authentication.py index d487575..0832954 100644 --- a/routers/authentication.py +++ b/routers/authentication.py @@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse from pydantic import BaseModel, EmailStr, ConfigDict from sqlmodel import Session, select -from utils.models import User +from utils.models import User, UserPassword from utils.auth import ( get_session, get_user_from_reset_token, @@ -119,20 +119,25 @@ class UserRead(BaseModel): # -- Routes -- +# TODO: Use custom error message in the case where the user is already registered @router.post("/register", response_class=RedirectResponse) async def register( user: UserRegister = Depends(UserRegister.as_form), session: Session = Depends(get_session), ) -> RedirectResponse: + # Check if the email is already registered db_user = session.exec(select(User).where( User.email == user.email)).first() if db_user: raise HTTPException(status_code=400, detail="Email already registered") + # Hash the password hashed_password = get_password_hash(user.password) + + # Create the user db_user = User(name=user.name, email=user.email, - hashed_password=hashed_password) + password=UserPassword(hashed_password=hashed_password)) session.add(db_user) session.commit() session.refresh(db_user) @@ -154,9 +159,11 @@ async def login( user: UserLogin = Depends(UserLogin.as_form), session: Session = Depends(get_session), ) -> RedirectResponse: + # Check if the email is registered db_user = session.exec(select(User).where( User.email == user.email)).first() - if not db_user or not verify_password(user.password, db_user.hashed_password): + + if not db_user or not db_user.password or not verify_password(user.password, db_user.password.hashed_password): raise HTTPException(status_code=400, detail="Invalid credentials") # Create access token @@ -258,7 +265,17 @@ async def reset_password( raise HTTPException(status_code=400, detail="Invalid or expired token") # Update password and mark token as used - authorized_user.hashed_password = get_password_hash(user.new_password) + if authorized_user.password: + authorized_user.password.hashed_password = get_password_hash( + user.new_password + ) + else: + logger.warning( + "User password not found during password reset; creating new password for user") + authorized_user.password = UserPassword( + hashed_password=get_password_hash(user.new_password) + ) + reset_token.used = True session.commit() session.refresh(authorized_user) diff --git a/routers/organization.py b/routers/organization.py index ea48a94..daa8504 100644 --- a/routers/organization.py +++ b/routers/organization.py @@ -1,20 +1,58 @@ from logging import getLogger from fastapi import APIRouter, Depends, HTTPException, Form from fastapi.responses import RedirectResponse -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, field_validator from sqlmodel import Session, select from utils.db import get_session -from utils.models import Organization +from utils.auth import get_authenticated_user, get_user_with_relations, InsufficientPermissionsError +from utils.models import Organization, User, Role, utc_time, default_roles, ValidPermissions from datetime import datetime logger = getLogger("uvicorn.error") +# -- Custom Exceptions -- + + +class EmptyOrganizationNameError(HTTPException): + def __init__(self): + super().__init__( + status_code=400, + detail="Organization name cannot be empty" + ) + + +class OrganizationNotFoundError(HTTPException): + def __init__(self): + super().__init__( + status_code=404, + detail="Organization not found" + ) + + +class OrganizationNameTakenError(HTTPException): + def __init__(self): + super().__init__( + status_code=400, + detail="Organization name already taken" + ) + + router = APIRouter(prefix="/organizations", tags=["organizations"]) +# -- Server Request and Response Models -- + + class OrganizationCreate(BaseModel): name: str + @field_validator('name') + @classmethod + def validate_name(cls, name: str) -> str: + if not name.strip(): + raise EmptyOrganizationNameError() + return name.strip() + @classmethod async def as_form(cls, name: str = Form(...)): return cls(name=name) @@ -33,56 +71,72 @@ class OrganizationUpdate(BaseModel): id: int name: str + @field_validator('name') + @classmethod + def validate_name(cls, name: str) -> str: + if not name.strip(): + raise EmptyOrganizationNameError() + return name.strip() + @classmethod async def as_form(cls, id: int = Form(...), name: str = Form(...)): return cls(id=id, name=name) -@router.post("/", response_class=RedirectResponse) +# -- Routes -- + +@router.post("/create", response_class=RedirectResponse) def create_organization( org: OrganizationCreate = Depends(OrganizationCreate.as_form), + user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ) -> RedirectResponse: - # Validate organization name is not empty - if not org.name.strip(): - raise HTTPException( - status_code=400, detail="Organization name cannot be empty") - + # Check if organization already exists db_org = session.exec(select(Organization).where( Organization.name == org.name)).first() if db_org: - raise HTTPException( - status_code=400, detail="Organization already exists") + raise OrganizationNameTakenError() + # Create organization first db_org = Organization(name=org.name) session.add(db_org) + # This gets us the org ID without committing + session.flush() + + # Create default roles with organization_id + initial_roles = [ + Role(name=name, organization_id=db_org.id) + for name in default_roles + ] + session.add_all(initial_roles) + session.flush() + + # Get owner role for user assignment + owner_role = next(role for role in db_org.roles if role.name == "Owner") + + # Assign user to owner role + user.roles.append(owner_role) + + # Commit changes session.commit() session.refresh(db_org) return RedirectResponse(url=f"/organizations/{db_org.id}", status_code=303) -@router.get("/{org_id}", response_model=OrganizationRead) -def read_organization(org_id: int, session: Session = Depends(get_session)): - db_org = session.get(Organization, org_id) - if not db_org: - raise HTTPException(status_code=404, detail="Organization not found") - return db_org - - -@router.put("/{org_id}", response_class=RedirectResponse) +@router.post("/update/{org_id}", name="update_organization", response_class=RedirectResponse) def update_organization( org: OrganizationUpdate = Depends(OrganizationUpdate.as_form), + user: User = Depends(get_user_with_relations), session: Session = Depends(get_session) ) -> RedirectResponse: - # Validate organization name is not empty - if not org.name.strip(): - raise HTTPException( - status_code=400, detail="Organization name cannot be empty") + # This will raise appropriate exceptions if org doesn't exist or user lacks access + organization: Organization | None = next( + (org for org in user.organizations if org.id == org.id), None) - db_org = session.get(Organization, org.id) - if not db_org: - raise HTTPException(status_code=404, detail="Organization not found") + # Check if user has permission to edit organization + if not organization or not user.has_permission(ValidPermissions.EDIT_ORGANIZATION, organization): + raise InsufficientPermissionsError() # Check if new name already exists for another organization existing_org = session.exec( @@ -91,28 +145,35 @@ def update_organization( .where(Organization.id != org.id) ).first() if existing_org: - raise HTTPException( - status_code=400, detail="Organization name already taken") + raise OrganizationNameTakenError() - db_org.name = org.name - db_org.updated_at = datetime.utcnow() - session.add(db_org) + # Update organization name + organization.name = org.name + organization.updated_at = utc_time() + session.add(organization) session.commit() - session.refresh(db_org) - return RedirectResponse(url=f"/organizations/{org.id}", status_code=303) + return RedirectResponse(url=f"/profile", status_code=303) -@router.delete("/{org_id}", response_class=RedirectResponse) +@router.post("/delete/{org_id}", response_class=RedirectResponse) def delete_organization( org_id: int, + user: User = Depends(get_user_with_relations), session: Session = Depends(get_session) ) -> RedirectResponse: - db_org = session.get(Organization, org_id) - if not db_org: - raise HTTPException(status_code=404, detail="Organization not found") - - session.delete(db_org) + # Check if user has permission to delete organization + organization: Organization | None = next( + (org for org in user.organizations if org.id == org_id), None) + if not organization or not any( + p.name == ValidPermissions.DELETE_ORGANIZATION + for role in organization.roles + for p in role.permissions + ): + raise InsufficientPermissionsError() + + # Delete organization + session.delete(organization) session.commit() - return RedirectResponse(url="/organizations", status_code=303) + return RedirectResponse(url="/profile", status_code=303) diff --git a/routers/role.py b/routers/role.py index a6429c4..1a89f2d 100644 --- a/routers/role.py +++ b/routers/role.py @@ -1,134 +1,258 @@ -from typing import List -from datetime import datetime +# TODO: User with permission to create/edit roles can only assign permissions +# they themselves have. +from typing import List, Sequence, Optional from logging import getLogger -from fastapi import APIRouter, Depends, HTTPException, Form +from fastapi import APIRouter, Depends, Form, HTTPException from fastapi.responses import RedirectResponse -from pydantic import BaseModel, ConfigDict -from sqlmodel import Session, select +from pydantic import BaseModel, ConfigDict, field_validator +from sqlmodel import Session, select, col +from sqlalchemy.orm import selectinload from utils.db import get_session -from utils.models import Role, RolePermissionLink, ValidPermissions, utc_time +from utils.auth import get_authenticated_user, InsufficientPermissionsError +from utils.models import Role, Permission, ValidPermissions, utc_time, User, DataIntegrityError logger = getLogger("uvicorn.error") router = APIRouter(prefix="/roles", tags=["roles"]) +# -- Custom Exceptions -- + + +class InvalidPermissionError(HTTPException): + """Raised when a user attempts to assign an invalid permission to a role""" + + def __init__(self, permission: ValidPermissions): + super().__init__( + status_code=400, + detail=f"Invalid permission: {permission}" + ) + + +class RoleAlreadyExistsError(HTTPException): + """Raised when attempting to create a role with a name that already exists""" + + def __init__(self): + super().__init__(status_code=400, detail="Role already exists") + + +class RoleNotFoundError(HTTPException): + """Raised when a requested role does not exist""" + + def __init__(self): + super().__init__(status_code=404, detail="Role not found") + + +class RoleHasUsersError(HTTPException): + """Raised when a requested role to be deleted has users""" + + def __init__(self): + super().__init__( + status_code=400, + detail="Role cannot be deleted until users with that role are reassigned" + ) + + +# -- Server Request Models -- + class RoleCreate(BaseModel): model_config = ConfigDict(from_attributes=True) name: str + organization_id: int permissions: List[ValidPermissions] @classmethod - async def as_form(cls, name: str = Form(...), permissions: List[ValidPermissions] = Form(...)): - return cls(name=name, permissions=permissions) + async def as_form( + cls, + name: str = Form(...), + organization_id: int = Form(...), + permissions: List[ValidPermissions] = Form(...) + ): + # Pass session to validator context + return cls( + name=name, + organization_id=organization_id, + permissions=permissions + ) -class RoleRead(BaseModel): +class RoleUpdate(BaseModel): model_config = ConfigDict(from_attributes=True) id: int name: str - created_at: datetime - updated_at: datetime + organization_id: int permissions: List[ValidPermissions] + @field_validator("id") + @classmethod + def validate_role_exists(cls, id: int, info): + session = info.context.get("session") + if session: + role = session.get(Role, id) + if not role or not role.id: + raise RoleNotFoundError() + return id -class RoleUpdate(BaseModel): + @classmethod + async def as_form( + cls, + id: int = Form(...), + name: str = Form(...), + organization_id: int = Form(...), + permissions: List[ValidPermissions] = Form(...) + ): + return cls( + id=id, + name=name, + organization_id=organization_id, + permissions=permissions + ) + + +class RoleDelete(BaseModel): model_config = ConfigDict(from_attributes=True) id: int - name: str - permissions: List[ValidPermissions] + organization_id: int @classmethod - async def as_form(cls, id: int = Form(...), name: str = Form(...), permissions: List[ValidPermissions] = Form(...)): - return cls(id=id, name=name, permissions=permissions) + async def as_form( + cls, + id: int = Form(...), + organization_id: int = Form(...) + ): + return cls(id=id, organization_id=organization_id) -@router.post("/", response_class=RedirectResponse) +# -- Routes -- + + +@router.post("/create", response_class=RedirectResponse) def create_role( role: RoleCreate = Depends(RoleCreate.as_form), + user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ) -> RedirectResponse: - db_role = session.exec(select(Role).where(Role.name == role.name)).first() - if db_role: - raise HTTPException(status_code=400, detail="Role already exists") - - # Create role and permissions in a single transaction - db_role = Role(name=role.name) + # Check that the user-selected role name is unique for the organization + if session.exec( + select(Role).where( + Role.name == role.name, + Role.organization_id == role.organization_id + ) + ).first(): + raise RoleAlreadyExistsError() - # Create RolePermissionLink objects and associate them with the role - db_role.permissions = [ - RolePermissionLink(permission_id=permission.name) - for permission in role.permissions - ] + # Check that the user is authorized to create roles in the organization + if not user.has_permission(ValidPermissions.CREATE_ROLE, role.organization_id): + raise InsufficientPermissionsError() + # Create role + db_role = Role( + name=role.name, + organization_id=role.organization_id + ) session.add(db_role) - session.commit() # Commit once after all operations - - return RedirectResponse(url="/roles", status_code=303) - -@router.get("/{role_id}", response_model=RoleRead) -def read_role(role_id: int, session: Session = Depends(get_session)): - db_role: Role | None = session.get(Role, role_id) - if not db_role or not db_role.id: - raise HTTPException(status_code=404, detail="Role not found") + # Select Permission records corresponding to the user-selected permissions + # and associate them with the newly created role + permissions: Sequence[Permission] = session.exec( + select(Permission).where(col(Permission.name).in_(role.permissions)) + ).all() + db_role.permissions.extend(permissions) - permissions = [ - ValidPermissions(link.permission.name) - for link in db_role.role_permission_links - if link.permission is not None - ] + # Commit transaction + session.commit() - return RoleRead( - id=db_role.id, - name=db_role.name, - created_at=db_role.created_at, - updated_at=db_role.updated_at, - permissions=permissions - ) + return RedirectResponse(url="/profile", status_code=303) -@router.put("/{role_id}", response_class=RedirectResponse) +@router.post("/update", response_class=RedirectResponse) def update_role( role: RoleUpdate = Depends(RoleUpdate.as_form), + user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ) -> RedirectResponse: - db_role: Role | None = session.get(Role, role.id) - if not db_role or not db_role.id: - raise HTTPException(status_code=404, detail="Role not found") - role_data = role.model_dump(exclude_unset=True) - for key, value in role_data.items(): - setattr(db_role, key, value) - db_role.updated_at = utc_time() - session.add(db_role) - session.commit() + # Check that the user is authorized to update the role + if not user.has_permission(ValidPermissions.EDIT_ROLE, role.organization_id): + raise InsufficientPermissionsError() + + # Select db_role to update, along with its permissions, by ID + db_role: Optional[Role] = session.exec( + select(Role).where(Role.id == role.id).options( + selectinload(Role.permissions)) + ).first() - # Correctly delete RolePermissionLinks for the role - session.delete(RolePermissionLink.role_id == role.id) + if not db_role: + raise RoleNotFoundError() + + # If any user-selected permissions are not valid, raise an error + for permission in role.permissions: + if permission not in ValidPermissions: + raise InvalidPermissionError(permission) + # Add any user-selected permissions that are not already associated with the role for permission in role.permissions: - db_role_permission_link = RolePermissionLink( - role_id=db_role.id, - permission_id=permission.name + if permission not in [p.name for p in db_role.permissions]: + db_permission: Optional[Permission] = session.exec( + select(Permission).where(Permission.name == permission) + ).first() + if db_permission: + db_role.permissions.append(db_permission) + else: + raise DataIntegrityError(resource=f"Permission: {permission}") + + # Remove any permissions that are not user-selected + for db_permission in db_role.permissions: + if db_permission.name not in role.permissions: + db_role.permissions.remove(db_permission) + + # Check that no existing organization role has the same name but a different ID + if session.exec( + select(Role).where( + Role.name == role.name, + Role.organization_id == role.organization_id, + Role.id != role.id ) - session.add(db_role_permission_link) + ).first(): + raise RoleAlreadyExistsError() + + # Update role name and updated_at timestamp + db_role.name = role.name + db_role.updated_at = utc_time() session.commit() session.refresh(db_role) - return RedirectResponse(url=f"/roles/{role.id}", status_code=303) + return RedirectResponse(url="/profile", status_code=303) -@router.delete("/{role_id}", response_class=RedirectResponse) +@router.post("/delete", response_class=RedirectResponse) def delete_role( - role_id: int, + role: RoleDelete = Depends(RoleDelete.as_form), + user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ) -> RedirectResponse: - db_role = session.get(Role, role_id) + # Check that the user is authorized to delete the role + if not user.has_permission(ValidPermissions.DELETE_ROLE, role.organization_id): + raise InsufficientPermissionsError() + + # Select the role to delete by ID, along with its users + db_role: Role | None = session.exec( + select(Role).where(Role.id == role.id).options( + selectinload(Role.users) + ) + ).first() + if not db_role: - raise HTTPException(status_code=404, detail="Role not found") + raise RoleNotFoundError() + + # Check that no users have the role + if db_role.users: + raise RoleHasUsersError() + + # Delete the role session.delete(db_role) session.commit() - return RedirectResponse(url="/roles", status_code=303) + + return RedirectResponse(url="/profile", status_code=303) diff --git a/routers/user.py b/routers/user.py index 043509d..5639d79 100644 --- a/routers/user.py +++ b/routers/user.py @@ -63,9 +63,15 @@ async def delete_account( current_user: User = Depends(get_authenticated_user), session: Session = Depends(get_session) ): + if not current_user.password: + raise HTTPException( + status_code=500, + detail="User password not found in database; please contact a system administrator" + ) + if not verify_password( user_delete_account.confirm_delete_password, - current_user.hashed_password + current_user.password.hashed_password ): raise HTTPException( status_code=400, diff --git a/templates/authentication/register.html b/templates/authentication/register.html index ceb8aac..e508fd1 100644 --- a/templates/authentication/register.html +++ b/templates/authentication/register.html @@ -24,6 +24,7 @@
+ +
+ Organizations + +
+
+ +
+
+
+ + +
+ +
+
+ + + {% if organizations %} +
+ {% for org in organizations %} + +
+
{{ org.name }}
+ Joined {{ org.created_at.strftime('%Y-%m-%d') }} +
+
+ {% endfor %} +
+ {% else %} +

You are not a member of any organizations.

+ {% endif %} +
+
+{% endmacro %} diff --git a/templates/users/organization.html b/templates/users/organization.html new file mode 100644 index 0000000..2e75fa9 --- /dev/null +++ b/templates/users/organization.html @@ -0,0 +1,90 @@ +{% extends "base.html" %} +{% from 'components/silhouette.html' import render_silhouette %} + +{% block title %}{{ organization.name }}{% endblock %} + +{% block content %} +
+

{{ organization.name }}

+ + +
+
+ Roles +
+
+
+ + + + + + + + + + {% for role in organization.roles %} + + + + + + {% endfor %} + +
Role NameMembersPermissions
{{ role.name }}{{ role.users|length }} +
    + {% for permission in role.permissions %} +
  • {{ permission.name.value }}
  • + {% endfor %} +
+
+
+
+
+ + +
+
+ Members +
+
+
+ + + + + + + + + + + {% for role in organization.roles %} + {% for user in role.users %} + + + + + + + {% endfor %} + {% endfor %} + +
NameEmailRoles
+ {% if user.avatar_url %} + User Avatar + {% else %} + {{ render_silhouette(width=40, height=40) }} + {% endif %} + {{ user.name }}{{ user.email }} + {% for user_role in user.roles %} + {% if user_role.organization_id == organization.id %} + {{ user_role.name }} + {% endif %} + {% endfor %} +
+
+
+
+
+{% endblock %} diff --git a/templates/users/profile.html b/templates/users/profile.html index 448c0d6..0053993 100644 --- a/templates/users/profile.html +++ b/templates/users/profile.html @@ -1,5 +1,6 @@ {% extends "base.html" %} {% from 'components/silhouette.html' import render_silhouette %} +{% from 'components/organizations.html' import render_organizations %} {% block title %}Profile{% endblock %} @@ -58,7 +59,7 @@

User Profile

Change Password
- +

To change your password, please confirm your email. A password reset link will be sent to your email address.

@@ -67,6 +68,9 @@

User Profile

+ + {{ render_organizations(user.roles|map(attribute='organization')|list) }} +
diff --git a/tests/conftest.py b/tests/conftest.py index f9b90ae..565a1ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,24 @@ import pytest from dotenv import load_dotenv -from sqlmodel import create_engine, Session, delete +from sqlmodel import create_engine, Session, select +from sqlalchemy import Engine from fastapi.testclient import TestClient from utils.db import get_connection_url, set_up_db, tear_down_db, get_session -from utils.models import User, PasswordResetToken +from utils.models import User, PasswordResetToken, Organization, Role, UserPassword from utils.auth import get_password_hash, create_access_token, create_refresh_token from main import app load_dotenv() +# Define a custom exception for test setup errors +class SetupError(Exception): + """Exception raised for errors in the test setup process.""" + pass + + @pytest.fixture(scope="session") -def engine(): +def engine() -> Engine: """ Create a new SQLModel engine for the test database. Use an in-memory SQLite database for testing. @@ -47,9 +54,9 @@ def clean_db(session: Session): """ Cleans up the database tables before each test. """ - # Exempt from mypy until SQLModel overload properly supports delete() - session.exec(delete(PasswordResetToken)) # type: ignore - session.exec(delete(User)) # type: ignore + for model in (PasswordResetToken, User, Role, Organization): + for record in session.exec(select(model)).all(): + session.delete(record) session.commit() @@ -63,7 +70,7 @@ def test_user(session: Session): user = User( name="Test User", email="test@example.com", - hashed_password=get_password_hash("Test123!@#") + password=UserPassword(hashed_password=get_password_hash("Test123!@#")) ) session.add(user) session.commit() @@ -107,3 +114,12 @@ def get_session_override(): yield client app.dependency_overrides.clear() + + +@pytest.fixture +def test_organization(session: Session): + """Create a test organization for use in tests""" + organization = Organization(name="Test Organization") + session.add(organization) + session.commit() + return organization diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 0ba2331..58bcc04 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -17,7 +17,7 @@ validate_token, generate_password_reset_url ) - +from .conftest import SetupError # --- Fixture setup --- @@ -28,7 +28,7 @@ def mock_email_response(): """ Returns a mock Email response object """ - return resend.Email(id="6229f547-f3f6-4eb8-b0dc-82c1b09121b6") + return resend.Email(id="mock_resend_id") @pytest.fixture @@ -104,7 +104,12 @@ def test_register_endpoint(unauth_client: TestClient, session: Session): User.email == "new@example.com")).first() assert user is not None assert user.name == "New User" - assert verify_password("NewPass123!@#", user.hashed_password) + + # Verify password was hashed and matches + if not user.password: + raise SetupError( + "Test setup failed; user.password is None") + assert verify_password("NewPass123!@#", user.password.hashed_password) def test_login_endpoint(unauth_client: TestClient, test_user: User): @@ -200,10 +205,14 @@ def test_password_reset_flow(unauth_client: TestClient, session: Session, test_u ) assert response.status_code == 303 + if not test_user.password: + raise SetupError( + "Test setup failed; test_user.password is None") + # Verify password was updated and token was marked as used session.refresh(test_user) session.refresh(reset_token) - assert verify_password("NewPass123!@#", test_user.hashed_password) + assert verify_password("NewPass123!@#", test_user.password.hashed_password) assert reset_token.used diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..aa6fe79 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,181 @@ +import warnings +from sqlmodel import Session, select, inspect +from sqlalchemy import Engine +from utils.db import ( + get_connection_url, + assign_permissions_to_role, + create_default_roles, + create_permissions, + tear_down_db, + set_up_db, +) +from utils.models import Role, Permission, Organization, RolePermissionLink, ValidPermissions +from .conftest import SetupError + + +def test_get_connection_url(): + """Test that get_connection_url returns a valid URL object""" + url = get_connection_url() + assert url.drivername == "postgresql" + assert url.database is not None + + +def test_create_permissions(session: Session): + """Test that create_permissions creates all ValidPermissions""" + # Clear existing permissions + existing_permissions = session.exec(select(Permission)).all() + for permission in existing_permissions: + session.delete(permission) + session.commit() + + create_permissions(session) + session.commit() + + # Check all permissions were created + db_permissions = session.exec(select(Permission)).all() + assert len(db_permissions) == len(ValidPermissions) + assert {p.name for p in db_permissions} == {p for p in ValidPermissions} + + +def test_create_default_roles(session: Session, test_organization: Organization): + """Test that create_default_roles creates expected roles with correct permissions""" + # Create permissions first + create_permissions(session) + session.commit() + + # Create roles for test organization + if test_organization.id is not None: + roles = create_default_roles(session, test_organization.id) + session.commit() + else: + raise SetupError( + "Test setup failed; test_organization.id is None") + + # Verify roles were created + assert len(roles) == 3 # Owner, Administrator, Member + + # Check Owner role permissions + owner_role = next(r for r in roles if r.name == "Owner") + owner_permissions = session.exec( + select(Permission) + .join(RolePermissionLink) + .where(RolePermissionLink.role_id == owner_role.id) + ).all() + assert len(owner_permissions) == len(ValidPermissions) + + # Check Administrator role permissions + admin_role = next(r for r in roles if r.name == "Administrator") + admin_permissions = session.exec( + select(Permission) + .join(RolePermissionLink) + .where(RolePermissionLink.role_id == admin_role.id) + ).all() + # Admin should have all permissions except DELETE_ORGANIZATION + assert len(admin_permissions) == len(ValidPermissions) - 1 + assert ValidPermissions.DELETE_ORGANIZATION not in { + p.name for p in admin_permissions} + + +def test_assign_permissions_to_role(session: Session, test_organization: Organization): + """Test that assign_permissions_to_role correctly assigns permissions""" + # Create a test role with the organization from fixture + role = Role(name="Test Role", organization_id=test_organization.id) + session.add(role) + + # Create test permissions + perm1 = Permission(name=ValidPermissions.CREATE_ROLE) + perm2 = Permission(name=ValidPermissions.DELETE_ROLE) + session.add(perm1) + session.add(perm2) + session.commit() + + # Assign permissions + permissions = [perm1, perm2] + assign_permissions_to_role(session, role, permissions) + session.commit() + + # Verify assignments + db_permissions = session.exec( + select(Permission) + .join(RolePermissionLink) + .where(RolePermissionLink.role_id == role.id) + ).all() + + assert len(db_permissions) == 2 + assert {p.name for p in db_permissions} == { + ValidPermissions.CREATE_ROLE, ValidPermissions.DELETE_ROLE} + + +def test_assign_permissions_to_role_duplicate_check(session: Session, test_organization: Organization): + """Test that assign_permissions_to_role doesn't create duplicates""" + # Create a test role with the organization from fixture + role = Role(name="Test Role", organization_id=test_organization.id) + perm = Permission(name=ValidPermissions.CREATE_ROLE) + session.add(role) + session.add(perm) + session.commit() + + # Assign same permission twice + assign_permissions_to_role(session, role, [perm], check_first=True) + assign_permissions_to_role(session, role, [perm], check_first=True) + session.commit() + + # Verify only one assignment exists + link_count = session.exec( + select(RolePermissionLink) + .where( + RolePermissionLink.role_id == role.id, + RolePermissionLink.permission_id == perm.id + ) + ).all() + assert len(link_count) == 1 + + +def test_set_up_db_creates_tables(engine: Engine, session: Session): + """Test that set_up_db creates all expected tables without warnings""" + # First tear down any existing tables + tear_down_db() + + # Run set_up_db with drop=False since we just cleaned up + set_up_db(drop=False) + + # Use SQLAlchemy inspect to check tables + inspector = inspect(engine) + table_names = inspector.get_table_names() + + # Check for core tables + expected_tables = { + "user", + "organization", + "role", + "permission", + "rolepermissionlink", + "passwordresettoken" + } + assert expected_tables.issubset(set(table_names)) + + # Verify permissions were created + permissions = session.exec(select(Permission)).all() + assert len(permissions) == len(ValidPermissions) + + +def test_set_up_db_drop_flag(engine: Engine, session: Session): + """Test that set_up_db's drop flag properly recreates tables""" + # Set up db with drop=True + set_up_db(drop=True) + + # Verify valid permissions exist + permissions = session.exec(select(Permission)).all() + assert len(permissions) == len(ValidPermissions) + + # Create an organization + org = Organization(name="Test Organization") + session.add(org) + session.commit() + + # Set up db with drop=False + set_up_db(drop=False) + + # Verify organization exists + assert session.exec(select(Organization).where( + Organization.name == "Test Organization")).first() is not None diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..282ef0f --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,230 @@ +from datetime import timedelta, datetime, UTC +from typing import Optional +from sqlmodel import select, Session +from utils.models import ( + Permission, + Role, + RolePermissionLink, + Organization, + ValidPermissions, + User, + UserRoleLink, + PasswordResetToken +) +from .conftest import SetupError + + +def test_permissions_persist_after_role_deletion(session: Session): + """ + Test that permissions are not deleted when a related Role is deleted. + Permissions links are automatically deleted due to cascade_delete=True. + """ + # Verify all ValidPermissions exist in database + all_permissions = session.exec(select(Permission)).all() + assert len(all_permissions) == len(ValidPermissions) + + # Create an organization + organization = Organization(name="Test Organization") + session.add(organization) + session.commit() + session.refresh(organization) + + # Create a role and link two specific permissions + role = Role(name="Test Role", organization_id=organization.id) + session.add(role) + session.commit() + session.refresh(role) + + # Find specific permissions to link + delete_org_permission = next( + p for p in all_permissions if p.name == ValidPermissions.DELETE_ORGANIZATION) + edit_org_permission = next( + p for p in all_permissions if p.name == ValidPermissions.EDIT_ORGANIZATION) + + role.permissions.append(delete_org_permission) + role.permissions.append(edit_org_permission) + session.commit() + + # Verify that RolePermissionLinks exist before deletion + role_permissions = session.exec(select(RolePermissionLink)).all() + assert len(role_permissions) == 2 + + # Delete the role (this will cascade delete the permission links) + session.delete(role) + session.commit() + + # Verify that all permissions still exist + remaining_permissions = session.exec(select(Permission)).all() + assert len(remaining_permissions) == len(ValidPermissions) + assert delete_org_permission in remaining_permissions + assert edit_org_permission in remaining_permissions + + # Verify that RolePermissionLinks were cascade deleted + remaining_role_permissions = session.exec(select(RolePermissionLink)).all() + assert len(remaining_role_permissions) == 0 + + +def test_user_organizations_property(session: Session, test_user: User, test_organization: Organization): + """ + Test that User.organizations property correctly returns all organizations + the user belongs to via their roles. + """ + # Create a role in the test organization + role = Role(name="Test Role", organization_id=test_organization.id) + session.add(role) + + # Link the user to the role + test_user.roles.append(role) + session.commit() + + # Refresh the user to ensure relationships are loaded + session.refresh(test_user) + + # Test the organizations property + assert len(test_user.organizations) == 1 + assert test_user.organizations[0].id == test_organization.id + + +def test_organization_users_property(session: Session, test_user: User, test_organization: Organization): + """ + Test that Organization.users property correctly returns all users + in the organization via their roles. + """ + # Create a role in the test organization + role = Role(name="Test Role", organization_id=test_organization.id) + session.add(role) + session.commit() + + # Link the user to the role + test_user.roles.append(role) + session.commit() + + # Refresh the organization to ensure relationships are loaded + session.refresh(test_organization) + + # Test the users property + users_list: list[User] = test_organization.users + assert len(users_list) == 1 + assert test_user in users_list + + +def test_cascade_delete_organization(session: Session, test_user: User, test_organization: Organization): + """ + Test that deleting an organization cascades properly: + - Deletes associated roles + - Deletes role-user links + - Does not delete users + """ + # Create a role in the test organization + role = Role(name="Test Role", organization_id=test_organization.id) + session.add(role) + test_user.roles.append(role) + session.commit() + + # Delete the organization + session.delete(test_organization) + session.commit() + + # Verify the role was deleted + remaining_roles = session.exec(select(Role)).all() + assert len(remaining_roles) == 0 + + # Verify the user-role link was deleted + remaining_links = session.exec(select(UserRoleLink)).all() + assert len(remaining_links) == 0 + + # Verify the user still exists + remaining_user = session.exec(select(User)).first() + assert remaining_user is not None + assert remaining_user.id == test_user.id + + +def test_password_reset_token_cascade_delete(session: Session, test_user: User): + """ + Test that password reset tokens are deleted when a user is deleted + """ + # Create reset tokens for the user + token1 = PasswordResetToken(user_id=test_user.id) + token2 = PasswordResetToken(user_id=test_user.id) + session.add(token1) + session.add(token2) + session.commit() + + # Verify tokens exist + tokens = session.exec(select(PasswordResetToken)).all() + assert len(tokens) == 2 + + # Delete the user + session.delete(test_user) + session.commit() + + # Verify tokens were cascade deleted + remaining_tokens = session.exec(select(PasswordResetToken)).all() + assert len(remaining_tokens) == 0 + + +def test_password_reset_token_is_expired(session: Session, test_user: User): + """ + Test that password reset token expiration is properly set and checked + """ + # Create an expired token + expired_token = PasswordResetToken( + user_id=test_user.id, + expires_at=datetime.now(UTC) - timedelta(hours=1) + ) + session.add(expired_token) + + # Create a valid token + valid_token = PasswordResetToken( + user_id=test_user.id, + expires_at=datetime.now(UTC) + timedelta(hours=1) + ) + session.add(valid_token) + session.commit() + + # Verify expiration states + assert expired_token.is_expired() + assert not valid_token.is_expired() + + +def test_user_has_permission(session: Session, test_user: User, test_organization: Organization): + """ + Test that User.has_permission method correctly checks if a user has a specific + permission for a given organization. + """ + # Create a role with specific permissions in the test organization + role = Role(name="Test Role", organization_id=test_organization.id) + session.add(role) + session.commit() + session.refresh(role) + + # Assign permissions to the role + delete_org_permission: Optional[Permission] = session.exec( + select(Permission).where(Permission.name == + ValidPermissions.DELETE_ORGANIZATION) + ).first() + edit_org_permission: Optional[Permission] = session.exec( + select(Permission).where(Permission.name == + ValidPermissions.EDIT_ORGANIZATION) + ).first() + + if delete_org_permission is not None and edit_org_permission is not None: + role.permissions.append(delete_org_permission) + role.permissions.append(edit_org_permission) + else: + raise SetupError( + "Test setup failed; permission not found in database") + session.commit() + + # Link the user to the role + test_user.roles.append(role) + session.commit() + session.refresh(test_user) + + # Test the has_permission method + assert test_user.has_permission( + ValidPermissions.DELETE_ORGANIZATION, test_organization) is True + assert test_user.has_permission( + ValidPermissions.EDIT_ORGANIZATION, test_organization) is True + assert test_user.has_permission( + ValidPermissions.INVITE_USER, test_organization) is False diff --git a/utils/auth.py b/utils/auth.py index 5793c3e..e7de8c5 100644 --- a/utils/auth.py +++ b/utils/auth.py @@ -8,18 +8,20 @@ from dotenv import load_dotenv from pydantic import field_validator, ValidationInfo from sqlmodel import Session, select +from sqlalchemy.orm import selectinload from bcrypt import gensalt, hashpw, checkpw from datetime import UTC, datetime, timedelta from typing import Optional from fastapi import Depends, Cookie, HTTPException, status from utils.db import get_session -from utils.models import User, PasswordResetToken +from utils.models import User, Role, PasswordResetToken load_dotenv() logger = logging.getLogger("uvicorn.error") -# --- AUTH --- +# --- Constants --- + SECRET_KEY = os.getenv("SECRET_KEY") ALGORITHM = "HS256" @@ -27,12 +29,15 @@ REFRESH_TOKEN_EXPIRE_DAYS = 30 -# Define the oauth2 scheme to get the token from the cookie -def oauth2_scheme_cookie( - access_token: Optional[str] = Cookie(None, alias="access_token"), - refresh_token: Optional[str] = Cookie(None, alias="refresh_token"), -) -> tuple[Optional[str], Optional[str]]: - return access_token, refresh_token +# --- Custom Exceptions --- + + +class AuthenticationError(HTTPException): + def __init__(self): + super().__init__( + status_code=status.HTTP_303_SEE_OTHER, + headers={"Location": "/login"} + ) class PasswordValidationError(HTTPException): @@ -54,6 +59,25 @@ def __init__(self, field: str = "confirm_password"): ) +class InsufficientPermissionsError(HTTPException): + def __init__(self): + super().__init__( + status_code=403, + detail="You don't have permission to perform this action" + ) + + +# --- Helpers --- + + +# Define the oauth2 scheme to get the token from the cookie +def oauth2_scheme_cookie( + access_token: Optional[str] = Cookie(None, alias="access_token"), + refresh_token: Optional[str] = Cookie(None, alias="refresh_token"), +) -> tuple[Optional[str], Optional[str]]: + return access_token, refresh_token + + def create_password_validator(field_name: str = "password"): """ Factory function that creates a password validation decorator for Pydantic models. @@ -216,14 +240,6 @@ def get_user_from_tokens( return None, None, None -class AuthenticationError(HTTPException): - def __init__(self): - super().__init__( - status_code=status.HTTP_303_SEE_OTHER, - headers={"Location": "/login"} - ) - - def get_authenticated_user( tokens: tuple[Optional[str], Optional[str] ] = Depends(oauth2_scheme_cookie), @@ -339,3 +355,23 @@ def get_user_from_reset_token(email: str, token: str, session: Session) -> tuple user, reset_token = result return user, reset_token + + +def get_user_with_relations( + user: User = Depends(get_authenticated_user), + session: Session = Depends(get_session), +) -> User: + """ + Returns an authenticated user with fully loaded role and organization relationships. + """ + # Refresh the user instance with eagerly loaded relationships + eager_user = session.exec( + select(User) + .where(User.id == user.id) + .options( + selectinload(User.roles).selectinload(Role.organization), + selectinload(User.roles).selectinload(Role.permissions) + ) + ).one() + + return eager_user diff --git a/utils/db.py b/utils/db.py index e044591..b65af2f 100644 --- a/utils/db.py +++ b/utils/db.py @@ -1,23 +1,35 @@ import os import logging +from typing import Generator, Union, Sequence from dotenv import load_dotenv +from fastapi import HTTPException from sqlalchemy.engine import URL from sqlmodel import create_engine, Session, SQLModel, select from utils.models import Role, Permission, RolePermissionLink, default_roles, ValidPermissions +# Load environment variables from a .env file load_dotenv() +# Set up a logger for error reporting logger = logging.getLogger("uvicorn.error") -# --- Database connection --- +# --- Database connection functions --- def get_connection_url() -> URL: """ - Creates a SQLModel URL object containing the connection URL to the Postgres database. - The connection details are obtained from environment variables. - Returns the URL object. + Constructs a SQLModel URL object for connecting to the PostgreSQL database. + + The connection details are sourced from environment variables, which should include: + - DB_USER: Database username + - DB_PASSWORD: Database password + - DB_HOST: Database host address + - DB_PORT: Database port (default is 5432) + - DB_NAME: Database name + + Returns: + URL: A SQLModel URL object containing the connection details. """ database_url: URL = URL.create( drivername="postgresql", @@ -31,32 +43,116 @@ def get_connection_url() -> URL: return database_url +# Create the database engine using the connection URL engine = create_engine(get_connection_url()) -def get_session(): +def get_session() -> Generator[Session, None, None]: + """ + Provides a database session for executing queries. + + Yields: + Session: A SQLModel session object for database operations. + """ with Session(engine) as session: yield session -def create_roles(session): +def assign_permissions_to_role( + session: Session, + role: Role, + permissions: Union[list[Permission], Sequence[Permission]], + check_first: bool = False +) -> None: """ - Create default roles in the database if they do not exist. + Assigns permissions to a role in the database. + + Args: + session (Session): The database session to use for operations. + role (Role): The role to assign permissions to. + permissions (list[Permission]): The list of permissions to assign. + check_first (bool): If True, checks if the role already has the permission before assigning it. + """ + + for permission in permissions: + # Check if the role already has the permission + if check_first: + db_role_permission_link: RolePermissionLink | None = session.exec( + select(RolePermissionLink).where( + RolePermissionLink.role_id == role.id, + RolePermissionLink.permission_id == permission.id + ) + ).first() + else: + db_role_permission_link = None + + # Skip granting DELETE_ORGANIZATION permission to the Administrator role + if not db_role_permission_link: + role_permission_link = RolePermissionLink( + role_id=role.id, + permission_id=permission.id + ) + session.add(role_permission_link) + + +def create_default_roles(session: Session, organization_id: int, check_first: bool = True) -> list: + """ + Creates default roles for a specified organization in the database if they do not already exist, + and assigns permissions to the Owner and Administrator roles. + + Args: + session (Session): The database session to use for operations. + organization_id (int): The ID of the organization for which to create roles. + check_first (bool): If True, checks if the role already exists before creating it. + + Returns: + list: A list of roles that were created or already existed in the database. """ + roles_in_db = [] for role_name in default_roles: - db_role = session.exec(select(Role).where( - Role.name == role_name)).first() + db_role = session.exec( + select(Role).where( + Role.name == role_name, + Role.organization_id == organization_id + ) + ).first() if not db_role: - db_role = Role(name=role_name) + db_role = Role(name=role_name, organization_id=organization_id) session.add(db_role) roles_in_db.append(db_role) + + # TODO: Construct this role-permission mapping once at app startup and use as constant + # Fetch all permissions once + owner_permissions = session.exec(select(Permission)).all() + admin_permissions = [ + permission for permission in owner_permissions + if permission.name != ValidPermissions.DELETE_ORGANIZATION + ] + + # Get Owner and Administrator roles by name + owner_role = next(role for role in roles_in_db if role.name == "Owner") + admin_role = next( + role for role in roles_in_db if role.name == "Administrator") + + # Assign all permissions to Owner + assign_permissions_to_role( + session, owner_role, owner_permissions, check_first=check_first) + + # Assign filtered permissions to Administrator + assign_permissions_to_role( + session, admin_role, admin_permissions, check_first=check_first) + + session.commit() return roles_in_db -def create_permissions(session, roles_in_db): +def create_permissions(session: Session) -> None: """ - Create default permissions and link them to roles in the database. + Creates default permissions in the database if they do not already exist. + + Args: + session (Session): The database session to use for operations. """ for permission in ValidPermissions: db_permission = session.exec(select(Permission).where( @@ -65,37 +161,28 @@ def create_permissions(session, roles_in_db): db_permission = Permission(name=permission) session.add(db_permission) - # Create RolePermissionLink for Owner and Administrator - for role in roles_in_db[:2]: - db_role_permission_link = session.exec(select(RolePermissionLink).where( - RolePermissionLink.role_id == role.id, - RolePermissionLink.permission_id == db_permission.id)).first() - if not db_role_permission_link: - if not (permission == ValidPermissions.DELETE_ORGANIZATION and role.name == "Administrator"): - role_permission_link = RolePermissionLink( - role_id=role.id, permission_id=db_permission.id) - session.add(role_permission_link) - -def set_up_db(drop: bool = False): +def set_up_db(drop: bool = False) -> None: """ - Set up the database by creating tables and populating them with default roles and permissions. + Sets up the database by creating tables and populating them with default permissions. + + Args: + drop (bool): If True, drops all existing tables before creating new ones. """ engine = create_engine(get_connection_url()) if drop: SQLModel.metadata.drop_all(engine) SQLModel.metadata.create_all(engine) + # Create default permissions with Session(engine) as session: - roles_in_db = create_roles(session) - session.commit() - create_permissions(session, roles_in_db) + create_permissions(session) session.commit() engine.dispose() -def tear_down_db(): +def tear_down_db() -> None: """ - Tear down the database by dropping all tables. + Tears down the database by dropping all tables. """ engine = create_engine(get_connection_url()) SQLModel.metadata.drop_all(engine) diff --git a/utils/models.py b/utils/models.py index cd43f34..63d9ec2 100644 --- a/utils/models.py +++ b/utils/models.py @@ -1,15 +1,43 @@ +from logging import getLogger, DEBUG from enum import Enum from uuid import uuid4 from datetime import datetime, UTC, timedelta -from typing import Optional, List +from typing import Optional, List, Union +from fastapi import HTTPException from sqlmodel import SQLModel, Field, Relationship from sqlalchemy import Column, Enum as SQLAlchemyEnum +from sqlalchemy.orm import Mapped + +logger = getLogger("uvicorn.error") +logger.setLevel(DEBUG) + + +# --- Helper functions --- def utc_time(): return datetime.now(UTC) +# --- Custom exceptions --- + + +class DataIntegrityError(HTTPException): + def __init__( + self, + resource: str = "Database resource" + ): + super().__init__( + status_code=500, + detail=( + f"{resource} is in a broken state; please contact a system administrator" + ) + ) + + +# --- Database models --- + + default_roles = ["Owner", "Administrator", "Member"] @@ -24,89 +52,179 @@ class ValidPermissions(Enum): EDIT_ROLE = "Edit Role" -class Organization(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - name: str - created_at: datetime = Field(default_factory=utc_time) - updated_at: datetime = Field(default_factory=utc_time) +class UserRoleLink(SQLModel, table=True): + """ + Associates users with roles. This creates a many-to-many relationship + between users and roles. + """ + user_id: Optional[int] = Field(foreign_key="user.id", primary_key=True) + role_id: Optional[int] = Field(foreign_key="role.id", primary_key=True) + - users: List["User"] = Relationship(back_populates="organization") +class RolePermissionLink(SQLModel, table=True): + role_id: Optional[int] = Field(foreign_key="role.id", primary_key=True) + permission_id: Optional[int] = Field( + foreign_key="permission.id", primary_key=True) -class Role(SQLModel, table=True): +class Permission(SQLModel, table=True): + """ + Represents a permission that can be assigned to a role. Should not be + modified unless the application logic and ValidPermissions enum change. + """ id: Optional[int] = Field(default=None, primary_key=True) - name: str - organization_id: Optional[int] = Field( - default=None, foreign_key="organization.id") + name: ValidPermissions = Field( + sa_column=Column(SQLAlchemyEnum(ValidPermissions, create_type=False))) created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - users: List["User"] = Relationship(back_populates="role") - role_permission_links: List["RolePermissionLink"] = Relationship( - back_populates="role") + roles: Mapped[List["Role"]] = Relationship( + back_populates="permissions", + link_model=RolePermissionLink + ) -class Permission(SQLModel, table=True): +class Organization(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) - name: ValidPermissions = Field( - sa_column=Column(SQLAlchemyEnum(ValidPermissions, create_type=False))) + name: str created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - role_permission_links: List["RolePermissionLink"] = Relationship( - back_populates="permission") + roles: Mapped[List["Role"]] = Relationship( + back_populates="organization", + sa_relationship_kwargs={ + "cascade": "all, delete-orphan" + } + ) + @property + def users(self) -> List["User"]: + """ + Returns all users in the organization via their roles. + """ + users = [] + # Track user IDs to ensure uniqueness + user_ids = set() + for role in self.roles: + for user in role.users: + if user.id not in user_ids: + users.append(user) + user_ids.add(user.id) + return users -class RolePermissionLink(SQLModel, table=True): + +class Role(SQLModel, table=True): + """ + Represents a role within an organization. + + Attributes: + id: Primary key. + name: The name of the role. + organization_id: Foreign key to the associated organization. + created_at: Timestamp when the role was created. + updated_at: Timestamp when the role was last updated. + """ id: Optional[int] = Field(default=None, primary_key=True) - role_id: Optional[int] = Field( - default=None, foreign_key="role.id") - permission_id: Optional[int] = Field( - default=None, foreign_key="permission.id") + name: str + organization_id: int = Field( + foreign_key="organization.id") + created_at: datetime = Field(default_factory=utc_time) + updated_at: datetime = Field(default_factory=utc_time) - role: Optional["Role"] = Relationship( - back_populates="role_permission_links") - permission: Optional["Permission"] = Relationship( - back_populates="role_permission_links") + organization: Mapped[Organization] = Relationship(back_populates="roles") + users: Mapped[List["User"]] = Relationship( + back_populates="roles", + link_model=UserRoleLink + ) + permissions: Mapped[List["Permission"]] = Relationship( + back_populates="roles", + link_model=RolePermissionLink + ) class PasswordResetToken(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) - user_id: Optional[int] = Field(default=None, foreign_key="user.id") + user_id: Optional[int] = Field(foreign_key="user.id") token: str = Field(default_factory=lambda: str( uuid4()), index=True, unique=True) expires_at: datetime = Field( default_factory=lambda: datetime.now(UTC) + timedelta(hours=1)) used: bool = Field(default=False) - user: Optional["User"] = Relationship( + user: Mapped[Optional["User"]] = Relationship( back_populates="password_reset_tokens") + def is_expired(self) -> bool: + """ + Check if the token has expired + """ + return datetime.now(UTC) > self.expires_at.replace(tzinfo=UTC) + +class UserPassword(SQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + user_id: Optional[int] = Field(foreign_key="user.id", unique=True) + hashed_password: str + + user: Mapped[Optional["User"]] = Relationship( + back_populates="password", + sa_relationship_kwargs={ + "cascade": "all, delete-orphan", + "single_parent": True + } + ) + + +# TODO: Prevent deleting a user who is sole owner of an organization class User(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) name: str email: str = Field(index=True, unique=True) - hashed_password: str avatar_url: Optional[str] = None - organization_id: Optional[int] = Field( - default=None, foreign_key="organization.id") - role_id: Optional[int] = Field(default=None, foreign_key="role.id") created_at: datetime = Field(default_factory=utc_time) updated_at: datetime = Field(default_factory=utc_time) - organization: Optional["Organization"] = Relationship( - back_populates="users") - role: Optional["Role"] = Relationship(back_populates="users") - password_reset_tokens: List["PasswordResetToken"] = Relationship( + roles: Mapped[List[Role]] = Relationship( + back_populates="users", + link_model=UserRoleLink + ) + password_reset_tokens: Mapped[List["PasswordResetToken"]] = Relationship( back_populates="user", - sa_relationship_kwargs={"cascade": "all, delete-orphan"} + sa_relationship_kwargs={ + "cascade": "all, delete-orphan" + } + ) + password: Mapped[Optional[UserPassword]] = Relationship( + back_populates="user" ) - -class UserOrganizationLink(SQLModel, table=True): - id: Optional[int] = Field(default=None, primary_key=True) - user_id: Optional[int] = Field( - default=None, foreign_key="user.id") - organization_id: Optional[int] = Field( - default=None, foreign_key="organization.id") + @property + def organizations(self) -> List[Organization]: + """ + Returns all organizations the user belongs to via their roles. + """ + organizations = [] + organization_ids = set() + for role in self.roles: + if role.organization_id not in organization_ids: + organizations.append(role.organization) + organization_ids.add(role.organization_id) + return organizations + + def has_permission(self, permission: ValidPermissions, organization: Union[Organization, int]) -> bool: + """ + Check if the user has a specific permission for a given organization. + """ + organization_id: Optional[int] = None + if isinstance(organization, Organization): + organization_id = organization.id + else: + organization_id = organization + + if not organization_id: + raise DataIntegrityError(resource="Organization ID") + + for role in self.roles: + if role.organization_id == organization_id: + return permission in [perm.name for perm in role.permissions] + return False