Skip to content

Commit 2587dcf

Browse files
Tests of cascade delete behaviors pass
1 parent 3f13345 commit 2587dcf

File tree

4 files changed

+178
-34
lines changed

4 files changed

+178
-34
lines changed

docs/customization.qmd

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ graph.write_png('static/schema.png')
274274
![Database Schema](static/schema.png)
275275

276276

277-
#### Database operations
277+
#### Database helpers
278278

279-
Database operations are handled by helper functions in `utils/db.py`. Key functions include:
279+
Database operations are facilitated by helper functions in `utils/db.py`. Key functions include:
280280

281281
- `set_up_db()`: Initializes the database schema and default data (which we do on every application start in `main.py`)
282282
- `get_connection_url()`: Creates a database connection URL from environment variables in `.env`
@@ -292,3 +292,30 @@ async def get_users(session: Session = Depends(get_session)):
292292
```
293293

294294
The session automatically handles transaction management, ensuring that database operations are atomic and consistent.
295+
296+
#### Cascade deletes
297+
298+
Cascade deletes (in which deleting a record from one table deletes related records from another table) can be handled at either the ORM level or the database level. This template handles cascade deletes at the ORM level, via SQLModel relationships. Inside a SQLModel `Relationship`, we set:
299+
300+
```python
301+
sa_relationship_kwargs={
302+
"cascade": "all, delete-orphan"
303+
}
304+
```
305+
306+
This tells SQLAlchemy to cascade all operations (e.g., `SELECT`, `INSERT`, `UPDATE`, `DELETE`) to the related table. Since this happens through the ORM, we need to be careful to do all our database operations through the ORM using supported syntax. That generally means loading database records into Python objects and then deleting those objects rather than deleting records in the database directly.
307+
308+
For example,
309+
310+
```python
311+
session.exec(delete(Role))
312+
```
313+
314+
will not trigger the cascade delete. Instead, we need to select the role objects and then delete them:
315+
316+
```python
317+
for role in session.exec(select(Role)).all():
318+
session.delete(role)
319+
```
320+
321+
This is slower than deleting the records directly, but it makes [many-to-many relationships](https://sqlmodel.tiangolo.com/tutorial/many-to-many/create-models-with-link/#create-the-tables) much easier to manage.

tests/conftest.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlmodel import create_engine, Session, select
44
from fastapi.testclient import TestClient
55
from utils.db import get_connection_url, set_up_db, tear_down_db, get_session
6-
from utils.models import User, PasswordResetToken, Organization
6+
from utils.models import User, PasswordResetToken, Organization, Role
77
from utils.auth import get_password_hash, create_access_token, create_refresh_token
88
from main import app
99

@@ -47,15 +47,9 @@ def clean_db(session: Session):
4747
"""
4848
Cleans up the database tables before each test.
4949
"""
50-
# Delete all PasswordResetTokens
51-
tokens = session.exec(select(PasswordResetToken)).all()
52-
for token in tokens:
53-
session.delete(token)
54-
55-
# Delete all Users
56-
users = session.exec(select(User)).all()
57-
for user in users:
58-
session.delete(user)
50+
for model in (PasswordResetToken, User, Role, Organization):
51+
for record in session.exec(select(model)).all():
52+
session.delete(record)
5953

6054
session.commit()
6155

tests/test_models.py

Lines changed: 130 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
RolePermissionLink,
77
Organization,
88
ValidPermissions,
9+
User,
10+
UserRoleLink,
11+
PasswordResetToken,
912
)
13+
from datetime import timedelta, datetime, UTC
1014

1115

1216
def test_permissions_persist_after_role_deletion(session: Session):
@@ -36,13 +40,8 @@ def test_permissions_persist_after_role_deletion(session: Session):
3640
edit_org_permission = next(
3741
p for p in all_permissions if p.name == ValidPermissions.EDIT_ORGANIZATION)
3842

39-
role_permission_link1 = RolePermissionLink(
40-
role_id=role.id, permission_id=delete_org_permission.id
41-
)
42-
role_permission_link2 = RolePermissionLink(
43-
role_id=role.id, permission_id=edit_org_permission.id
44-
)
45-
session.add_all([role_permission_link1, role_permission_link2])
43+
role.permissions.append(delete_org_permission)
44+
role.permissions.append(edit_org_permission)
4645
session.commit()
4746

4847
# Verify that RolePermissionLinks exist before deletion
@@ -62,3 +61,127 @@ def test_permissions_persist_after_role_deletion(session: Session):
6261
# Verify that RolePermissionLinks were cascade deleted
6362
remaining_role_permissions = session.exec(select(RolePermissionLink)).all()
6463
assert len(remaining_role_permissions) == 0
64+
65+
66+
def test_user_organizations_property(session: Session, test_user: User, test_organization: Organization):
67+
"""
68+
Test that User.organizations property correctly returns all organizations
69+
the user belongs to via their roles.
70+
"""
71+
# Create a role in the test organization
72+
role = Role(name="Test Role", organization_id=test_organization.id)
73+
session.add(role)
74+
75+
# Link the user to the role
76+
test_user.roles.append(role)
77+
session.commit()
78+
79+
# Refresh the user to ensure relationships are loaded
80+
session.refresh(test_user)
81+
82+
# Test the organizations property
83+
assert len(test_user.organizations) == 1
84+
assert test_user.organizations[0].id == test_organization.id
85+
86+
87+
def test_organization_users_property(session: Session, test_user: User, test_organization: Organization):
88+
"""
89+
Test that Organization.users property correctly returns all users
90+
in the organization via their roles.
91+
"""
92+
# Create a role in the test organization
93+
role = Role(name="Test Role", organization_id=test_organization.id)
94+
session.add(role)
95+
session.commit()
96+
97+
# Link the user to the role
98+
test_user.roles.append(role)
99+
session.commit()
100+
101+
# Refresh the organization to ensure relationships are loaded
102+
session.refresh(test_organization)
103+
104+
# Test the users property
105+
users_list = test_organization.users
106+
assert len(users_list) == 1
107+
# users_list is a list of lists due to the property implementation
108+
assert test_user in users_list[0]
109+
110+
111+
def test_cascade_delete_organization(session: Session, test_user: User, test_organization: Organization):
112+
"""
113+
Test that deleting an organization cascades properly:
114+
- Deletes associated roles
115+
- Deletes role-user links
116+
- Does not delete users
117+
"""
118+
# Create a role in the test organization
119+
role = Role(name="Test Role", organization_id=test_organization.id)
120+
session.add(role)
121+
test_user.roles.append(role)
122+
session.commit()
123+
124+
# Delete the organization
125+
session.delete(test_organization)
126+
session.commit()
127+
128+
# Verify the role was deleted
129+
remaining_roles = session.exec(select(Role)).all()
130+
assert len(remaining_roles) == 0
131+
132+
# Verify the user-role link was deleted
133+
remaining_links = session.exec(select(UserRoleLink)).all()
134+
assert len(remaining_links) == 0
135+
136+
# Verify the user still exists
137+
remaining_user = session.exec(select(User)).first()
138+
assert remaining_user is not None
139+
assert remaining_user.id == test_user.id
140+
141+
142+
def test_password_reset_token_cascade_delete(session: Session, test_user: User):
143+
"""
144+
Test that password reset tokens are deleted when a user is deleted
145+
"""
146+
# Create reset tokens for the user
147+
token1 = PasswordResetToken(user_id=test_user.id)
148+
token2 = PasswordResetToken(user_id=test_user.id)
149+
session.add(token1)
150+
session.add(token2)
151+
session.commit()
152+
153+
# Verify tokens exist
154+
tokens = session.exec(select(PasswordResetToken)).all()
155+
assert len(tokens) == 2
156+
157+
# Delete the user
158+
session.delete(test_user)
159+
session.commit()
160+
161+
# Verify tokens were cascade deleted
162+
remaining_tokens = session.exec(select(PasswordResetToken)).all()
163+
assert len(remaining_tokens) == 0
164+
165+
166+
def test_password_reset_token_is_expired(session: Session, test_user: User):
167+
"""
168+
Test that password reset token expiration is properly set and checked
169+
"""
170+
# Create an expired token
171+
expired_token = PasswordResetToken(
172+
user_id=test_user.id,
173+
expires_at=datetime.now(UTC) - timedelta(hours=1)
174+
)
175+
session.add(expired_token)
176+
177+
# Create a valid token
178+
valid_token = PasswordResetToken(
179+
user_id=test_user.id,
180+
expires_at=datetime.now(UTC) + timedelta(hours=1)
181+
)
182+
session.add(valid_token)
183+
session.commit()
184+
185+
# Verify expiration states
186+
assert expired_token.is_expired()
187+
assert not valid_token.is_expired()

utils/models.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,14 @@ class UserRoleLink(SQLModel, table=True):
3131
Associates users with roles. This creates a many-to-many relationship
3232
between users and roles.
3333
"""
34-
id: Optional[int] = Field(default=None, primary_key=True)
35-
user_id: int = Field(foreign_key="user.id")
36-
role_id: int = Field(foreign_key="role.id")
37-
created_at: datetime = Field(default_factory=utc_time)
38-
updated_at: datetime = Field(default_factory=utc_time)
34+
user_id: Optional[int] = Field(foreign_key="user.id", primary_key=True)
35+
role_id: Optional[int] = Field(foreign_key="role.id", primary_key=True)
3936

4037

4138
class RolePermissionLink(SQLModel, table=True):
42-
id: Optional[int] = Field(default=None, primary_key=True)
43-
role_id: int = Field(foreign_key="role.id")
44-
permission_id: int = Field(foreign_key="permission.id")
45-
created_at: datetime = Field(default_factory=utc_time)
46-
updated_at: datetime = Field(default_factory=utc_time)
39+
role_id: Optional[int] = Field(foreign_key="role.id", primary_key=True)
40+
permission_id: Optional[int] = Field(
41+
foreign_key="permission.id", primary_key=True)
4742

4843

4944
class Permission(SQLModel, table=True):
@@ -72,8 +67,7 @@ class Organization(SQLModel, table=True):
7267
roles: List["Role"] = Relationship(
7368
back_populates="organization",
7469
sa_relationship_kwargs={
75-
"cascade": "all, delete-orphan",
76-
"passive_deletes": True
70+
"cascade": "all, delete-orphan"
7771
}
7872
)
7973

@@ -98,7 +92,8 @@ class Role(SQLModel, table=True):
9892
"""
9993
id: Optional[int] = Field(default=None, primary_key=True)
10094
name: str
101-
organization_id: int = Field(foreign_key="organization.id")
95+
organization_id: int = Field(
96+
foreign_key="organization.id")
10297
created_at: datetime = Field(default_factory=utc_time)
10398
updated_at: datetime = Field(default_factory=utc_time)
10499

@@ -125,6 +120,12 @@ class PasswordResetToken(SQLModel, table=True):
125120
user: Optional["User"] = Relationship(
126121
back_populates="password_reset_tokens")
127122

123+
def is_expired(self) -> bool:
124+
"""
125+
Check if the token has expired
126+
"""
127+
return datetime.now(UTC) > self.expires_at.replace(tzinfo=UTC)
128+
128129

129130
# TODO: Prevent deleting a user who is sole owner of an organization
130131
class User(SQLModel, table=True):
@@ -143,8 +144,7 @@ class User(SQLModel, table=True):
143144
password_reset_tokens: List["PasswordResetToken"] = Relationship(
144145
back_populates="user",
145146
sa_relationship_kwargs={
146-
"cascade": "all, delete-orphan",
147-
"passive_deletes": True
147+
"cascade": "all, delete-orphan"
148148
}
149149
)
150150

0 commit comments

Comments
 (0)