Skip to content

Commit f99af68

Browse files
Refactored db.py
1 parent eb58d4d commit f99af68

File tree

1 file changed

+106
-40
lines changed

1 file changed

+106
-40
lines changed

utils/db.py

Lines changed: 106 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,34 @@
11
import os
22
import logging
3+
from typing import Generator
34
from dotenv import load_dotenv
45
from sqlalchemy.engine import URL
56
from sqlmodel import create_engine, Session, SQLModel, select
67
from utils.models import Role, Permission, RolePermissionLink, default_roles, ValidPermissions
78

9+
# Load environment variables from a .env file
810
load_dotenv()
911

12+
# Set up a logger for error reporting
1013
logger = logging.getLogger("uvicorn.error")
1114

1215

13-
# --- Database connection ---
16+
# --- Database connection functions ---
1417

1518

1619
def get_connection_url() -> URL:
1720
"""
18-
Creates a SQLModel URL object containing the connection URL to the Postgres database.
19-
The connection details are obtained from environment variables.
20-
Returns the URL object.
21+
Constructs a SQLModel URL object for connecting to the PostgreSQL database.
22+
23+
The connection details are sourced from environment variables, which should include:
24+
- DB_USER: Database username
25+
- DB_PASSWORD: Database password
26+
- DB_HOST: Database host address
27+
- DB_PORT: Database port (default is 5432)
28+
- DB_NAME: Database name
29+
30+
Returns:
31+
URL: A SQLModel URL object containing the connection details.
2132
"""
2233
database_url: URL = URL.create(
2334
drivername="postgresql",
@@ -31,59 +42,111 @@ def get_connection_url() -> URL:
3142
return database_url
3243

3344

45+
# Create the database engine using the connection URL
3446
engine = create_engine(get_connection_url())
3547

3648

37-
def get_session():
49+
def get_session() -> Generator[Session, None, None]:
50+
"""
51+
Provides a database session for executing queries.
52+
53+
Yields:
54+
Session: A SQLModel session object for database operations.
55+
"""
3856
with Session(engine) as session:
3957
yield session
4058

4159

42-
def create_default_roles(session, organization_id: int, check_first: bool = True):
60+
def assign_permissions_to_role(session: Session, role: Role, permissions: list[Permission], check_first: bool = False) -> None:
61+
"""
62+
Assigns permissions to a role in the database.
63+
64+
Args:
65+
session (Session): The database session to use for operations.
66+
role (Role): The role to assign permissions to.
67+
permissions (list[Permission]): The list of permissions to assign.
68+
check_first (bool): If True, checks if the role already has the permission before assigning it.
69+
"""
70+
71+
for permission in permissions:
72+
# Check if the role already has the permission
73+
if check_first:
74+
db_role_permission_link: RolePermissionLink | None = session.exec(
75+
select(RolePermissionLink).where(
76+
RolePermissionLink.role_id == role.id,
77+
RolePermissionLink.permission_id == permission.id
78+
)
79+
).first()
80+
else:
81+
db_role_permission_link = None
82+
83+
# Skip granting DELETE_ORGANIZATION permission to the Administrator role
84+
if not db_role_permission_link:
85+
role_permission_link = RolePermissionLink(
86+
role_id=role.id,
87+
permission_id=permission.id
88+
)
89+
session.add(role_permission_link)
90+
91+
92+
def create_default_roles(session: Session, organization_id: int, check_first: bool = True) -> list:
4393
"""
44-
Create default roles for an organization in the database if they do not exist.
94+
Creates default roles for a specified organization in the database if they do not already exist,
95+
and assigns permissions to the Owner and Administrator roles.
96+
97+
Args:
98+
session (Session): The database session to use for operations.
99+
organization_id (int): The ID of the organization for which to create roles.
100+
check_first (bool): If True, checks if the role already exists before creating it.
101+
102+
Returns:
103+
list: A list of roles that were created or already existed in the database.
45104
"""
105+
46106
roles_in_db = []
47107
for role_name in default_roles:
48-
db_role = session.exec(select(Role).where(
49-
Role.name == role_name,
50-
Role.organization_id == organization_id
51-
)).first()
108+
db_role = session.exec(
109+
select(Role).where(
110+
Role.name == role_name,
111+
Role.organization_id == organization_id
112+
)
113+
).first()
52114
if not db_role:
53115
db_role = Role(name=role_name, organization_id=organization_id)
54116
session.add(db_role)
55117
roles_in_db.append(db_role)
56118

57-
# Create RolePermissionLink for Owner and Administrator roles
58-
for role in roles_in_db[:2]:
59-
permissions = session.exec(select(Permission)).all()
60-
for permission in permissions:
61-
# Check if the role already has the permission
62-
if check_first:
63-
db_role_permission_link: RolePermissionLink | None = session.exec(select(RolePermissionLink).where(
64-
RolePermissionLink.role_id == role.id,
65-
RolePermissionLink.permission_id == permission.id
66-
)).first()
67-
else:
68-
db_role_permission_link = None
69-
70-
# Skip giving DELETE_ORGANIZATION permission to Administrator
71-
if not db_role_permission_link and not (
72-
permission == ValidPermissions.DELETE_ORGANIZATION and
73-
role.name == "Administrator"
74-
):
75-
role_permission_link = RolePermissionLink(
76-
role_id=role.id,
77-
permission_id=permission.id
78-
)
79-
session.add(role_permission_link)
119+
# TODO: Construct this role-permission mapping once at app startup and use as constant
120+
# Fetch all permissions once
121+
owner_permissions = session.exec(select(Permission)).all()
122+
admin_permissions = [
123+
permission for permission in owner_permissions
124+
if permission.name != ValidPermissions.DELETE_ORGANIZATION
125+
]
126+
127+
# Get Owner and Administrator roles by name
128+
owner_role = next(role for role in roles_in_db if role.name == "Owner")
129+
admin_role = next(
130+
role for role in roles_in_db if role.name == "Administrator")
131+
132+
# Assign all permissions to Owner
133+
assign_permissions_to_role(
134+
session, owner_role, owner_permissions, check_first=check_first)
80135

136+
# Assign filtered permissions to Administrator
137+
assign_permissions_to_role(
138+
session, admin_role, admin_permissions, check_first=check_first)
139+
140+
session.commit()
81141
return roles_in_db
82142

83143

84-
def create_permissions(session):
144+
def create_permissions(session: Session) -> None:
85145
"""
86-
Create default permissions.
146+
Creates default permissions in the database if they do not already exist.
147+
148+
Args:
149+
session (Session): The database session to use for operations.
87150
"""
88151
for permission in ValidPermissions:
89152
db_permission = session.exec(select(Permission).where(
@@ -93,9 +156,12 @@ def create_permissions(session):
93156
session.add(db_permission)
94157

95158

96-
def set_up_db(drop: bool = False):
159+
def set_up_db(drop: bool = False) -> None:
97160
"""
98-
Set up the database by creating tables and populating them with default roles and permissions.
161+
Sets up the database by creating tables and populating them with default permissions.
162+
163+
Args:
164+
drop (bool): If True, drops all existing tables before creating new ones.
99165
"""
100166
engine = create_engine(get_connection_url())
101167
if drop:
@@ -108,9 +174,9 @@ def set_up_db(drop: bool = False):
108174
engine.dispose()
109175

110176

111-
def tear_down_db():
177+
def tear_down_db() -> None:
112178
"""
113-
Tear down the database by dropping all tables.
179+
Tears down the database by dropping all tables.
114180
"""
115181
engine = create_engine(get_connection_url())
116182
SQLModel.metadata.drop_all(engine)

0 commit comments

Comments
 (0)