Skip to content

Commit a373df7

Browse files
Eagerly load roles, orgs, and permissions with user in endpoints that need them
1 parent 91bc397 commit a373df7

File tree

5 files changed

+95
-251
lines changed

5 files changed

+95
-251
lines changed

main.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
from fastapi.exceptions import RequestValidationError, HTTPException, StarletteHTTPException
99
from sqlmodel import Session
1010
from routers import authentication, organization, role, user
11-
from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError, AuthenticationError
11+
from utils.auth import get_authenticated_user, get_user_with_relations, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError, AuthenticationError
1212
from utils.models import User
1313
from utils.db import get_session, set_up_db
14-
from utils.role_org import get_user_organizations, get_organization_roles
1514

1615
logger = logging.getLogger("uvicorn.error")
1716
logger.setLevel(logging.DEBUG)
@@ -236,28 +235,27 @@ async def common_authenticated_parameters(
236235
return {"request": request, "user": user, "error_message": error_message}
237236

238237

238+
async def common_authenticated_parameters_with_organizations(
239+
request: Request,
240+
user: User = Depends(get_user_with_relations),
241+
error_message: Optional[str] = None
242+
) -> dict:
243+
return {"request": request, "user": user, "error_message": error_message}
244+
245+
239246
# Redirect to home if user is not authenticated
240247
@app.get("/dashboard")
241248
async def read_dashboard(
242249
params: dict = Depends(common_authenticated_parameters)
243250
):
244-
if not params["user"]:
245-
return RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND)
246251
return templates.TemplateResponse(params["request"], "dashboard/index.html", params)
247252

248253

249254
@app.get("/profile")
250255
async def read_profile(
251-
params: dict = Depends(common_authenticated_parameters),
252-
session: Session = Depends(get_session)
256+
params: dict = Depends(common_authenticated_parameters_with_organizations)
253257
):
254-
if not params["user"]:
255-
return RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND)
256-
257-
# Get user's organizations
258-
params["organizations"] = get_user_organizations(
259-
params["user"].id, session)
260-
258+
params["organizations"] = params["user"].organizations
261259
return templates.TemplateResponse(params["request"], "users/profile.html", params)
262260

263261

routers/organization.py

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
from pydantic import BaseModel, ConfigDict, field_validator
55
from sqlmodel import Session, select
66
from utils.db import get_session
7-
from utils.auth import get_authenticated_user
8-
from utils.models import Organization, User, Role, Permission, UserOrganizationLink, ValidPermissions, utc_time
7+
from utils.auth import get_authenticated_user, get_user_with_relations
8+
from utils.models import Organization, User, Role, utc_time, default_roles
99
from datetime import datetime
10-
from sqlalchemy import and_
11-
from utils.role_org import get_organization, check_user_permission
1210

1311
logger = getLogger("uvicorn.error")
1412

@@ -23,14 +21,6 @@ def __init__(self):
2321
)
2422

2523

