Skip to content

Commit 138c515

Browse files
Passing unit tests for models.py
1 parent 694bb2f commit 138c515

File tree

10 files changed

+95
-48
lines changed

10 files changed

+95
-48
lines changed

main.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
from fastapi.staticfiles import StaticFiles
77
from fastapi.templating import Jinja2Templates
88
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
9-
from routers import authentication, organization, role, user, dashboard, terms_of_service, privacy_policy, about
10-
from utils.auth import (
11-
NeedsNewTokens,
12-
PasswordValidationError,
13-
AuthenticationError,
9+
from routers import about, account, dashboard, organization, privacy_policy, role, terms_of_service, user
10+
from utils.dependencies import (
1411
get_optional_user
1512
)
13+
from exceptions.http_exceptions import (
14+
AuthenticationError,
15+
PasswordValidationError
16+
)
17+
from exceptions.exceptions import (
18+
NeedsNewTokens
19+
)
1620
from utils.db import set_up_db
1721
from utils.models import User
1822

@@ -39,14 +43,14 @@ async def lifespan(app: FastAPI):
3943
# --- Include Routers ---
4044

4145

42-
app.include_router(authentication.router)
46+
app.include_router(account.router)
47+
app.include_router(about.router)
48+
app.include_router(dashboard.router)
4349
app.include_router(organization.router)
50+
app.include_router(privacy_policy.router)
4451
app.include_router(role.router)
45-
app.include_router(user.router)
46-
app.include_router(dashboard.router)
4752
app.include_router(terms_of_service.router)
48-
app.include_router(privacy_policy.router)
49-
app.include_router(about.router)
53+
app.include_router(user.router)
5054

5155

5256
# --- Exception Handling Middlewares ---

routers/about.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
from fastapi import APIRouter, Depends, Request
33
from fastapi.templating import Jinja2Templates
4-
from utils.auth import get_optional_user
4+
from utils.dependencies import get_optional_user
55
from utils.models import User
66

77
router = APIRouter(prefix="/about", tags=["about"])

routers/dashboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
from fastapi import APIRouter, Depends, Request
33
from fastapi.templating import Jinja2Templates
4-
from utils.auth import get_user_with_relations
4+
from utils.dependencies import get_user_with_relations
55
from utils.models import User
66

77
router = APIRouter(prefix="/dashboard", tags=["dashboard"])

routers/organization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from fastapi.templating import Jinja2Templates
66
from pydantic import BaseModel, ConfigDict, field_validator
77
from sqlmodel import Session, select
8-
from utils.db import get_session
9-
from utils.auth import get_authenticated_user, get_user_with_relations
10-
from utils.models import Organization, User, Role, utc_time, default_roles, ValidPermissions
8+
from utils.db import get_session, default_roles
9+
from utils.dependencies import get_authenticated_user, get_user_with_relations
10+
from utils.models import Organization, User, Role, utc_time
11+
from utils.enums import ValidPermissions
1112
from exceptions.http_exceptions import EmptyOrganizationNameError, OrganizationNotFoundError, OrganizationNameTakenError, InsufficientPermissionsError
1213

1314
logger = getLogger("uvicorn.error")

routers/privacy_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
from fastapi import APIRouter, Depends, Request
33
from fastapi.templating import Jinja2Templates
4-
from utils.auth import get_optional_user
4+
from utils.dependencies import get_optional_user
55
from utils.models import User
66

77
router = APIRouter(prefix="/privacy_policy", tags=["privacy_policy"])

routers/role.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlmodel import Session, select, col
99
from sqlalchemy.orm import selectinload
1010
from utils.db import get_session
11-
from utils.auth import get_authenticated_user
11+
from utils.dependencies import get_authenticated_user
1212
from utils.models import Role, Permission, ValidPermissions, utc_time, User, DataIntegrityError
1313
from exceptions.http_exceptions import InsufficientPermissionsError, InvalidPermissionError, RoleAlreadyExistsError, RoleNotFoundError, RoleHasUsersError
1414

routers/terms_of_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional
22
from fastapi import APIRouter, Depends, Request
33
from fastapi.templating import Jinja2Templates
4-
from utils.auth import get_optional_user
4+
from utils.dependencies import get_optional_user
55
from utils.models import User
66

77
router = APIRouter(prefix="/terms_of_service", tags=["terms_of_service"])

routers/user.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from typing import Optional
55
from fastapi.templating import Jinja2Templates
66
from utils.models import UserBase, User, DataIntegrityError
7-
from utils.auth import get_session, get_authenticated_user
7+
from utils.db import get_session
8+
from utils.dependencies import get_authenticated_user
89
from utils.images import validate_and_process_image, MAX_FILE_SIZE, MIN_DIMENSION, MAX_DIMENSION, ALLOWED_CONTENT_TYPES
910

1011
router = APIRouter(prefix="/user", tags=["user"])

tests/conftest.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
import pytest
22
from typing import Generator
3-
from dotenv import load_dotenv
4-
from sqlmodel import create_engine, Session, select
3+
from sqlmodel import create_engine, Session, select, SQLModel
54
from sqlalchemy import Engine
65
from fastapi.testclient import TestClient
7-
from utils.db import get_connection_url, set_up_db, tear_down_db, get_session
8-
from utils.models import User, PasswordResetToken, Organization, Role, UserPassword
6+
from utils.db import get_session
7+
from utils.models import User, PasswordResetToken, Organization, Role, Account, Permission
98
from utils.auth import get_password_hash, create_access_token, create_refresh_token
9+
from utils.enums import ValidPermissions
1010
from main import app
1111

12-
load_dotenv()
13-
1412

1513
# Define a custom exception for test setup errors
1614
class SetupError(Exception):
@@ -26,9 +24,8 @@ def engine() -> Engine:
2624
Create a new SQLModel engine for the test database.
2725
Use an in-memory SQLite database for testing.
2826
"""
29-
engine = create_engine(
30-
get_connection_url()
31-
)
27+
# Use in-memory SQLite for testing
28+
engine = create_engine("sqlite:///:memory:")
3229
return engine
3330

3431

@@ -38,9 +35,22 @@ def set_up_database(engine) -> Generator[None, None, None]:
3835
Set up the test database before running the test suite.
3936
Drop all tables and recreate them to ensure a clean state.
4037
"""
41-
set_up_db(drop=True)
38+
# Create all tables in the in-memory database
39+
SQLModel.metadata.create_all(engine)
40+
41+
# Create permissions
42+
with Session(engine) as session:
43+
# Check if permissions already exist
44+
existing_permissions = session.exec(select(Permission)).all()
45+
if not existing_permissions:
46+
# Create all permissions from the ValidPermissions enum
47+
for permission in ValidPermissions:
48+
session.add(Permission(name=permission))
49+
session.commit()
50+
4251
yield
43-
tear_down_db()
52+
# Drop all tables
53+
SQLModel.metadata.drop_all(engine)
4454

4555

4656
@pytest.fixture
@@ -57,22 +67,37 @@ def clean_db(session: Session) -> None:
5767
"""
5868
Cleans up the database tables before each test.
5969
"""
60-
for model in (PasswordResetToken, User, Role, Organization):
70+
# Don't delete permissions as they are required for tests
71+
for model in (PasswordResetToken, User, Role, Organization, Account):
6172
for record in session.exec(select(model)).all():
6273
session.delete(record)
6374

6475
session.commit()
6576

6677

6778
@pytest.fixture()
68-
def test_user(session: Session) -> User:
79+
def test_account(session: Session) -> Account:
80+
"""
81+
Creates a test account in the database.
82+
"""
83+
account = Account(
84+
85+
hashed_password=get_password_hash("Test123!@#")
86+
)
87+
session.add(account)
88+
session.commit()
89+
session.refresh(account)
90+
return account
91+
92+
93+
@pytest.fixture()
94+
def test_user(session: Session, test_account: Account) -> User:
6995
"""
7096
Creates a test user in the database.
7197
"""
7298
user = User(
7399
name="Test User",
74-
75-
password=UserPassword(hashed_password=get_password_hash("Test123!@#"))
100+
account_id=test_account.id
76101
)
77102
session.add(user)
78103
session.commit()
@@ -95,7 +120,7 @@ def get_session_override():
95120

96121

97122
@pytest.fixture()
98-
def auth_client(session: Session, test_user: User) -> Generator[TestClient, None, None]:
123+
def auth_client(session: Session, test_account: Account) -> Generator[TestClient, None, None]:
99124
"""
100125
Provides a TestClient instance with valid authentication tokens.
101126
"""
@@ -106,8 +131,8 @@ def get_session_override():
106131
client = TestClient(app)
107132

108133
# Create and set valid tokens
109-
access_token = create_access_token({"sub": test_user.email})
110-
refresh_token = create_refresh_token({"sub": test_user.email})
134+
access_token = create_access_token({"sub": test_account.email})
135+
refresh_token = create_refresh_token({"sub": test_account.email})
111136

112137
client.cookies.set("access_token", access_token)
113138
client.cookies.set("refresh_token", refresh_token)

tests/test_models.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
ValidPermissions,
1010
User,
1111
UserRoleLink,
12-
PasswordResetToken
12+
PasswordResetToken,
13+
Account
1314
)
1415
from .conftest import SetupError
1516

