Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ SECRET_KEY=
BASE_URL=http://localhost:8000

# Database
DB_USER=
DB_PASSWORD=
DB_HOST=localhost
DB_USER=postgres
DB_PASSWORD=postgres
DB_HOST=127.0.0.1
DB_PORT=5432
DB_NAME=

Expand Down
27 changes: 2 additions & 25 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
python-version: ["3.13"]
os: [ubuntu-latest]

runs-on: ${{ matrix.os }}
Expand All @@ -20,7 +20,7 @@ jobs:
postgres:
image: postgres:latest
env:
POSTGRES_DB: test_db
POSTGRES_DB: db
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
ports:
Expand All @@ -45,29 +45,6 @@ jobs:
- name: Install project
run: uv sync --all-extras --dev

- name: Set env variables for pytest
run: |
echo "DB_USER=postgres" >> $GITHUB_ENV
echo "DB_PASSWORD=postgres" >> $GITHUB_ENV
echo "DB_HOST=127.0.0.1" >> $GITHUB_ENV
echo "DB_PORT=5432" >> $GITHUB_ENV
echo "DB_NAME=test_db" >> $GITHUB_ENV
echo "SECRET_KEY=$(openssl rand -base64 32)" >> $GITHUB_ENV
echo "BASE_URL=http://localhost:8000" >> $GITHUB_ENV
echo "RESEND_API_KEY=resend_api_key" >> $GITHUB_ENV
echo "[email protected]" >> $GITHUB_ENV

- name: Verify environment variables
run: |
echo "Checking if required environment variables are set..."
[ -n "$DB_USER" ] && \
[ -n "$DB_PASSWORD" ] && \
[ -n "$DB_HOST" ] && \
[ -n "$DB_PORT" ] && \
[ -n "$DB_NAME" ] && \
[ -n "$SECRET_KEY" ] && \
[ -n "$RESEND_API_KEY" ]

- name: Run type checking with mypy
run: uv run mypy .

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ dependencies = [
"pyjwt<3.0.0,>=2.10.1",
"jinja2<4.0.0,>=3.1.4",
"uvicorn<1.0.0,>=0.32.0",
"psycopg2<3.0.0,>=2.9.10",
"pydantic[email]<3.0.0,>=2.9.2",
"python-multipart<1.0.0,>=0.0.17",
"python-dotenv<2.0.0,>=1.0.1",
"resend<3.0.0,>=2.4.0",
"bcrypt<5.0.0,>=4.2.0",
"fastapi<1.0.0,>=0.115.5",
"pillow>=11.0.0",
"psycopg2-binary>=2.9.10",
]