26-
class OrganizationExistsError(HTTPException):
27-
def __init__(self):
28-
super().__init__(
29-
status_code=400,
30-
detail="Organization already exists"
31-
)
32-
33-
3424
class OrganizationNotFoundError(HTTPException):
3525
def __init__(self):
3626
super().__init__(
@@ -109,64 +99,49 @@ def create_organization(
10999
user: User = Depends(get_authenticated_user),
110100
session: Session = Depends(get_session)
111101
) -> RedirectResponse:
102+
# Check if organization already exists
112103
db_org = session.exec(select(Organization).where(
113104
Organization.name == org.name)).first()
114105
if db_org:
115-
raise OrganizationExistsError()
106+
raise OrganizationNameTakenError()
116107

108+
# Create organization first
117109
db_org = Organization(name=org.name)
118110
session.add(db_org)
119-
session.commit()
120-
session.refresh(db_org)
111+
# This gets us the org ID without committing
112+
session.flush()
121113

122-
# Create default roles
123-
default_role_names = ["Owner", "Administrator", "Member"]
124-
default_roles = []
125-
for role_name in default_role_names:
126-
role = Role(name=role_name, organization_id=db_org.id)
127-
session.add(role)
128-
default_roles.append(role)
129-
session.commit()
114+
# Create default roles with organization_id
115+
initial_roles = [
116+
Role(name=name, organization_id=db_org.id)
117+
for name in default_roles
118+
]
119+
session.add_all(initial_roles)
120+
session.flush()
130121

131-
owner_role = session.exec(
132-
select(Role).where(
133-
and_(
134-
Role.organization_id == db_org.id,
135-
Role.name == "Owner"
136-
)
137-
)
138-
).first()
122+
# Get owner role for user assignment
123+
owner_role = next(role for role in db_org.roles if role.name == "Owner")
139124

140-
if not owner_role:
141-
owner_role = Role(
142-
name="Owner",
143-
organization_id=db_org.id
144-
)
145-
session.add(owner_role)
146-
session.commit()
147-
session.refresh(owner_role)
148-
149-
user_org_link = UserOrganizationLink(
150-
user_id=user.id,
151-
organization_id=db_org.id,
152-
role_id=owner_role.id
153-
)
154-
session.add(user_org_link)
125+
# Assign user to owner role
126+
user.roles.append(owner_role)
127+
128+
# Commit changes
155129
session.commit()
130+
session.refresh(db_org)
156131

157132
return RedirectResponse(url=f"/profile", status_code=303)
158133

159134

160135
@router.put("/{org_id}", response_class=RedirectResponse)
161136
def update_organization(
162137
org: OrganizationUpdate = Depends(OrganizationUpdate.as_form),
163-
user: User = Depends(get_authenticated_user),
138+
user: User = Depends(get_user_with_relations),
164139
session: Session = Depends(get_session)
165140
) -> RedirectResponse:
166141
# This will raise appropriate exceptions if org doesn't exist or user lacks access
167-
organization = get_organization(org.id, user.id, session)
142+
organization: Organization = user.organizations.get(org.id)
168143

169-
if not check_user_permission(user.id, org.id, ValidPermissions.EDIT_ORGANIZATION, session):
144+
if not organization or not any(role.permissions.EDIT_ORGANIZATION for role in organization.roles):
170145
raise InsufficientPermissionsError()
171146

172147
# Check if new name already exists for another organization
@@ -178,24 +153,27 @@ def update_organization(
178153
if existing_org:
179154
raise OrganizationNameTakenError()
180155

156+
# Update organization name
181157
organization.name = org.name
182158
organization.updated_at = utc_time()
183159
session.add(organization)
184160
session.commit()
185-
session.refresh(organization)
186161

187162
return RedirectResponse(url=f"/profile", status_code=303)
188163

189164

190165
@router.delete("/{org_id}", response_class=RedirectResponse)
191166
def delete_organization(
192167
org_id: int,
193-
user: User = Depends(get_authenticated_user),
168+
user: User = Depends(get_user_with_relations),
194169
session: Session = Depends(get_session)
195170
) -> RedirectResponse:
196-
# This will raise appropriate exceptions if org doesn't exist or user lacks access
197-
organization = get_organization(org_id, user.id, session)
171+
# Check if user has permission to delete organization
172+
organization: Organization = user.organizations.get(org_id)
173+
if not organization or not any(role.permissions.DELETE_ORGANIZATION for role in organization.roles):
174+
raise InsufficientPermissionsError()
198175

176+
# Delete organization
199177
session.delete(organization)
200178
session.commit()
201179

utils/auth.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from dotenv import load_dotenv
99
from pydantic import field_validator, ValidationInfo
1010
from sqlmodel import Session, select
11+
from sqlalchemy.orm import selectinload
1112
from bcrypt import gensalt, hashpw, checkpw
1213
from datetime import UTC, datetime, timedelta
1314
from typing import Optional
1415
from fastapi import Depends, Cookie, HTTPException, status
1516
from utils.db import get_session
16-
from utils.models import User, PasswordResetToken
17+
from utils.models import User, Role, PasswordResetToken
1718

1819
load_dotenv()
1920
logger = logging.getLogger("uvicorn.error")
@@ -339,3 +340,23 @@ def get_user_from_reset_token(email: str, token: str, session: Session) -> tuple
339340

340341
user, reset_token = result
341342
return user, reset_token
343+
344+
345+
def get_user_with_relations(
346+
user: User = Depends(get_authenticated_user),
347+
session: Session = Depends(get_session),
348+
) -> User:
349+
"""
350+
Returns an authenticated user with fully loaded role and organization relationships.
351+
"""
352+
# Refresh the user instance with eagerly loaded relationships
353+
eager_user = session.exec(
354+
select(User)
355+
.where(User.id == user.id)
356+
.options(
357+
selectinload(User.roles).selectinload(Role.organization),
358+
selectinload(User.roles).selectinload(Role.permissions)
359+
)
360+
).one()
361+
362+
return eager_user

utils/models.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime, UTC, timedelta
44
from typing import Optional, List
55
from sqlmodel import SQLModel, Field, Relationship
6-
from sqlalchemy import Column, Enum as SQLAlchemyEnum, ForeignKey
6+
from sqlalchemy import Column, Enum as SQLAlchemyEnum
77

88

99
def utc_time():
@@ -13,6 +13,8 @@ def utc_time():
1313
default_roles = ["Owner", "Administrator", "Member"]
1414

1515

16+
# TODO: User with permission to create/edit roles can only assign permissions
17+
# they themselves have.
1618
class ValidPermissions(Enum):
1719
DELETE_ORGANIZATION = "Delete Organization"
1820
EDIT_ORGANIZATION = "Edit Organization"
@@ -24,24 +26,17 @@ class ValidPermissions(Enum):
2426
EDIT_ROLE = "Edit Role"
2527

2628

27-
class UserOrganizationLink(SQLModel, table=True):
29+
class UserRoleLink(SQLModel, table=True):
30+
"""
31+
Associates users with roles. This creates a many-to-many relationship
32+
between users and roles.
33+
"""
2834
id: Optional[int] = Field(default=None, primary_key=True)
2935
user_id: int = Field(foreign_key="user.id")
30-
organization_id: int = Field(foreign_key="organization.id")
3136
role_id: int = Field(foreign_key="role.id")
3237
created_at: datetime = Field(default_factory=utc_time)
3338
updated_at: datetime = Field(default_factory=utc_time)
3439

35-
user: "User" = Relationship(
36-
back_populates="organization_links"
37-
)
38-
organization: "Organization" = Relationship(
39-
back_populates="user_links"
40-
)
41-
role: "Role" = Relationship(
42-
back_populates="user_links"
43-
)
44-
4540

4641
class RolePermissionLink(SQLModel, table=True):
4742
id: Optional[int] = Field(default=None, primary_key=True)
@@ -74,14 +69,20 @@ class Organization(SQLModel, table=True):
7469
created_at: datetime = Field(default_factory=utc_time)
7570
updated_at: datetime = Field(default_factory=utc_time)
7671

77-
user_links: List[UserOrganizationLink] = Relationship(
72+
roles: List["Role"] = Relationship(
7873
back_populates="organization",
7974
sa_relationship_kwargs={
8075
"cascade": "all, delete-orphan",
8176
"passive_deletes": True
8277
}
8378
)
84-
roles: List["Role"] = Relationship(back_populates="organization")
79+
80+
@property
81+
def users(self) -> List["User"]:
82+
"""
83+
Returns all users in the organization via their roles.
84+
"""
85+
return [role.users for role in self.roles]
8586

8687

8788
class Role(SQLModel, table=True):
@@ -101,9 +102,11 @@ class Role(SQLModel, table=True):
101102
created_at: datetime = Field(default_factory=utc_time)
102103
updated_at: datetime = Field(default_factory=utc_time)
103104

104-
organization: "Organization" = Relationship(back_populates="roles")
105-
user_links: List[UserOrganizationLink] = Relationship(
106-
back_populates="role")
105+
organization: Organization = Relationship(back_populates="roles")
106+
users: List["User"] = Relationship(
107+
back_populates="roles",
108+
link_model=UserRoleLink
109+
)
107110
permissions: List["Permission"] = Relationship(
108111
back_populates="roles",
109112
link_model=RolePermissionLink
@@ -133,12 +136,9 @@ class User(SQLModel, table=True):
133136
created_at: datetime = Field(default_factory=utc_time)
134137
updated_at: datetime = Field(default_factory=utc_time)
135138

136-
organization_links: List[UserOrganizationLink] = Relationship(
137-
back_populates="user",
138-
sa_relationship_kwargs={
139-
"cascade": "all, delete-orphan",
140-
"passive_deletes": True
141-
}
139+
roles: List[Role] = Relationship(
140+
back_populates="users",
141+
link_model=UserRoleLink
142142
)
143143
password_reset_tokens: List["PasswordResetToken"] = Relationship(
144144
back_populates="user",
@@ -147,3 +147,10 @@ class User(SQLModel, table=True):
147147
"passive_deletes": True
148148
}
149149
)
150+
151+
@property
152+
def organizations(self) -> List[Organization]:
153+
"""
154+
Returns all organizations the user belongs to via their roles.
155+
"""
156+
return [role.organization for role in self.roles]

0 commit comments

Comments
 (0)