@@ -139,13 +140,13 @@ def test_cascade_delete_organization(session: Session, test_user: User, test_org
139140
assert remaining_user.id == test_user.id
140141

141142

142-
def test_password_reset_token_cascade_delete(session: Session, test_user: User):
143+
def test_password_reset_token_cascade_delete(session: Session, test_account: Account):
143144
"""
144-
Test that password reset tokens are deleted when a user is deleted
145+
Test that password reset tokens are deleted when an account is deleted
145146
"""
146-
# Create reset tokens for the user
147-
token1 = PasswordResetToken(user_id=test_user.id)
148-
token2 = PasswordResetToken(user_id=test_user.id)
147+
# Create reset tokens for the account
148+
token1 = PasswordResetToken(account_id=test_account.id)
149+
token2 = PasswordResetToken(account_id=test_account.id)
149150
session.add(token1)
150151
session.add(token2)
151152
session.commit()
@@ -154,29 +155,29 @@ def test_password_reset_token_cascade_delete(session: Session, test_user: User):
154155
tokens = session.exec(select(PasswordResetToken)).all()
155156
assert len(tokens) == 2
156157

157-
# Delete the user
158-
session.delete(test_user)
158+
# Delete the account
159+
session.delete(test_account)
159160
session.commit()
160161

161162
# Verify tokens were cascade deleted
162163
remaining_tokens = session.exec(select(PasswordResetToken)).all()
163164
assert len(remaining_tokens) == 0
164165

165166

166-
def test_password_reset_token_is_expired(session: Session, test_user: User):
167+
def test_password_reset_token_is_expired(session: Session, test_account: Account):
167168
"""
168169
Test that password reset token expiration is properly set and checked
169170
"""
170171
# Create an expired token
171172
expired_token = PasswordResetToken(
172-
user_id=test_user.id,
173+
account_id=test_account.id,
173174
expires_at=datetime.now(UTC) - timedelta(hours=1)
174175
)
175176
session.add(expired_token)
176177

177178
# Create a valid token
178179
valid_token = PasswordResetToken(
179-
user_id=test_user.id,
180+
account_id=test_account.id,
180181
expires_at=datetime.now(UTC) + timedelta(hours=1)
181182
)
182183
session.add(valid_token)
@@ -228,3 +229,18 @@ def test_user_has_permission(session: Session, test_user: User, test_organizatio
228229
ValidPermissions.EDIT_ORGANIZATION, test_organization) is True
229230
assert test_user.has_permission(
230231
ValidPermissions.INVITE_USER, test_organization) is False
232+
233+
234+
def test_cascade_delete_account_deletes_user(session: Session, test_account: Account, test_user: User):
235+
"""
236+
Test that deleting an account cascades to delete the associated user
237+
"""
238+
# Verify the user exists
239+
assert session.exec(select(User).where(User.account_id == test_account.id)).first() is not None
240+
241+
# Delete the account
242+
session.delete(test_account)
243+
session.commit()
244+
245+
# Verify the user was cascade deleted
246+
assert session.exec(select(User).where(User.account_id == test_account.id)).first() is None

0 commit comments

Comments
 (0)