[dependency-groups]
Expand All @@ -31,5 +31,5 @@ dev = [
"notebook<8.0.0,>=7.2.2",
"pytest<9.0.0,>=8.3.3",
"sqlalchemy-schemadisplay<3.0,>=2.0",
"mypy>=1.15.0",
"mypy>=1.18.2",
]
30 changes: 17 additions & 13 deletions routers/core/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,24 @@ async def login(

# Process the invitation
try:
logger.info(f"Processing invitation {invitation.id} for user {account.user.id} during login.")
process_invitation(invitation, account.user, session)
session.commit()
# Set redirect to the organization page
redirect_url = org_router.url_path_for("read_organization", org_id=invitation.organization_id)
logger.info(f"Redirecting user {account.user.id} to organization {invitation.organization_id} after accepting invitation {invitation.id}.")
if account.user and account.user.id:
logger.info(f"Processing invitation {invitation.id} for user {account.user.id} during login.")
process_invitation(invitation, account.user, session)
session.commit()
# Set redirect to the organization page
redirect_url = org_router.url_path_for("read_organization", org_id=invitation.organization_id)
logger.info(f"Redirecting user {account.user.id} to organization {invitation.organization_id} after accepting invitation {invitation.id}.")
else:
logger.error("User has no ID during invitation processing.")
raise DataIntegrityError(resource="User ID")
except Exception as e:
logger.error(
f"Error processing invitation {invitation.id} for user {account.user.id} during login: {e}",
exc_info=True
)
session.rollback()
# Raise the specific invitation processing error
raise InvitationProcessingError()
logger.error(
f"Error processing invitation during login: {e}",
exc_info=True
)
session.rollback()
# Raise the specific invitation processing error
raise InvitationProcessingError()

else:
logger.info(f"Standard login for account {account.email}. Redirecting to dashboard.")
Expand Down
82 changes: 39 additions & 43 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import pytest
import os
from typing import Generator
from sqlmodel import create_engine, Session, select
from sqlalchemy import Engine
from fastapi.testclient import TestClient
from dotenv import load_dotenv
from utils.core.db import get_connection_url, tear_down_db, set_up_db, create_default_roles
from utils.core.models import User, PasswordResetToken, EmailUpdateToken, Organization, Role, Account, Invitation
from utils.core.db import get_connection_url, tear_down_db, set_up_db, create_default_roles, ensure_database_exists
from utils.core.models import User, Organization, Role, Account, Invitation
from utils.core.auth import get_password_hash, create_access_token, create_refresh_token
from main import app
from datetime import datetime, UTC, timedelta

# Load environment variables
load_dotenv(override=True)

# Define a custom exception for test setup errors
class SetupError(Exception):
"""Exception raised for errors in the test setup process."""
Expand All @@ -21,29 +18,41 @@ def __init__(self, message="An error occurred during test setup"):
super().__init__(self.message)


@pytest.fixture(scope="session")
def engine() -> Engine:
@pytest.fixture
def env_vars(monkeypatch):
load_dotenv()

# monkeypatch remaining env vars
with monkeypatch.context() as m:
# Get valid db user, password, host, and port from env
m.setenv("DB_HOST", os.getenv("DB_HOST", "127.0.0.1"))
m.setenv("DB_PORT", os.getenv("DB_PORT", "5432"))
m.setenv("DB_USER", os.getenv("DB_USER", "postgres"))
m.setenv("DB_PASSWORD", os.getenv("DB_PASSWORD", "postgres"))
m.setenv("SECRET_KEY", "testsecretkey")
m.setenv("HOST_NAME", "Test Organization")
m.setenv("DB_NAME", "qual2db4-test-db")
m.setenv("RESEND_API_KEY", "test")
m.setenv("EMAIL_FROM", "[email protected]")
m.setenv("QUALTRICS_BASE_URL", "test")
m.setenv("QUALTRICS_API_TOKEN", "test")
m.setenv("BASE_URL", "http://localhost:8000")
yield


@pytest.fixture
def engine(env_vars):
"""
Create a new SQLModel engine for the test database.
Use PostgreSQL for testing to match production environment.
"""
# Use PostgreSQL for testing to match production environment
ensure_database_exists(get_connection_url())
engine = create_engine(get_connection_url())
return engine
set_up_db(drop=True)

yield engine

@pytest.fixture(scope="session", autouse=True)
def set_up_database(engine) -> Generator[None, None, None]:
"""
Set up the test database before running the test suite.
Drop all tables and recreate them to ensure a clean state.
"""
# Drop and recreate all tables using the helpers from db.py
tear_down_db()
set_up_db(drop=False)

yield

# Clean up after tests
tear_down_db()

Expand All @@ -57,20 +66,7 @@ def session(engine) -> Generator[Session, None, None]:
yield session


@pytest.fixture(autouse=True)
def clean_db(session: Session) -> None:
"""
Cleans up the database tables before each test.
"""
# Don't delete permissions as they are required for tests
for model in (PasswordResetToken, EmailUpdateToken, User, Role, Organization, Account):
for record in session.exec(select(model)).all():
session.delete(record)

session.commit()


@pytest.fixture()
@pytest.fixture
def test_account(session: Session) -> Account:
"""
Creates a test account in the database.
Expand All @@ -85,7 +81,7 @@ def test_account(session: Session) -> Account:
return account


@pytest.fixture()
@pytest.fixture
def test_user(session: Session, test_account: Account) -> User:
"""
Creates a test user in the database linked to the test account.
Expand All @@ -103,7 +99,7 @@ def test_user(session: Session, test_account: Account) -> User:
return user


@pytest.fixture()
@pytest.fixture
def unauth_client(session: Session) -> Generator[TestClient, None, None]:
"""
Provides a TestClient instance without authentication.
Expand All @@ -112,7 +108,7 @@ def unauth_client(session: Session) -> Generator[TestClient, None, None]:
yield client


@pytest.fixture()
@pytest.fixture
def auth_client(session: Session, test_account: Account, test_user: User) -> Generator[TestClient, None, None]:
"""
Provides a TestClient instance with valid authentication tokens.
Expand All @@ -136,13 +132,13 @@ def test_organization(session: Session) -> Organization:
session.add(organization)
session.flush()

if organization.id is None:
if organization.id:
# Use the utility function to create default roles and assign permissions
# This function handles the commit internally
create_default_roles(session, organization.id, check_first=False)
else:
pytest.fail("Failed to get organization ID after flush")

# Use the utility function to create default roles and assign permissions
# This function handles the commit internally
create_default_roles(session, organization.id, check_first=False)

return organization


Expand Down
8 changes: 2 additions & 6 deletions tests/routers/core/test_account.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from fastapi.testclient import TestClient
from starlette.datastructures import URLPath
from sqlmodel import Session, select
Expand All @@ -16,9 +15,6 @@
get_password_hash
)

