Skip to content

Commit 8b3e3f7

Browse files
Merge pull request #19 from Promptly-Technologies-LLC/18-add-a-test-suite
Starter unit test suite
2 parents 94c1726 + 9072ac4 commit 8b3e3f7

File tree

13 files changed

+592
-176
lines changed

13 files changed

+592
-176
lines changed

main.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from sqlmodel import Session
1212
from routers import authentication, organization, role, user
1313
from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError
14-
from utils.db import User, get_session
14+
from utils.models import User
15+
from utils.db import get_session, set_up_db
1516

1617

1718
logger = logging.getLogger("uvicorn.error")
@@ -21,6 +22,7 @@
2122
@asynccontextmanager
2223
async def lifespan(app: FastAPI):
2324
# Optional startup logic
25+
set_up_db(drop=False)
2426
yield
2527
# Optional shutdown logic
2628

@@ -63,9 +65,9 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):
6365
@app.exception_handler(PasswordValidationError)
6466
async def password_validation_exception_handler(request: Request, exc: PasswordValidationError):
6567
return templates.TemplateResponse(
68+
request,
6669
"errors/validation_error.html",
6770
{
68-
"request": request,
6971
"status_code": 422,
7072
"errors": {"error": exc.detail}
7173
},
@@ -91,9 +93,9 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
9193
errors[field_name] = error["msg"]
9294

9395
return templates.TemplateResponse(
96+
request,
9497
"errors/validation_error.html",
9598
{
96-
"request": request,
9799
"status_code": 422,
98100
"errors": errors
99101
},
@@ -109,8 +111,9 @@ async def http_exception_handler(request: Request, exc: StarletteHTTPException):
109111
raise exc
110112

111113
return templates.TemplateResponse(
114+
request,
112115
"errors/error.html",
113-
{"request": request, "status_code": exc.status_code, "detail": exc.detail},
116+
{"status_code": exc.status_code, "detail": exc.detail},
114117
status_code=exc.status_code,
115118
)
116119

@@ -122,9 +125,9 @@ async def general_exception_handler(request: Request, exc: Exception):
122125
logger.error(f"Unhandled exception: {exc}", exc_info=True)
123126

124127
return templates.TemplateResponse(
128+
request,
125129
"errors/error.html",
126130
{
127-
"request": request,
128131
"status_code": 500,
129132
"detail": "Internal Server Error"
130133
},
@@ -150,7 +153,7 @@ async def read_home(
150153
):
151154
if params["user"]:
152155
return RedirectResponse(url="/dashboard", status_code=302)
153-
return templates.TemplateResponse("index.html", params)
156+
return templates.TemplateResponse(params["request"], "index.html", params)
154157

155158

156159
@app.get("/login")
@@ -159,7 +162,7 @@ async def read_login(
159162
):
160163
if params["user"]:
161164
return RedirectResponse(url="/dashboard", status_code=302)
162-
return templates.TemplateResponse("authentication/login.html", params)
165+
return templates.TemplateResponse(params["request"], "authentication/login.html", params)
163166

164167

165168
@app.get("/register")
@@ -168,7 +171,7 @@ async def read_register(
168171
):
169172
if params["user"]:
170173
return RedirectResponse(url="/dashboard", status_code=302)
171-
return templates.TemplateResponse("authentication/register.html", params)
174+
return templates.TemplateResponse(params["request"], "authentication/register.html", params)
172175

173176

174177
@app.get("/forgot_password")
@@ -180,22 +183,22 @@ async def read_forgot_password(
180183
return RedirectResponse(url="/dashboard", status_code=302)
181184
params["show_form"] = show_form
182185

183-
return templates.TemplateResponse("authentication/forgot_password.html", params)
186+
return templates.TemplateResponse(params["request"], "authentication/forgot_password.html", params)
184187

185188

186189
@app.get("/about")
187190
async def read_about(params: dict = Depends(common_unauthenticated_parameters)):
188-
return templates.TemplateResponse("about.html", params)
191+
return templates.TemplateResponse(params["request"], "about.html", params)
189192

190193

191194
@app.get("/privacy_policy")
192195
async def read_privacy_policy(params: dict = Depends(common_unauthenticated_parameters)):
193-
return templates.TemplateResponse("privacy_policy.html", params)
196+
return templates.TemplateResponse(params["request"], "privacy_policy.html", params)
194197

195198

196199
@app.get("/terms_of_service")
197200
async def read_terms_of_service(params: dict = Depends(common_unauthenticated_parameters)):
198-
return templates.TemplateResponse("terms_of_service.html", params)
201+
return templates.TemplateResponse(params["request"], "terms_of_service.html", params)
199202

200203

201204
@app.get("/reset_password")
@@ -214,7 +217,7 @@ async def read_reset_password(
214217
params["email"] = email
215218
params["token"] = token
216219

217-
return templates.TemplateResponse("authentication/reset_password.html", params)
220+
return templates.TemplateResponse(params["request"], "authentication/reset_password.html", params)
218221

219222

220223
# -- Authenticated Routes --
@@ -236,7 +239,7 @@ async def read_dashboard(
236239
):
237240
if not params["user"]:
238241
return RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND)
239-
return templates.TemplateResponse("dashboard/index.html", params)
242+
return templates.TemplateResponse(params["request"], "dashboard/index.html", params)
240243

241244

242245
@app.get("/profile")
@@ -246,7 +249,7 @@ async def read_profile(
246249
if not params["user"]:
247250
# Changed to 302
248251
return RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND)
249-
return templates.TemplateResponse("users/profile.html", params)
252+
return templates.TemplateResponse(params["request"], "users/profile.html", params)
250253

251254

252255
# -- Include Routers --

poetry.lock

Lines changed: 83 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pyjwt = "^2.10.0"
1313
jinja2 = "^3.1.4"
1414
uvicorn = "^0.32.0"
1515
psycopg2 = "^2.9.10"
16-
pydantic = "^2.9.2"
16+
pydantic = {extras = ["email"], version = "^2.9.2"}
1717
python-multipart = "^0.0.17"
1818
python-dotenv = "^1.0.1"
1919
resend = "^2.4.0"
@@ -27,6 +27,7 @@ quarto = "^0.1.0"
2727
mypy = "^1.11.2"
2828
jupyter = "^1.1.1"
2929
notebook = "^7.2.2"
30+
pytest = "^8.3.3"
3031

3132
[build-system]
3233
requires = ["poetry-core"]

routers/authentication.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi.responses import RedirectResponse
77
from pydantic import BaseModel, EmailStr, ConfigDict
88
from sqlmodel import Session, select
9-
from utils.db import User
9+
from utils.models import User
1010
from utils.auth import (
1111
get_session,
1212
get_user_from_reset_token,

routers/organization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from fastapi.responses import RedirectResponse
44
from pydantic import BaseModel, ConfigDict
55
from sqlmodel import Session, select
6-
from utils.db import Organization, get_session
6+
from utils.db import get_session
7+
from utils.models import Organization
78
from datetime import datetime
89

910
logger = getLogger("uvicorn.error")

routers/role.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from fastapi import APIRouter, Depends, HTTPException, Form
55
from fastapi.responses import RedirectResponse
66
from pydantic import BaseModel, ConfigDict
7-
from sqlmodel import Session, select, delete
8-
from utils.db import Role, RolePermissionLink, ValidPermissions, get_session, utc_time
7+
from sqlmodel import Session, select
8+
from utils.db import get_session
9+
from utils.models import Role, RolePermissionLink, ValidPermissions, utc_time
910

1011
logger = getLogger("uvicorn.error")
1112

routers/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from fastapi.responses import RedirectResponse
33
from pydantic import BaseModel, EmailStr
44
from sqlmodel import Session
5-
from utils.db import User
5+
from utils.models import User
66
from utils.auth import get_session, get_authenticated_user, verify_password
77

88
router = APIRouter(prefix="/user", tags=["user"])

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
from sqlmodel import create_engine, Session, delete
3+
from utils.db import get_connection_url, set_up_db, tear_down_db
4+
from utils.models import User, PasswordResetToken
5+
from dotenv import load_dotenv
6+
7+
load_dotenv()
8+
9+
10+
@pytest.fixture(scope="session")
11+
def engine():
12+
"""
13+
Create a new SQLModel engine for the test database.
14+
Use an in-memory SQLite database for testing.
15+
"""
16+
engine = create_engine(
17+
get_connection_url()
18+
)
19+
return engine
20+
21+
22+
@pytest.fixture(scope="session", autouse=True)
23+
def set_up_database(engine):
24+
"""
25+
Set up the test database before running the test suite.
26+
Drop all tables and recreate them to ensure a clean state.
27+
"""
28+
set_up_db(drop=True)
29+
yield
30+
tear_down_db()
31+
32+
33+
@pytest.fixture
34+
def session(engine):
35+
"""
36+
Provide a session for database operations in tests.
37+
"""
38+
with Session(engine) as session:
39+
yield session
40+
41+
42+
@pytest.fixture(autouse=True)
43+
def clean_db(session: Session):
44+
"""
45+
Cleans up the database tables before each test.
46+
"""
47+
# Exempt from mypy until SQLModel overload properly supports delete()
48+
session.exec(delete(PasswordResetToken)) # type: ignore
49+
session.exec(delete(User)) # type: ignore
50+
51+
session.commit()

0 commit comments

Comments
 (0)