# --- Fixture setup ---


# --- API Endpoint Tests ---


Expand Down Expand Up @@ -124,7 +120,7 @@ def test_password_reset_flow(unauth_client: TestClient, session: Session, test_a

# Verify content
assert call_args["to"] == [test_account.email]
assert call_args["from"] == "noreply@promptlytechnologies.com"
assert call_args["from"] == "test@example.com"
assert "Password Reset Request" in call_args["subject"]
assert "reset_password" in call_args["html"]

Expand Down Expand Up @@ -259,7 +255,7 @@ def test_request_email_update_success(auth_client: TestClient, test_account: Acc

# Verify email content
assert call_args["to"] == [test_account.email]
assert call_args["from"] == "noreply@promptlytechnologies.com"
assert call_args["from"] == "test@example.com"
assert "Confirm Email Update" in call_args["subject"]
assert "confirm_email_update" in call_args["html"]
assert new_email in call_args["html"]
Expand Down
10 changes: 5 additions & 5 deletions tests/utils/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_password_hashing() -> None:
assert not verify_password("wrong_password", hashed)


def test_token_creation_and_validation() -> None:
def test_token_creation_and_validation(env_vars) -> None:
data = {"sub": "[email protected]"}

# Test access token
Expand All @@ -55,21 +55,21 @@ def test_token_creation_and_validation() -> None:
assert decoded["type"] == "refresh"


def test_expired_token() -> None:
def test_expired_token(env_vars) -> None:
data = {"sub": "[email protected]"}
expired_delta = timedelta(minutes=-10)
expired_token = create_access_token(data, expired_delta)
decoded = validate_token(expired_token, "access")
assert decoded is None


def test_invalid_token_type() -> None:
def test_invalid_token_type(env_vars) -> None:
data = {"sub": "[email protected]"}
access_token = create_access_token(data)
decoded = validate_token(access_token, "refresh")
assert decoded is None

def test_password_reset_url_generation() -> None:
def test_password_reset_url_generation(env_vars) -> None:
"""
Tests that the password reset URL is correctly formatted and contains
the required query parameters.
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_password_pattern() -> None:
password = "aA1" * 3
assert re.match(COMPILED_PASSWORD_PATTERN, password) is None

def test_email_update_url_generation() -> None:
def test_email_update_url_generation(env_vars) -> None:
"""
Tests that the email update confirmation URL is correctly formatted and contains
the required query parameters.
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from utils.core.models import Role, Permission, Organization, RolePermissionLink, ValidPermissions
from tests.conftest import SetupError

def test_get_connection_url():
def test_get_connection_url(env_vars):
"""Test that get_connection_url returns a valid URL object"""
url = get_connection_url()
assert url.drivername == "postgresql"
Expand Down
Loading