diff --git a/.env b/.env deleted file mode 100644 index 1d44286e25..0000000000 --- a/.env +++ /dev/null @@ -1,45 +0,0 @@ -# Domain -# This would be set to the production domain with an env var on deployment -# used by Traefik to transmit traffic and aqcuire TLS certificates -DOMAIN=localhost -# To test the local Traefik config -# DOMAIN=localhost.tiangolo.com - -# Used by the backend to generate links in emails to the frontend -FRONTEND_HOST=http://localhost:5173 -# In staging and production, set this env var to the frontend host, e.g. -# FRONTEND_HOST=https://dashboard.example.com - -# Environment: local, staging, production -ENVIRONMENT=local - -PROJECT_NAME="Full Stack FastAPI Project" -STACK_NAME=full-stack-fastapi-project - -# Backend -BACKEND_CORS_ORIGINS="http://localhost,http://localhost:5173,https://localhost,https://localhost:5173,http://localhost.tiangolo.com" -SECRET_KEY=changethis -FIRST_SUPERUSER=admin@example.com -FIRST_SUPERUSER_PASSWORD=changethis - -# Emails -SMTP_HOST= -SMTP_USER= -SMTP_PASSWORD= -EMAILS_FROM_EMAIL=info@example.com -SMTP_TLS=True -SMTP_SSL=False -SMTP_PORT=587 - -# Postgres -POSTGRES_SERVER=localhost -POSTGRES_PORT=5432 -POSTGRES_DB=app -POSTGRES_USER=postgres -POSTGRES_PASSWORD=changethis - -SENTRY_DSN= - -# Configure these with your own Docker registry images -DOCKER_IMAGE_BACKEND=backend -DOCKER_IMAGE_FRONTEND=frontend diff --git a/.gitignore b/.gitignore index a6dd346572..b9db13d147 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,71 @@ -.vscode +# Environment files +.env +.env.* +!.env.example + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual Environment +.venv +venv/ +ENV/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +.DS_Store + +# Testing +.coverage +htmlcov/ +.pytest_cache/ +.tox/ + +# Logs +logs/ +*.log + +# Node.js node_modules/ + +# Local development +.env.local +.env.development.local +.env.test.local +.env.production.local + +# Debug logs +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# Misc +.DS_Store +*.pem + +# Playwright /test-results/ /playwright-report/ /blob-report/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..0bec3a7564 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,123 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Full-stack FastAPI + React application with: +- **Backend**: FastAPI, SQLModel, PostgreSQL, Alembic, JWT auth +- **Frontend**: React, TypeScript, Vite, TanStack Query/Router, Chakra UI +- **Infrastructure**: Docker Compose, Traefik proxy + +## Essential Commands + +### Backend Development +```bash +# Run tests +cd backend && bash ./scripts/test.sh + +# Run specific test +cd backend && python -m pytest app/tests/api/routes/test_users.py::test_read_users -xvs + +# Lint code +cd backend && bash ./scripts/lint.sh + +# Format code +cd backend && bash ./scripts/format.sh + +# Create migration +cd backend && alembic revision -m "migration_name" --autogenerate + +# Apply migrations +cd backend && alembic upgrade head + +# Database setup +cd backend && bash ./scripts/setup_db.sh +``` + +### Frontend Development +```bash +# Install dependencies +cd frontend && npm install + +# Run development server +cd frontend && npm run dev + +# Build for production +cd frontend && npm run build + +# Lint/format code +cd frontend && npm run lint + +# Generate API client from OpenAPI +cd frontend && npm run generate-client + +# Run E2E tests +cd frontend && npx playwright test + +# Run specific test +cd frontend && npx playwright test tests/login.spec.ts +``` + +### Docker Development +```bash +# Start all services +docker compose up -d + +# View logs +docker compose logs -f backend + +# Rebuild specific service +docker compose build backend + +# Run backend tests in Docker +docker compose exec backend bash /app/scripts/test.sh + +# Access database +docker compose exec db psql -U postgres app +``` + +## Architecture Patterns + +### Backend Structure +- **Models**: SQLModel with UUID primary keys and timestamp mixins (`app/models/`) +- **Routes**: Modular route organization (`app/api/routes/`) +- **Auth**: JWT with refresh tokens, OAuth support (`app/api/routes/auth/`) +- **Config**: Pydantic Settings (`app/core/config.py`) +- **Database**: Async SQLAlchemy sessions (`app/db/session.py`) + +### Frontend Structure +- **Routing**: File-based routing with TanStack Router (`src/routes/`) +- **API Client**: Auto-generated from OpenAPI (`src/client/`) +- **Auth**: Token management in localStorage (`src/hooks/useAuth.ts`) +- **Components**: Organized by feature (`src/components/`) +- **UI**: Chakra UI with custom theme (`src/theme.tsx`) + +### Database Migrations +Alembic manages migrations. When modifying models: +1. Make changes to SQLModel classes +2. Generate migration: `alembic revision -m "description" --autogenerate` +3. Review generated migration file +4. Apply: `alembic upgrade head` + +### Authentication Flow +- Login returns access and refresh tokens +- Access token expires in 15 minutes +- Refresh token used to get new access token +- Frontend automatically refreshes tokens + +## Development URLs +- Frontend: http://localhost:5173 +- Backend API: http://localhost:8000 +- API Docs: http://localhost:8000/docs +- Adminer (DB UI): http://localhost:8080 +- MailCatcher: http://localhost:1080 +- Traefik Dashboard: http://localhost:8090 + +## Environment Variables +Create `.env` file from template. Key variables: +- `POSTGRES_*`: Database connection +- `FIRST_SUPERUSER*`: Initial admin account +- `BACKEND_CORS_ORIGINS`: CORS configuration +- `SMTP_*`: Email settings +- `VITE_API_URL`: Frontend API endpoint \ No newline at end of file diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000000..6e8813a12d --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,43 @@ +# Database +DATABASE_URL=postgresql://postgres:postgres@localhost:5432/copilot +ASYNC_DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/copilot + +# Security +SECRET_KEY=your-secret-key-here +ALGORITHM=HS256 +ACCESS_TOKEN_EXPIRE_MINUTES=30 +REFRESH_TOKEN_EXPIRE_DAYS=7 + +# First superuser +FIRST_SUPERUSER=admin@example.com +FIRST_SUPERUSER_PASSWORD=changeme + +# Email (for password reset, etc.) +SMTP_TLS=True +SMTP_PORT=587 +SMTP_HOST=smtp.example.com +SMTP_USER=your-email@example.com +SMTP_PASSWORD=your-email-password +EMAILS_FROM_EMAIL=noreply@example.com +EMAILS_FROM_NAME="Copilot" + +# OAuth (Google) +GOOGLE_OAUTH_CLIENT_ID=your-google-client-id +GOOGLE_OAUTH_CLIENT_SECRET=your-google-client-secret +GOOGLE_OAUTH_REDIRECT_URI=http://localhost:8000/auth/google/callback + +# OAuth (Microsoft) +MICROSOFT_OAUTH_CLIENT_ID=your-microsoft-client-id +MICROSOFT_OAUTH_CLIENT_SECRET=your-microsoft-client-secret +MICROSOFT_OAUTH_REDIRECT_URI=http://localhost:8000/auth/microsoft/callback +MICROSOFT_OAUTH_TENANT=common + +# CORS (comma-separated list of origins, or * for all) +BACKEND_CORS_ORIGINS=["http://localhost:3000", "http://localhost:8000"] + +# Logging +LOG_LEVEL=INFO +SQL_ECHO=False + +# Environment (development, staging, production) +ENVIRONMENT=development diff --git a/backend/DB_SETUP.md b/backend/DB_SETUP.md new file mode 100644 index 0000000000..c4ccd248f9 --- /dev/null +++ b/backend/DB_SETUP.md @@ -0,0 +1,118 @@ +# Database Setup and Migrations + +This document explains how to set up the database and manage migrations for the Copilot backend. + +## Prerequisites + +- Python 3.8+ +- PostgreSQL 13+ +- pip +- virtualenv (recommended) + +## Setup + +1. **Create a virtual environment** (recommended): + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + ``` + +2. **Install dependencies**: + ```bash + pip install -e ".[dev]" + ``` + +3. **Set up environment variables**: + Copy `.env.example` to `.env` and update the values: + ```bash + cp .env.example .env + ``` + +## Database Configuration + +Update the following environment variables in your `.env` file: + +``` +DATABASE_URL=postgresql://postgres:postgres@localhost:5432/copilot +ASYNC_DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/copilot +``` + +## Running Migrations + +### Create a new migration + +To create a new migration after making changes to your models: + +```bash +python -m scripts.migrate create --message "your migration message" +``` + +### Apply migrations + +To apply all pending migrations: + +```bash +python -m scripts.migrate upgrade head +``` + +### Rollback a migration + +To rollback to a previous migration: + +```bash +python -m scripts.migrate downgrade +``` + +### Show current migration + +To show the current migration: + +```bash +python -m scripts.migrate current +``` + +### Show migration history + +To show the migration history: + +```bash +python -m scripts.migrate history +``` + +## Initial Setup + +To set up the database and run all migrations: + +```bash +./scripts/setup_db.sh +``` + +This will: +1. Check if the database is accessible +2. Run all pending migrations +3. Create an initial admin user if it doesn't exist + +## Database Models + +The database models are defined in `app/models/`: + +- `base.py`: Base model classes and mixins +- `user.py`: User-related models and schemas + +## Common Issues + +### Database Connection Issues + +If you encounter connection issues: + +1. Ensure PostgreSQL is running +2. Check that the database exists and the user has the correct permissions +3. Verify the connection string in your `.env` file + +### Migration Issues + +If you encounter issues with migrations: + +1. Make sure all models are properly imported in `app/models/__init__.py` +2. Check for any syntax errors in your models +3. If needed, you can delete the migration files in `app/alembic/versions/` and create a new initial migration diff --git a/backend/alembic.ini b/backend/alembic.ini index 24841c2bfb..b4a02f15bc 100755 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -5,26 +5,26 @@ script_location = app/alembic # template used to generate migration files -# file_template = %%(rev)s_%%(slug)s +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d_%%(slug)s # timezone to use when rendering the date # within the migration file as well as the filename. # string value is passed to dateutil.tz.gettz() # leave blank for localtime -# timezone = +timezone = UTC # max length of characters to apply to the # "slug" field -#truncate_slug_length = 40 +truncate_slug_length = 40 # set to 'true' to run the environment during # the 'revision' command, regardless of autogenerate -# revision_environment = false +revision_environment = false # set to 'true' to allow .pyc and .pyo files without # a source .py file to be detected as revisions in the # versions/ directory -# sourceless = false +sourceless = false # version location specification; this defaults # to alembic/versions. When using multiple version @@ -33,7 +33,10 @@ script_location = app/alembic # the output encoding used when revision files # are written from script.py.mako -# output_encoding = utf-8 +output_encoding = utf-8 + +# Database connection string (overridden by env.py) +sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/copilot # Logging configuration [loggers] @@ -46,18 +49,18 @@ keys = console keys = generic [logger_root] -level = WARN +level = INFO handlers = console qualname = [logger_sqlalchemy] -level = WARN -handlers = +level = WARNING +handlers = qualname = sqlalchemy.engine [logger_alembic] level = INFO -handlers = +handlers = qualname = alembic [handler_console] diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index c2b83c841d..273f2b8ba4 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -8,19 +8,21 @@ from pydantic import ValidationError from sqlmodel import Session -from app.core import security -from app.core.config import settings -from app.core.db import engine +from app.core import security, SessionLocal, settings +from app.core.config import settings as app_settings from app.models import TokenPayload, User reusable_oauth2 = OAuth2PasswordBearer( - tokenUrl=f"{settings.API_V1_STR}/login/access-token" + tokenUrl=f"{app_settings.API_V1_STR}/login/access-token" ) def get_db() -> Generator[Session, None, None]: - with Session(engine) as session: - yield session + db = SessionLocal() + try: + yield db + finally: + db.close() SessionDep = Annotated[Session, Depends(get_db)] diff --git a/backend/app/api/main.py b/backend/app/api/main.py index eac18c8e8f..0a3f572515 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -1,14 +1,20 @@ from fastapi import APIRouter from app.api.routes import items, login, private, users, utils +from app.api.routes.auth.router import router as auth_router from app.core.config import settings api_router = APIRouter() -api_router.include_router(login.router) -api_router.include_router(users.router) -api_router.include_router(utils.router) -api_router.include_router(items.router) +# Include auth routes - already prefixed with /api/v1 in main.py +api_router.include_router(auth_router) +# Include other routes - they will be under /api/v1 prefix added in main.py +api_router.include_router(login.router, prefix="/login", tags=["login"]) +api_router.include_router(users.router, prefix="/users", tags=["users"]) +api_router.include_router(utils.router, prefix="/utils", tags=["utils"]) +api_router.include_router(items.router, prefix="/items", tags=["items"]) + +# Include private routes in local environment if settings.ENVIRONMENT == "local": - api_router.include_router(private.router) + api_router.include_router(private.router, prefix="/private", tags=["private"]) diff --git a/backend/app/api/routes/auth/__init__.py b/backend/app/api/routes/auth/__init__.py new file mode 100644 index 0000000000..8cea025474 --- /dev/null +++ b/backend/app/api/routes/auth/__init__.py @@ -0,0 +1 @@ +# This file makes the auth directory a Python package diff --git a/backend/app/api/routes/auth/auth.py b/backend/app/api/routes/auth/auth.py new file mode 100644 index 0000000000..31ffdde67c --- /dev/null +++ b/backend/app/api/routes/auth/auth.py @@ -0,0 +1,139 @@ +from datetime import timedelta +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm +from sqlmodel import Session + +from app.core import security +from app.core.config import settings +from app.db import get_db +from app.models import User +from app.schemas.auth import Token, UserLogin, UserRegister, UserOut, PasswordResetRequest, PasswordResetConfirm +from app.crud import create_user, get_user_by_email, update_user + +router = APIRouter() + +@router.post("/signup", response_model=UserOut, status_code=status.HTTP_201_CREATED) +def signup( + *, + db: Session = Depends(get_db), + user_in: UserRegister, +) -> Any: + """ + Create new user with email and password. + """ + # Check if user with this email already exists + db_user = get_user_by_email(db, email=user_in.email) + if db_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The user with this email already exists in the system.", + ) + + # Create new user + user = create_user(db, user_in) + + # TODO: Send verification email + + return user + +@router.post("/login", response_model=Token) +def login( + db: Session = Depends(get_db), + form_data: OAuth2PasswordRequestForm = Depends(), +) -> Any: + """ + OAuth2 compatible token login, get an access token for future requests. + """ + # Authenticate user + user = security.authenticate( + db, email=form_data.username, password=form_data.password + ) + if not user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Incorrect email or password", + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Inactive user", + ) + + # Generate tokens + tokens = security.generate_token_response(str(user.id)) + + # Store refresh token in database + # TODO: Implement refresh token storage + + return tokens + +@router.post("/refresh", response_model=Token) +def refresh_token( + refresh_token: str, + db: Session = Depends(get_db), +) -> Any: + """ + Refresh access token using a valid refresh token. + """ + # Verify refresh token + try: + token_data = security.verify_refresh_token(refresh_token) + except HTTPException as e: + raise e + + # Get user + user = get_user(db, user_id=token_data["sub"]) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + + # Generate new tokens + tokens = security.generate_token_response(str(user.id)) + + # TODO: Update refresh token in database + + return tokens + +@router.post("/forgot-password", status_code=status.HTTP_202_ACCEPTED) +def forgot_password( + password_reset: PasswordResetRequest, + db: Session = Depends(get_db), +) -> Any: + """ + Request password reset. + """ + user = get_user_by_email(db, email=password_reset.email) + if not user: + # Don't reveal that the user doesn't exist + return {"message": "If your email is registered, you will receive a password reset link."} + + # TODO: Generate password reset token and send email + + return {"message": "If your email is registered, you will receive a password reset link."} + +@router.post("/reset-password", status_code=status.HTTP_200_OK) +def reset_password( + reset_data: PasswordResetConfirm, + db: Session = Depends(get_db), +) -> Any: + """ + Reset password using a valid token. + """ + # TODO: Verify reset token + # TODO: Update user password + + return {"message": "Password updated successfully"} + +@router.get("/me", response_model=UserOut) +def read_users_me( + current_user: User = Depends(security.get_current_user), +) -> Any: + """ + Get current user. + """ + return current_user diff --git a/backend/app/api/routes/auth/oauth.py b/backend/app/api/routes/auth/oauth.py new file mode 100644 index 0000000000..a8ec4c01b7 --- /dev/null +++ b/backend/app/api/routes/auth/oauth.py @@ -0,0 +1,175 @@ +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import RedirectResponse +from pydantic import AnyHttpUrl +from sqlmodel import Session + +from app.core import security +from app.core.config import settings +from app.db import get_db +from app.models import User +from app.schemas.auth import OAuthProvider, OAuthTokenRequest, SSOProvider, SSORequest + +router = APIRouter(prefix="/oauth") + +# OAuth providers configuration +OAUTH_PROVIDERS = { + OAuthProvider.GOOGLE: { + "authorization_url": "https://accounts.google.com/o/oauth2/v2/auth", + "token_url": "https://oauth2.googleapis.com/token", + "userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo", + "scopes": ["openid", "email", "profile"], + }, + OAuthProvider.MICROSOFT: { + "authorization_url": "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + "token_url": "https://login.microsoftonline.com/common/oauth2/v2.0/token", + "userinfo_url": "https://graph.microsoft.com/oidc/userinfo", + "scopes": ["openid", "email", "profile"], + }, +} + +def get_oauth_redirect_url(provider: OAuthProvider) -> str: + """Generate OAuth redirect URL for the given provider.""" + if provider not in OAUTH_PROVIDERS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported OAuth provider: {provider}", + ) + + provider_config = OAUTH_PROVIDERS[provider] + redirect_uri = f"{settings.SERVER_HOST}{settings.API_V1_STR}/auth/oauth/{provider}/callback" + + # TODO: Generate state parameter for CSRF protection + state = "" + + # Build authorization URL + from urllib.parse import urlencode + params = { + "client_id": getattr(settings, f"{provider.upper()}_CLIENT_ID"), + "response_type": "code", + "redirect_uri": redirect_uri, + "scope": " ".join(provider_config["scopes"]), + "state": state, + "access_type": "offline", + "prompt": "consent", + } + + return f"{provider_config['authorization_url']}?{urlencode(params)}" + +@router.get("/{provider}") +async def oauth_login( + provider: OAuthProvider, + redirect_uri: str, +): + """ + Initiate OAuth login flow. + """ + # Store redirect_uri in session or state parameter + # For now, we'll pass it as a query parameter to the OAuth provider + try: + auth_url = get_oauth_redirect_url(provider) + return {"authorization_url": f"{auth_url}&redirect_uri={redirect_uri}"} + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Error initiating OAuth flow", + ) from e + +@router.get("/{provider}/callback") +async def oauth_callback( + provider: OAuthProvider, + code: str, + state: Optional[str] = None, + error: Optional[str] = None, + db: Session = Depends(get_db), +): + """ + OAuth callback endpoint. + """ + if error: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"OAuth error: {error}", + ) + + if provider not in OAUTH_PROVIDERS: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported OAuth provider: {provider}", + ) + + # TODO: Verify state parameter + + try: + # Exchange authorization code for access token + provider_config = OAUTH_PROVIDERS[provider] + token_data = await exchange_code_for_token(provider, code, provider_config) + + # Get user info from provider + user_info = await get_user_info(provider, token_data["access_token"], provider_config) + + # Find or create user + user = await get_or_create_user_from_oauth(db, provider, user_info) + + # Generate JWT tokens + tokens = security.generate_token_response(str(user.id)) + + # TODO: Store refresh token in database + + # Redirect to frontend with tokens + # In production, use secure, http-only cookies + redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={tokens['access_token']}&refresh_token={tokens['refresh_token']}" + return RedirectResponse(url=redirect_url) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Error during OAuth callback: {str(e)}", + ) + +async def exchange_code_for_token(provider: OAuthProvider, code: str, provider_config: Dict) -> Dict: + """Exchange authorization code for access token.""" + # TODO: Implement token exchange + # This will make a POST request to the provider's token endpoint + # with the authorization code and client credentials + return {} + +async def get_user_info(provider: OAuthProvider, access_token: str, provider_config: Dict) -> Dict: + """Get user info from OAuth provider.""" + # TODO: Implement user info retrieval + # This will make a GET request to the provider's userinfo endpoint + # with the access token + return {} + +async def get_or_create_user_from_oauth(db: Session, provider: OAuthProvider, user_info: Dict) -> User: + """Find or create a user from OAuth user info.""" + # TODO: Implement user lookup/creation + # This will find an existing user by email or create a new one + # and update the SSO provider information + return User() # Placeholder + +# SSO Endpoints +@router.post("/sso/{provider}") +async def sso_login( + provider: SSOProvider, + sso_request: SSORequest, +): + """ + Initiate SSO login flow for SAML or OIDC. + """ + if provider == SSOProvider.SAML: + # TODO: Implement SAML SSO + pass + elif provider == SSOProvider.OIDC: + # TODO: Implement OIDC SSO + pass + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported SSO provider: {provider}", + ) + + return {"message": "SSO flow initiated"} diff --git a/backend/app/api/routes/auth/router.py b/backend/app/api/routes/auth/router.py new file mode 100644 index 0000000000..afef34ec45 --- /dev/null +++ b/backend/app/api/routes/auth/router.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +from app.api.routes.auth import auth, oauth + +router = APIRouter() + +# Include auth routes +router.include_router(auth.router, prefix="/auth", tags=["auth"]) + +# Include OAuth/SSO routes +router.include_router(oauth.router, tags=["oauth"]) diff --git a/backend/app/api/routes/items.py b/backend/app/api/routes/items.py index 177dc1e476..dbc81a739c 100644 --- a/backend/app/api/routes/items.py +++ b/backend/app/api/routes/items.py @@ -7,7 +7,7 @@ from app.api.deps import CurrentUser, SessionDep from app.models import Item, ItemCreate, ItemPublic, ItemsPublic, ItemUpdate, Message -router = APIRouter(prefix="/items", tags=["items"]) +router = APIRouter(tags=["items"]) @router.get("/", response_model=ItemsPublic) diff --git a/backend/app/api/routes/login.py b/backend/app/api/routes/login.py index 980c66f86f..a93f8a526d 100644 --- a/backend/app/api/routes/login.py +++ b/backend/app/api/routes/login.py @@ -10,7 +10,7 @@ from app.core import security from app.core.config import settings from app.core.security import get_password_hash -from app.models import Message, NewPassword, Token, UserPublic +from app.models import Message, NewPassword, TokenPair as Token, UserPublic from app.utils import ( generate_password_reset_token, generate_reset_password_email, @@ -21,7 +21,7 @@ router = APIRouter(tags=["login"]) -@router.post("/login/access-token") +@router.post("/access-token") def login_access_token( session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()] ) -> Token: @@ -43,7 +43,7 @@ def login_access_token( ) -@router.post("/login/test-token", response_model=UserPublic) +@router.post("/test-token", response_model=UserPublic) def test_token(current_user: CurrentUser) -> Any: """ Test access token diff --git a/backend/app/api/routes/private.py b/backend/app/api/routes/private.py index 9f33ef1900..e820977b69 100644 --- a/backend/app/api/routes/private.py +++ b/backend/app/api/routes/private.py @@ -10,7 +10,7 @@ UserPublic, ) -router = APIRouter(tags=["private"], prefix="/private") +router = APIRouter(tags=["private"]) class PrivateUserCreate(BaseModel): diff --git a/backend/app/api/routes/users.py b/backend/app/api/routes/users.py index 6429818458..0867849163 100644 --- a/backend/app/api/routes/users.py +++ b/backend/app/api/routes/users.py @@ -26,7 +26,7 @@ ) from app.utils import generate_new_account_email, send_email -router = APIRouter(prefix="/users", tags=["users"]) +router = APIRouter(tags=["users"]) @router.get( diff --git a/backend/app/api/routes/utils.py b/backend/app/api/routes/utils.py index fc093419b3..f711f9f994 100644 --- a/backend/app/api/routes/utils.py +++ b/backend/app/api/routes/utils.py @@ -5,7 +5,7 @@ from app.models import Message from app.utils import generate_test_email, send_email -router = APIRouter(prefix="/utils", tags=["utils"]) +router = APIRouter(tags=["utils"]) @router.post( diff --git a/backend/app/core/__init__.py b/backend/app/core/__init__.py index e69de29bb2..341c5ee459 100644 --- a/backend/app/core/__init__.py +++ b/backend/app/core/__init__.py @@ -0,0 +1,136 @@ +""" +Core functionality for the Copilot backend application. + +This package contains core components and utilities used throughout the application, +including configuration, logging, database connections, security utilities, and helpers. +""" + +# Import core modules to make them available when importing from app.core +from .config import settings # noqa: F401 +from .logging import logger, setup_logging # noqa: F401 + +# Database +from .db import ( # noqa: F401 + get_db, + get_async_db, + get_sync_session, + get_async_session, + init_db, + async_init_db, + sync_engine, + async_engine, + SessionLocal, + AsyncSessionLocal, +) + +# Security +from .security import ( # noqa: F401 + TokenType, + create_token, + create_access_token, + create_refresh_token, + create_password_reset_token, + create_email_verification_token, + verify_token, + verify_password, + get_password_hash, + generate_password, + check_password_strength, + get_current_user, + get_current_active_user, + get_current_active_superuser, + get_token_from_request, + get_current_user_optional, + generate_token_response, + verify_refresh_token, +) + +# Utils +from .utils import ( # noqa: F401 + generate_uuid, + generate_random_string, + generate_random_number, + get_timestamp, + get_datetime, + format_datetime, + parse_datetime, + is_valid_email, + is_valid_url, + hash_password, + generate_jwt_token, + decode_jwt_token, + encrypt_data, + decrypt_data, + to_camel_case, + to_snake_case, + dict_to_camel_case, + dict_to_snake_case, + get_client_ip, + get_user_agent, + get_domain_from_email, + mask_email, + mask_phone, + paginate, +) + +# Define what's available when importing from app.core +__all__ = [ + # Configuration and logging + 'settings', + 'logger', + 'setup_logging', + + # Database + 'get_db', + 'get_async_db', + 'get_sync_session', + 'get_async_session', + 'init_db', + 'async_init_db', + + # Security + 'TokenType', + 'create_token', + 'create_access_token', + 'create_refresh_token', + 'create_password_reset_token', + 'create_email_verification_token', + 'verify_token', + 'verify_password', + 'get_password_hash', + 'generate_password', + 'check_password_strength', + 'get_current_user', + 'get_current_active_user', + 'get_current_active_superuser', + 'get_token_from_request', + 'get_current_user_optional', + 'generate_token_response', + 'verify_refresh_token', + + # Utils + 'generate_uuid', + 'generate_random_string', + 'generate_random_number', + 'get_timestamp', + 'get_datetime', + 'format_datetime', + 'parse_datetime', + 'is_valid_email', + 'is_valid_url', + 'hash_password', + 'generate_jwt_token', + 'decode_jwt_token', + 'encrypt_data', + 'decrypt_data', + 'to_camel_case', + 'to_snake_case', + 'dict_to_camel_case', + 'dict_to_snake_case', + 'get_client_ip', + 'get_user_agent', + 'get_domain_from_email', + 'mask_email', + 'mask_phone', + 'paginate', +] \ No newline at end of file diff --git a/backend/app/core/config.py b/backend/app/core/config.py index d58e03c87d..0bd4fb40aa 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -1,6 +1,13 @@ +""" +Application configuration settings. + +This module defines the application configuration using Pydantic settings management. +It loads environment variables from a .env file and provides type-safe access to them. +""" import secrets import warnings -from typing import Annotated, Any, Literal +from pathlib import Path +from typing import Annotated, Any, Literal, Optional, Union from pydantic import ( AnyUrl, @@ -9,12 +16,152 @@ HttpUrl, PostgresDsn, computed_field, + field_validator, model_validator, + RedisDsn, ) from pydantic_core import MultiHostUrl from pydantic_settings import BaseSettings, SettingsConfigDict from typing_extensions import Self +# Project root directory +PROJECT_ROOT = Path(__file__).parent.parent.parent + + +def parse_cors(v: Any) -> list[str] | str: + """Parse CORS origins from a comma-separated string or list.""" + if isinstance(v, str) and not v.startswith("["): + return [i.strip() for i in v.split(",")] + elif isinstance(v, (list, str)): + return v + raise ValueError(v) + + +class DatabaseSettings(BaseSettings): + """Database configuration settings.""" + model_config = SettingsConfigDict(env_prefix="DATABASE_") + + POSTGRES_SERVER: str = "localhost" + POSTGRES_USER: str = "postgres" + POSTGRES_PASSWORD: str = "postgres" + POSTGRES_DB: str = "copilot" + POSTGRES_PORT: int = 5432 + SQL_ECHO: bool = False + + @computed_field + @property + def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: + """Generate the PostgreSQL database URI.""" + return MultiHostUrl.build( + scheme="postgresql+psycopg", + username=self.POSTGRES_USER, + password=self.POSTGRES_PASSWORD, + host=self.POSTGRES_SERVER, + port=self.POSTGRES_PORT, + path=f"/{self.POSTGRES_DB}", + ) + + @computed_field + @property + def ASYNC_SQLALCHEMY_DATABASE_URI(self) -> str: + """Generate the async PostgreSQL database URI.""" + return str( + MultiHostUrl.build( + scheme="postgresql+asyncpg", + username=self.POSTGRES_USER, + password=self.POSTGRES_PASSWORD, + host=self.POSTGRES_SERVER, + port=self.POSTGRES_PORT, + path=f"/{self.POSTGRES_DB}", + ) + ) + + +class AuthSettings(BaseSettings): + """Authentication and authorization settings.""" + model_config = SettingsConfigDict(env_prefix="AUTH_") + + SECRET_KEY: str = secrets.token_urlsafe(32) + ALGORITHM: str = "HS256" + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 # 7 days + REFRESH_TOKEN_EXPIRE_DAYS: int = 30 + PASSWORD_RESET_TOKEN_EXPIRE_HOURS: int = 24 + + # First superuser + FIRST_SUPERUSER: EmailStr = "admin@example.com" + FIRST_SUPERUSER_PASSWORD: str = "changeme" + + # OAuth2 + GOOGLE_OAUTH_CLIENT_ID: Optional[str] = None + GOOGLE_OAUTH_CLIENT_SECRET: Optional[str] = None + GOOGLE_OAUTH_REDIRECT_URI: Optional[HttpUrl] = None + + MICROSOFT_OAUTH_CLIENT_ID: Optional[str] = None + MICROSOFT_OAUTH_CLIENT_SECRET: Optional[str] = None + MICROSOFT_OAUTH_REDIRECT_URI: Optional[HttpUrl] = None + MICROSOFT_OAUTH_TENANT: str = "common" + + # Session + SESSION_SECRET_KEY: str = secrets.token_urlsafe(32) + SESSION_COOKIE_NAME: str = "session" + SESSION_COOKIE_HTTPONLY: bool = True + SESSION_COOKIE_SECURE: bool = False # Set to True in production with HTTPS + SESSION_COOKIE_SAMESITE: str = "lax" + SESSION_COOKIE_DOMAIN: Optional[str] = None + + # CORS + BACKEND_CORS_ORIGINS: list[HttpUrl] = [ + HttpUrl("http://localhost:3000"), + HttpUrl("http://localhost:8000"), + ] + + @property + def all_cors_origins(self) -> list[str]: + """Get all allowed CORS origins.""" + return [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + + +class EmailSettings(BaseSettings): + """Email configuration settings.""" + model_config = SettingsConfigDict(env_prefix="EMAIL_") + + SMTP_TLS: bool = True + SMTP_PORT: int = 587 + SMTP_HOST: Optional[str] = None + SMTP_USER: Optional[str] = None + SMTP_PASSWORD: Optional[str] = None + EMAILS_FROM_EMAIL: Optional[EmailStr] = None + EMAILS_FROM_NAME: Optional[str] = None + + @property + def EMAILS_ENABLED(self) -> bool: + """Check if email sending is enabled.""" + return bool(self.SMTP_HOST and self.SMTP_USER and self.SMTP_PASSWORD) + + +class RedisSettings(BaseSettings): + """Redis configuration settings.""" + model_config = SettingsConfigDict(env_prefix="REDIS_") + + REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + REDIS_PASSWORD: Optional[str] = None + REDIS_DB: int = 0 + REDIS_SSL: bool = False + + @computed_field + @property + def REDIS_URL(self) -> RedisDsn: + """Generate the Redis URL.""" + return RedisDsn.build( + scheme="redis", + host=self.REDIS_HOST, + port=self.REDIS_PORT, + username=None, + password=self.REDIS_PASSWORD, + path=f"/{self.REDIS_DB}", + ) + def parse_cors(v: Any) -> list[str] | str: if isinstance(v, str) and not v.startswith("["): @@ -24,39 +171,130 @@ def parse_cors(v: Any) -> list[str] | str: raise ValueError(v) -class Settings(BaseSettings): +class Settings(DatabaseSettings, AuthSettings, EmailSettings, RedisSettings): + """Application settings.""" model_config = SettingsConfigDict( - # Use top level .env file (one level above ./backend/) - env_file="../.env", - env_ignore_empty=True, + env_file=PROJECT_ROOT / ".env", + env_file_encoding="utf-8", + env_nested_delimiter="__", extra="ignore", + case_sensitive=True, ) + + # Application + PROJECT_NAME: str = "Copilot API" API_V1_STR: str = "/api/v1" - SECRET_KEY: str = secrets.token_urlsafe(32) - # 60 minutes * 24 hours * 8 days = 8 days - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 - FRONTEND_HOST: str = "http://localhost:5173" ENVIRONMENT: Literal["local", "staging", "production"] = "local" - - BACKEND_CORS_ORIGINS: Annotated[ - list[AnyUrl] | str, BeforeValidator(parse_cors) - ] = [] - - @computed_field # type: ignore[prop-decorator] + + # Allow development as an alias for local + @model_validator(mode='before') + @classmethod + def validate_environment(cls, data: Any) -> Any: + if isinstance(data, dict) and data.get('ENVIRONMENT') == 'development': + data['ENVIRONMENT'] = 'local' + return data + + DEBUG: bool = False + + # Security + SECRET_KEY: str = secrets.token_urlsafe(32) + + # API + API_PREFIX: str = "/api" + PROJECT_VERSION: str = "1.0.0" + + # Logging + LOG_LEVEL: str = "INFO" + LOG_FORMAT: str = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}" + LOG_FILE: Optional[Path] = None + + # Frontend + FRONTEND_HOST: HttpUrl = "http://localhost:3000" + + # CORS + BACKEND_CORS_ORIGINS: list[AnyUrl] = [ + "http://localhost:3000", + "http://localhost:8000", + ] + + # Rate limiting + RATE_LIMIT: bool = True + RATE_LIMIT_PER_MINUTE: int = 100 + RATE_LIMIT_PER_HOUR: int = 1000 + + # Security headers + SECURITY_HEADERS: bool = True + + # Trusted hosts + ALLOWED_HOSTS: list[str] = ["*"] + + # Application URLs + @computed_field + @property + def BASE_URL(self) -> HttpUrl: + """Get the base URL of the application.""" + if self.ENVIRONMENT == "production": + return HttpUrl("https://api.copilot.example.com") + elif self.ENVIRONMENT == "staging": + return HttpUrl("https://staging-api.copilot.example.com") + return HttpUrl("http://localhost:8000") + + @computed_field + @property + def FRONTEND_URL(self) -> HttpUrl: + """Get the frontend URL.""" + if self.ENVIRONMENT == "production": + return HttpUrl("https://copilot.example.com") + elif self.ENVIRONMENT == "staging": + return HttpUrl("https://staging.copilot.example.com") + return HttpUrl("http://localhost:3000") + + @computed_field @property def all_cors_origins(self) -> list[str]: - return [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + [ - self.FRONTEND_HOST - ] - + """Get all allowed CORS origins.""" + origins = [str(origin).rstrip("/") for origin in self.BACKEND_CORS_ORIGINS] + origins.append(str(self.FRONTEND_URL).rstrip("/")) + return list(set(origins)) # Remove duplicates + + @field_validator("ENVIRONMENT") + def set_debug(cls, v: str) -> str: + """Set DEBUG based on ENVIRONMENT.""" + import os + os.environ["DEBUG"] = str(v == "local").lower() + return v + + @field_validator("BACKEND_CORS_ORIGINS", mode="before") + def assemble_cors_origins(cls, v: Union[str, list[Union[str, HttpUrl]]]) -> list[HttpUrl]: + """Parse CORS origins from a comma-separated string or list.""" + if isinstance(v, str): + if v.startswith("["): + # Handle JSON array string + import json + v = json.loads(v) + else: + # Handle comma-separated string + v = [i.strip() for i in v.split(",")] + + # Convert all items to HttpUrl objects + result = [] + for item in v: + if isinstance(item, str): + result.append(HttpUrl(item)) + elif isinstance(item, HttpUrl): + result.append(item) + else: + raise ValueError(f"Invalid CORS origin: {item}") + return result + PROJECT_NAME: str SENTRY_DSN: HttpUrl | None = None - POSTGRES_SERVER: str + POSTGRES_SERVER: str = "localhost" + POSTGRES_USER: str = "postgres" + POSTGRES_PASSWORD: str = "postgres" + POSTGRES_DB: str = "copilot" POSTGRES_PORT: int = 5432 - POSTGRES_USER: str - POSTGRES_PASSWORD: str = "" - POSTGRES_DB: str = "" - + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: @@ -76,7 +314,7 @@ def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: SMTP_USER: str | None = None SMTP_PASSWORD: str | None = None EMAILS_FROM_EMAIL: EmailStr | None = None - EMAILS_FROM_NAME: EmailStr | None = None + EMAILS_FROM_NAME: str | None = None @model_validator(mode="after") def _set_default_emails_from(self) -> Self: @@ -117,4 +355,48 @@ def _enforce_non_default_secrets(self) -> Self: return self -settings = Settings() # type: ignore +# Initialize settings +settings = Settings() + +# Configure logging based on settings +logging_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "json": { + "()": "pythonjsonlogger.jsonlogger.JsonFormatter", + "format": "%(asctime)s %(levelname)s %(name)s %(message)s", + }, + "console": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "console", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "": {"handlers": ["console"], "level": settings.LOG_LEVEL}, + "uvicorn": {"level": "INFO"}, + "uvicorn.error": {"level": "INFO"}, + "uvicorn.access": {"level": "INFO", "propagate": False}, + "sqlalchemy.engine": {"level": "WARNING"}, + "sqlalchemy.pool": {"level": "WARNING"}, + }, +} + +# Apply logging configuration +import logging.config +logging.config.dictConfig(logging_config) + +# Set log level for all loggers +for logger_name in logging.root.manager.loggerDict: + if logger_name in logging_config["loggers"]: + logging.getLogger(logger_name).setLevel( + logging_config["loggers"][logger_name].get("level", settings.LOG_LEVEL) + ) + else: + logging.getLogger(logger_name).setLevel(settings.LOG_LEVEL) diff --git a/backend/app/core/database.py b/backend/app/core/database.py new file mode 100644 index 0000000000..cdb427e02a --- /dev/null +++ b/backend/app/core/database.py @@ -0,0 +1,54 @@ +from typing import AsyncGenerator, Generator + +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + +from app.core.config import settings + +# Create sync engine and session factory +engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + pool_pre_ping=True, + echo=settings.SQL_ECHO, +) +SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, +) + +# Create async engine and session factory +async_engine = create_async_engine( + str(settings.ASYNC_SQLALCHEMY_DATABASE_URI), + echo=settings.SQL_ECHO, + future=True, +) +AsyncSessionLocal = sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, +) + +# Dependency for sync database sessions +def get_db() -> Generator: + """Get a sync database session.""" + db = SessionLocal() + try: + yield db + finally: + db.close() + +# Dependency for async database sessions +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + """Get an async database session.""" + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() diff --git a/backend/app/core/db.py b/backend/app/core/db.py index ba991fb36d..332604316d 100644 --- a/backend/app/core/db.py +++ b/backend/app/core/db.py @@ -1,33 +1,229 @@ -from sqlmodel import Session, create_engine, select +""" +Database connection and session management. + +This module provides utilities for managing database connections and sessions, +including both synchronous and asynchronous database operations. +""" +import logging +from contextlib import asynccontextmanager, contextmanager +from typing import AsyncGenerator, Generator, Optional + +from sqlalchemy import create_engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import Session as SyncSession, sessionmaker +from sqlmodel import SQLModel, select -from app import crud from app.core.config import settings +from app.core.logging import logger from app.models import User, UserCreate -engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) +# Configure logging +logger = logging.getLogger(__name__) + +# Create database engines +sync_engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + pool_pre_ping=True, + pool_recycle=3600, + echo=settings.SQL_ECHO, +) + +async_engine = create_async_engine( + str(settings.ASYNC_SQLALCHEMY_DATABASE_URI), + pool_pre_ping=True, + pool_recycle=3600, + echo=settings.SQL_ECHO, +) + +# Session factories +SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=sync_engine, + class_=SyncSession, +) + +AsyncSessionLocal = async_sessionmaker( + autocommit=False, + autoflush=False, + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +# Context managers for database sessions +@contextmanager +def get_db() -> Generator[SyncSession, None, None]: + """ + Context manager for synchronous database sessions. + + Yields: + SyncSession: A synchronous database session + + Example: + with get_db() as db: + db.query(User).all() + """ + db = SessionLocal() + try: + yield db + db.commit() + except SQLAlchemyError as e: + db.rollback() + logger.error(f"Database error: {str(e)}") + raise + finally: + db.close() + + +@asynccontextmanager +async def get_async_db() -> AsyncGenerator[AsyncSession, None]: + """ + Async context manager for asynchronous database sessions. + + Yields: + AsyncSession: An asynchronous database session + + Example: + async with get_async_db() as db: + result = await db.execute(select(User)) + users = result.scalars().all() + """ + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except SQLAlchemyError as e: + await session.rollback() + logger.error(f"Async database error: {str(e)}") + raise + finally: + await session.close() + + +def get_sync_session() -> SyncSession: + """ + Get a synchronous database session. + + Returns: + SyncSession: A synchronous database session + + Note: + Remember to close the session when done using session.close() + """ + return SessionLocal() + + +async def get_async_session() -> AsyncSession: + """ + Get an asynchronous database session. + + Returns: + AsyncSession: An asynchronous database session + + Note: + Remember to close the session when done using await session.close() + """ + return AsyncSessionLocal() -# make sure all SQLModel models are imported (app.models) before initializing DB -# otherwise, SQLModel might fail to initialize relationships properly -# for more details: https://github.com/fastapi/full-stack-fastapi-template/issues/28 +# Database initialization +def init_db() -> None: + """ + Initialize the database with default data. + + This function creates the database tables and adds an initial admin user + if it doesn't already exist. + """ + try: + # Create all tables + SQLModel.metadata.create_all(sync_engine) + logger.info("Database tables created successfully") + + # Create default admin user if it doesn't exist + with get_db() as session: + # Check if admin user already exists + stmt = select(User).where(User.email == settings.FIRST_SUPERUSER) + result = session.execute(stmt) + user = result.scalar_one_or_none() + + if not user: + # Create admin user + admin_user = User( + email=settings.FIRST_SUPERUSER, + hashed_password=get_password_hash(settings.FIRST_SUPERUSER_PASSWORD), + is_active=True, + is_superuser=True, + is_verified=True, + ) + session.add(admin_user) + session.commit() + logger.info(f"Created admin user: {settings.FIRST_SUPERUSER}") + else: + logger.info(f"Admin user already exists: {settings.FIRST_SUPERUSER}") + + except Exception as e: + logger.error(f"Error initializing database: {str(e)}") + raise -def init_db(session: Session) -> None: - # Tables should be created with Alembic migrations - # But if you don't want to use migrations, create - # the tables un-commenting the next lines - # from sqlmodel import SQLModel +async def async_init_db() -> None: + """ + Asynchronously initialize the database with default data. + + This function creates the database tables and adds an initial admin user + if it doesn't already exist. + """ + try: + # Create all tables + async with async_engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + logger.info("Database tables created successfully") + + # Create default admin user if it doesn't exist + async with get_async_db() as session: + # Check if admin user already exists + stmt = select(User).where(User.email == settings.FIRST_SUPERUSER) + result = await session.execute(stmt) + user = result.scalar_one_or_none() + + if not user: + # Create admin user + admin_user = User( + email=settings.FIRST_SUPERUSER, + hashed_password=get_password_hash(settings.FIRST_SUPERUSER_PASSWORD), + is_active=True, + is_superuser=True, + is_verified=True, + ) + session.add(admin_user) + await session.commit() + logger.info(f"Created admin user: {settings.FIRST_SUPERUSER}") + else: + logger.info(f"Admin user already exists: {settings.FIRST_SUPERUSER}") + + except Exception as e: + logger.error(f"Error initializing database: {str(e)}") + raise - # This works because the models are already imported and registered from app.models - # SQLModel.metadata.create_all(engine) - user = session.exec( - select(User).where(User.email == settings.FIRST_SUPERUSER) - ).first() - if not user: - user_in = UserCreate( - email=settings.FIRST_SUPERUSER, - password=settings.FIRST_SUPERUSER_PASSWORD, - is_superuser=True, - ) - user = crud.create_user(session=session, user_create=user_in) +def get_password_hash(password: str) -> str: + """ + Generate a password hash. + + Args: + password: The plain text password + + Returns: + str: The hashed password + """ + from passlib.context import CryptContext + + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + return pwd_context.hash(password) diff --git a/backend/app/core/logging.py b/backend/app/core/logging.py new file mode 100644 index 0000000000..c8c790d210 --- /dev/null +++ b/backend/app/core/logging.py @@ -0,0 +1,126 @@ +""" +Logging configuration for the application. +""" +import logging +import logging.config +import sys +from pathlib import Path +from typing import Any, Dict, Optional + +from loguru import logger +from pydantic import Field +from pydantic_settings import BaseSettings + + +class LoggingSettings(BaseSettings): + """Logging settings.""" + LOG_LEVEL: str = Field("INFO", env="LOG_LEVEL") + LOG_FORMAT: str = "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}" + LOG_FILE: Optional[Path] = Field(None, env="LOG_FILE") + LOG_ROTATION: str = "10 MB" + LOG_RETENTION: str = "30 days" + LOG_COMPRESSION: str = "zip" + + +def setup_logging( + log_level: Optional[str] = None, + log_file: Optional[Path] = None, + log_format: Optional[str] = None, + log_rotation: Optional[str] = None, + log_retention: Optional[str] = None, + log_compression: Optional[str] = None, +) -> logger: + """ + Configure logging for the application. + + Args: + log_level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + log_file: Path to the log file + log_format: Log message format + log_rotation: Log rotation configuration + log_retention: Log retention configuration + log_compression: Log compression configuration + + Returns: + logger: Configured logger instance + """ + # Get logging settings + settings = LoggingSettings() + + # Apply overrides + log_level = log_level or settings.LOG_LEVEL + log_file = log_file or settings.LOG_FILE + log_format = log_format or settings.LOG_FORMAT + log_rotation = log_rotation or settings.LOG_ROTATION + log_retention = log_retention or settings.LOG_RETENTION + log_compression = log_compression or settings.LOG_COMPRESSION + + # Configure loguru logger + logger.remove() # Remove default handler + + # Add console handler + logger.add( + sys.stderr, + level=log_level, + format=log_format, + colorize=True, + backtrace=True, + diagnose=True, + ) + + # Add file handler if log file is specified + if log_file: + log_file.parent.mkdir(parents=True, exist_ok=True) + logger.add( + str(log_file), + level=log_level, + format=log_format, + rotation=log_rotation, + retention=log_retention, + compression=log_compression, + backtrace=True, + diagnose=True, + ) + + # Configure standard library logging to use loguru + class InterceptHandler(logging.Handler): + def emit(self, record): + # Get corresponding Loguru level if it exists + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message + frame, depth = sys._getframe(6), 6 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + # Set up logging to use loguru + logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True) + + # Disable noisy loggers + for name in [ + "asyncio", + "uvicorn", + "uvicorn.error", + "uvicorn.access", + "fastapi", + "sqlalchemy.engine", + ]: + logging.getLogger(name).handlers = [InterceptHandler()] + logging.getLogger(name).propagate = False + + # Set log level for SQLAlchemy + logging.getLogger("sqlalchemy.engine").setLevel("WARNING") + + return logger + + +# Create a default logger instance +logger = setup_logging() diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 7aff7cfb32..7262040dab 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -1,27 +1,627 @@ +""" +Security utilities for authentication and authorization. + +This module provides functions for password hashing, JWT token generation and verification, +and user authentication utilities. +""" +import json +import logging +import secrets from datetime import datetime, timedelta, timezone -from typing import Any +from typing import Any, Optional, Union, Dict, List import jwt +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import ( + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, + SecurityScopes, +) +from jose import JWTError, jwt as jose_jwt from passlib.context import CryptContext +from pydantic import ValidationError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import Session, select from app.core.config import settings +from app.core.logging import logger +from app.db import get_db, get_async_db +from app.models import TokenPayload, User, UserRole + +# Configure logging +logger = logging.getLogger(__name__) +# Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# OAuth2 scheme for token authentication +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl=f"{settings.API_V1_STR}/auth/login", + scopes={ + "me": "Read information about the current user.", + "users:read": "Read user information.", + "users:write": "Create and update users.", + "users:delete": "Delete users.", + "admin": "Admin access.", + }, + auto_error=False, +) + +# JWT token types +class TokenType: + ACCESS = "access" + REFRESH = "refresh" + RESET_PASSWORD = "reset_password" + VERIFY_EMAIL = "verify_email" -ALGORITHM = "HS256" +def create_token( + subject: Union[str, Any], + token_type: str, + expires_delta: Optional[timedelta] = None, + data: Optional[Dict[str, Any]] = None, +) -> str: + """ + Create a JWT token. + Args: + subject: The subject of the token (usually user ID) + token_type: Type of token (access, refresh, reset_password, verify_email) + expires_delta: Optional timedelta for token expiration + data: Additional data to include in the token -def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: - expire = datetime.now(timezone.utc) + expires_delta - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + Returns: + str: Encoded JWT token + """ + now = datetime.now(timezone.utc) + + # Set default expiration based on token type + if expires_delta is None: + if token_type == TokenType.ACCESS: + expires_delta = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + elif token_type == TokenType.REFRESH: + expires_delta = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) + elif token_type == TokenType.RESET_PASSWORD: + expires_delta = timedelta(hours=settings.PASSWORD_RESET_TOKEN_EXPIRE_HOURS) + else: # VERIFY_EMAIL and others + expires_delta = timedelta(days=1) + + expire = now + expires_delta + + # Prepare token data + to_encode = { + "exp": expire, + "iat": now, + "sub": str(subject), + "type": token_type, + "jti": secrets.token_urlsafe(16), # Unique token identifier + } + + # Add additional data if provided + if data: + to_encode.update(data) + + # Encode and return the token + encoded_jwt = jose_jwt.encode( + to_encode, + settings.SECRET_KEY, + algorithm=settings.ALGORITHM, + ) + return encoded_jwt +def create_access_token( + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None, + scopes: Optional[List[str]] = None, + **data: Any, +) -> str: + """ + Create an access token. + + Args: + subject: The subject of the token (usually user ID) + expires_delta: Optional timedelta for token expiration + scopes: List of scopes for the token + **data: Additional data to include in the token + + Returns: + str: Encoded JWT access token + """ + token_data = {} + if scopes: + token_data["scopes"] = scopes + if data: + token_data.update(data) + + return create_token( + subject=subject, + token_type=TokenType.ACCESS, + expires_delta=expires_delta, + data=token_data, + ) + +def create_refresh_token( + subject: Union[str, Any], + expires_delta: Optional[timedelta] = None, + **data: Any, +) -> str: + """ + Create a refresh token. + + Args: + subject: The subject of the token (usually user ID) + expires_delta: Optional timedelta for token expiration + **data: Additional data to include in the token + + Returns: + str: Encoded JWT refresh token + """ + return create_token( + subject=subject, + token_type=TokenType.REFRESH, + expires_delta=expires_delta, + data=data, + ) + + +def create_password_reset_token(email: str) -> str: + """ + Create a password reset token. + + Args: + email: User's email address + + Returns: + str: Encoded JWT password reset token + """ + return create_token( + subject=email, + token_type=TokenType.RESET_PASSWORD, + expires_delta=timedelta(hours=settings.PASSWORD_RESET_TOKEN_EXPIRE_HOURS), + ) + + +def create_email_verification_token(email: str) -> str: + """ + Create an email verification token. + + Args: + email: User's email address + + Returns: + str: Encoded JWT email verification token + """ + return create_token( + subject=email, + token_type=TokenType.VERIFY_EMAIL, + expires_delta=timedelta(days=7), # 7 days to verify email + ) + + +def verify_token( + token: str, + token_type: Optional[str] = None, + expected_subject: Optional[str] = None, +) -> Dict[str, Any]: + """ + Verify a JWT token and return its payload. + + Args: + token: The JWT token to verify + token_type: Expected token type (access, refresh, etc.) + expected_subject: Expected subject (user ID or email) + + Returns: + Dict[str, Any]: Decoded token payload + + Raises: + HTTPException: If token is invalid or expired + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = jose_jwt.decode( + token, + settings.SECRET_KEY, + algorithms=[settings.ALGORITHM], + options={"verify_aud": False}, + ) + + subject = payload.get("sub") + if subject is None: + raise credentials_exception + + # Verify token type if specified + if token_type and payload.get("type") != token_type: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid token type. Expected {token_type}", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Verify subject if expected_subject is provided + if expected_subject and str(subject) != str(expected_subject): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token subject", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return payload + + except jose_jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired", + headers={"WWW-Authenticate": "Bearer"}, + ) + except (jose_jwt.JWTError, ValidationError): + raise credentials_exception + def verify_password(plain_password: str, hashed_password: str) -> bool: + """ + Verify a password against a hash. + + Args: + plain_password: The plain text password + hashed_password: The hashed password + + Returns: + bool: True if the password matches, False otherwise + """ + if not plain_password or not hashed_password: + return False return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: + """ + Generate a password hash. + + Args: + password: The plain text password + + Returns: + str: The hashed password + """ return pwd_context.hash(password) + + +def generate_password() -> str: + """ + Generate a random password. + + Returns: + str: A random password + """ + # Generate a random password with letters, digits, and special characters + import random + import string + + length = 12 + chars = string.ascii_letters + string.digits + "!@#$%^&*()" + return ''.join(random.choice(chars) for _ in range(length)) + + +def check_password_strength(password: str) -> dict: + """ + Check the strength of a password. + + Args: + password: The password to check + + Returns: + dict: A dictionary with password strength information + """ + import re + + # Initialize result + result = { + 'length': len(password) >= 8, + 'has_uppercase': bool(re.search(r'[A-Z]', password)), + 'has_lowercase': bool(re.search(r'[a-z]', password)), + 'has_digit': bool(re.search(r'[0-9]', password)), + 'has_special': bool(re.search(r'[^A-Za-z0-9]', password)), + 'is_strong': True, + } + + # Check if all conditions are met + result['is_strong'] = all([ + result['length'], + result['has_uppercase'], + result['has_lowercase'], + result['has_digit'], + result['has_special'], + ]) + + return result + +async def get_current_user( + security_scopes: SecurityScopes, + db: AsyncSession = Depends(get_async_db), + token: str = Depends(oauth2_scheme), +) -> User: + """ + Get the current user from the JWT token. + + Args: + security_scopes: Security scopes required for the endpoint + db: Database session + token: JWT token from the Authorization header + + Returns: + User: The authenticated user + + Raises: + HTTPException: If the token is invalid or user not found + """ + if security_scopes.scopes: + authenticate_value = f'Bearer scope=\"{security_scopes.scope_str}\"' + else: + authenticate_value = "Bearer" + + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": authenticate_value}, + ) + + # Verify the token + try: + payload = verify_token(token, TokenType.ACCESS) + token_data = TokenPayload(**payload) + + # Get user from database + result = await db.execute(select(User).where(User.id == token_data.sub)) + user = result.scalar_one_or_none() + + if user is None: + raise credentials_exception + + # Check if user is active + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Inactive user", + ) + + # Check scopes + if security_scopes.scopes: + token_scopes = payload.get("scopes", []) + for scope in security_scopes.scopes: + if scope not in token_scopes: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not enough permissions", + headers={"WWW-Authenticate": authenticate_value}, + ) + + return user + + except (jose_jwt.JWTError, ValidationError) as e: + logger.error(f"JWT validation error: {str(e)}") + raise credentials_exception from e + + +async def get_current_active_user( + current_user: User = Depends(get_current_user), +) -> User: + """ + Get the current active user. + + Args: + current_user: The current authenticated user + + Returns: + User: The active user + + Raises: + HTTPException: If the user is inactive + """ + if not current_user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + +async def get_current_active_superuser( + current_user: User = Depends(get_current_user), +) -> User: + """ + Get the current active superuser. + + Args: + current_user: The current authenticated user + + Returns: + User: The superuser + + Raises: + HTTPException: If the user is not a superuser + """ + if not current_user.is_superuser: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="The user doesn't have enough privileges", + ) + return current_user + + +def generate_token_response( + user: User, + access_token_expires: Optional[timedelta] = None, + refresh_token_expires: Optional[timedelta] = None, +) -> dict: + """ + Generate access and refresh tokens for a user. + + Args: + user: The user to generate tokens for + access_token_expires: Optional expiration time for access token + refresh_token_expires: Optional expiration time for refresh token + + Returns: + dict: Dictionary containing access and refresh tokens + """ + # Define user scopes based on role + scopes = ["me"] + if user.is_superuser: + scopes.extend(["users:read", "users:write", "users:delete", "admin"]) + + # Create tokens + access_token = create_access_token( + subject=str(user.id), + scopes=scopes, + expires_delta=access_token_expires, + is_superuser=user.is_superuser, + email=user.email, + ) + + refresh_token = create_refresh_token( + subject=str(user.id), + expires_delta=refresh_token_expires, + ) + + return { + "access_token": access_token, + "token_type": "bearer", + "refresh_token": refresh_token, + "expires_in": ( + access_token_expires.total_seconds() + if access_token_expires + else settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60 + ), + } + +async def verify_refresh_token( + token: str, db: AsyncSession = Depends(get_async_db) +) -> User: + """ + Verify a refresh token and return the associated user. + + Args: + token: The refresh token to verify + db: Database session + + Returns: + User: The user associated with the refresh token + + Raises: + HTTPException: If the token is invalid or user not found + """ + try: + payload = verify_token(token, TokenType.REFRESH) + token_data = TokenPayload(**payload) + + # Get user from database + result = await db.execute(select(User).where(User.id == token_data.sub)) + user = result.scalar_one_or_none() + + if user is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found", + ) + + # Check if user is active + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Inactive user", + ) + + return user + + except (jose_jwt.JWTError, ValidationError) as e: + logger.error(f"Refresh token validation error: {str(e)}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) from e + + +def get_authorization_scheme_param(authorization_header_value: str) -> tuple[str, str]: + """ + Parse the authorization header and return the scheme and token. + + Args: + authorization_header_value: The value of the Authorization header + + Returns: + tuple: A tuple of (scheme, token) + """ + if not authorization_header_value: + return "", "" + + parts = authorization_header_value.split() + if len(parts) != 2: + return "", "" + + scheme, token = parts + return scheme, token + + +def get_token_from_request(request: Request) -> Optional[str]: + """ + Extract token from request headers or cookies. + + Args: + request: The incoming request + + Returns: + Optional[str]: The token if found, None otherwise + """ + # Try to get token from Authorization header + auth_header = request.headers.get("Authorization") + if auth_header: + scheme, token = get_authorization_scheme_param(auth_header) + if scheme.lower() == "bearer": + return token + + # Try to get token from cookie + token = request.cookies.get("access_token") + if token: + return token + + return None + + +def get_current_user_optional( + request: Request, + db: AsyncSession = Depends(get_async_db), +) -> Optional[User]: + """ + Get the current user if authenticated, otherwise return None. + + This is useful for endpoints that work for both authenticated and unauthenticated users. + + Args: + request: The incoming request + db: Database session + + Returns: + Optional[User]: The current user if authenticated, None otherwise + """ + token = get_token_from_request(request) + if not token: + return None + + try: + payload = verify_token(token, TokenType.ACCESS) + token_data = TokenPayload(**payload) + + # Get user from database + result = db.execute(select(User).where(User.id == token_data.sub)) + user = result.scalar_one_or_none() + + if user is None or not user.is_active: + return None + + return user + + except (jose_jwt.JWTError, ValidationError): + return None diff --git a/backend/app/core/utils.py b/backend/app/core/utils.py new file mode 100644 index 0000000000..c8d5d0f66c --- /dev/null +++ b/backend/app/core/utils.py @@ -0,0 +1,447 @@ +""" +Utility functions and helpers for the application. + +This module contains various utility functions that are used throughout the application +for common tasks such as string manipulation, data validation, and more. +""" +import base64 +import hashlib +import json +import logging +import os +import random +import re +import string +import time +import uuid +from datetime import datetime, timezone +from enum import Enum +from functools import wraps +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, cast + +import jwt +from fastapi import HTTPException, Request, status +from jose import jwe +from passlib.context import CryptContext +from pydantic import BaseModel, EmailStr, ValidationError + +from app.core.config import settings +from app.core.logging import logger + +# Type variable for generic function return type +T = TypeVar('T') + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# Email validation regex +EMAIL_REGEX = re.compile( + r'^[a-zA-Z0-9.!#$%&\'*+/=?^_`{|}~-]+@[a-zA-Z0-9]' + r'(?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?' + r'(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+$' +) + +# URL validation regex +URL_REGEX = re.compile( + r'^(?:http|ftp)s?://' # http:// or https:// + r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain... + r'localhost|' # localhost... + r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip + r'(?::\d+)?' # optional port + r'(?:/?|[/?]\S+)$', re.IGNORECASE +) + + +def generate_uuid() -> str: + """Generate a UUID4 string. + + Returns: + str: A UUID4 string + """ + return str(uuid.uuid4()) + + +def generate_random_string(length: int = 32) -> str: + """Generate a random alphanumeric string of specified length. + + Args: + length: Length of the string to generate (default: 32) + + Returns: + str: A random alphanumeric string + """ + chars = string.ascii_letters + string.digits + return ''.join(random.choice(chars) for _ in range(length)) + + +def generate_random_number(length: int = 6) -> str: + """Generate a random numeric string of specified length. + + Args: + length: Length of the number to generate (default: 6) + + Returns: + str: A random numeric string + """ + return ''.join(random.choice(string.digits) for _ in range(length)) + + +def get_timestamp() -> int: + """Get the current Unix timestamp. + + Returns: + int: Current Unix timestamp in seconds + """ + return int(time.time()) + + +def get_datetime() -> datetime: + """Get the current UTC datetime. + + Returns: + datetime: Current UTC datetime + """ + return datetime.now(timezone.utc) + + +def format_datetime(dt: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: + """Format a datetime object as a string. + + Args: + dt: Datetime object to format + format_str: Format string (default: "%Y-%m-%d %H:%M:%S") + + Returns: + str: Formatted datetime string + """ + return dt.strftime(format_str) + + +def parse_datetime(datetime_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime: + """Parse a datetime string into a datetime object. + + Args: + datetime_str: Datetime string to parse + format_str: Format string (default: "%Y-%m-%d %H:%M:%S") + + Returns: + datetime: Parsed datetime object + + Raises: + ValueError: If the datetime string cannot be parsed + """ + return datetime.strptime(datetime_str, format_str).replace(tzinfo=timezone.utc) + + +def is_valid_email(email: str) -> bool: + """Check if a string is a valid email address. + + Args: + email: Email address to validate + + Returns: + bool: True if the email is valid, False otherwise + """ + return bool(EMAIL_REGEX.match(email)) + + +def is_valid_url(url: str) -> bool: + """Check if a string is a valid URL. + + Args: + url: URL to validate + + Returns: + bool: True if the URL is valid, False otherwise + """ + return bool(URL_REGEX.match(url)) + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt. + + Args: + password: Plain text password + + Returns: + str: Hashed password + """ + return pwd_context.hash(password) + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against a hash. + + Args: + plain_password: Plain text password + hashed_password: Hashed password + + Returns: + bool: True if the password matches, False otherwise + """ + if not plain_password or not hashed_password: + return False + return pwd_context.verify(plain_password, hashed_password) + + +def generate_jwt_token( + data: dict, + expires_delta: Optional[int] = None, + secret_key: Optional[str] = None, + algorithm: str = "HS256" +) -> str: + """Generate a JWT token. + + Args: + data: Data to include in the token + expires_delta: Expiration time in seconds (default: 1 hour) + secret_key: Secret key for signing the token (default: settings.SECRET_KEY) + algorithm: Algorithm to use for signing (default: HS256) + + Returns: + str: Encoded JWT token + """ + secret_key = secret_key or settings.SECRET_KEY + expires_delta = expires_delta or 3600 # 1 hour default + + to_encode = data.copy() + expire = datetime.utcnow() + timedelta(seconds=expires_delta) + to_encode.update({"exp": expire}) + + return jwt.encode(to_encode, secret_key, algorithm=algorithm) + + +def decode_jwt_token( + token: str, + secret_key: Optional[str] = None, + algorithms: List[str] = ["HS256"] +) -> dict: + """Decode a JWT token. + + Args: + token: JWT token to decode + secret_key: Secret key used for signing (default: settings.SECRET_KEY) + algorithms: List of allowed algorithms (default: ["HS256"]) + + Returns: + dict: Decoded token payload + + Raises: + HTTPException: If the token is invalid or expired + """ + secret_key = secret_key or settings.SECRET_KEY + + try: + payload = jwt.decode(token, secret_key, algorithms=algorithms) + return payload + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired", + headers={"WWW-Authenticate": "Bearer"}, + ) + except jwt.JWTError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) from e + + +def encrypt_data(data: Union[dict, str], key: Optional[str] = None) -> str: + """Encrypt data using JWE. + + Args: + data: Data to encrypt (dict or JSON string) + key: Encryption key (default: settings.SECRET_KEY) + + Returns: + str: Encrypted data as a string + """ + key = key or settings.SECRET_KEY + if isinstance(data, dict): + data = json.dumps(data) + return jwe.encrypt(data.encode(), key).decode() + + +def decrypt_data(encrypted_data: str, key: Optional[str] = None) -> str: + """Decrypt data using JWE. + + Args: + encrypted_data: Encrypted data as a string + key: Decryption key (default: settings.SECRET_KEY) + + Returns: + str: Decrypted data as a string + """ + key = key or settings.SECRET_KEY + return jwe.decrypt(encrypted_data.encode(), key).decode() + + +def to_camel_case(snake_str: str) -> str: + """Convert a snake_case string to camelCase. + + Args: + snake_str: Snake case string to convert + + Returns: + str: Camel case string + """ + components = snake_str.split('_') + return components[0] + ''.join(x.title() for x in components[1:]) + + +def to_snake_case(camel_str: str) -> str: + """Convert a camelCase string to snake_case. + + Args: + camel_str: Camel case string to convert + + Returns: + str: Snake case string + """ + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', camel_str).lower() + + +def dict_to_camel_case(data: dict) -> dict: + """Convert all keys in a dictionary to camelCase. + + Args: + data: Dictionary with snake_case keys + + Returns: + dict: Dictionary with camelCase keys + """ + return {to_camel_case(k): v for k, v in data.items()} + + +def dict_to_snake_case(data: dict) -> dict: + """Convert all keys in a dictionary to snake_case. + + Args: + data: Dictionary with camelCase keys + + Returns: + dict: Dictionary with snake_case keys + """ + return {to_snake_case(k): v for k, v in data.items()} + + +def get_client_ip(request: Request) -> str: + """Get the client's IP address from the request. + + Args: + request: FastAPI Request object + + Returns: + str: Client's IP address + """ + x_forwarded_for = request.headers.get("X-Forwarded-For") + if x_forwarded_for: + return x_forwarded_for.split(",")[0] + return request.client.host if request.client else "unknown" + + +def get_user_agent(request: Request) -> str: + """Get the user agent from the request. + + Args: + request: FastAPI Request object + + Returns: + str: User agent string + """ + return request.headers.get("User-Agent", "unknown") + + +def get_domain_from_email(email: str) -> str: + """Extract the domain from an email address. + + Args: + email: Email address + + Returns: + str: Domain part of the email + """ + if "@" not in email: + return "" + return email.split("@")[1].lower() + + +def mask_email(email: str) -> str: + """Mask an email address for display. + + Example: test@example.com -> t***@e******.com + + Args: + email: Email address to mask + + Returns: + str: Masked email address + """ + if "@" not in email: + return email + + username, domain = email.split("@") + masked_username = f"{username[0]}{'*' * (len(username) - 1)}" if len(username) > 1 else username + + if "." in domain: + domain_parts = domain.split(".") + masked_domain = f"{domain_parts[0][0]}{'*' * (len(domain_parts[0]) - 1)}.{'.'.join(domain_parts[1:])}" + else: + masked_domain = domain + + return f"{masked_username}@{masked_domain}" + + +def mask_phone(phone: str) -> str: + """Mask a phone number for display. + + Example: +1234567890 -> +1******890 + + Args: + phone: Phone number to mask + + Returns: + str: Masked phone number + """ + if not phone: + return "" + + # Keep country code and last 3 digits + if len(phone) > 4: + return f"{phone[:2]}******{phone[-3:]}" + return "*" * len(phone) + + +def paginate( + items: List[T], + page: int = 1, + page_size: int = 10, + total: Optional[int] = None +) -> Dict[str, Any]: + """Paginate a list of items. + + Args: + items: List of items to paginate + page: Current page number (1-based) + page_size: Number of items per page + total: Total number of items (if None, uses len(items)) + + Returns: + dict: Dictionary with pagination metadata and items + """ + if total is None: + total = len(items) + + total_pages = (total + page_size - 1) // page_size if page_size > 0 else 1 + + return { + "items": items, + "page": page, + "page_size": page_size, + "total": total, + "total_pages": total_pages, + "has_next": page < total_pages, + "has_previous": page > 1, + } diff --git a/backend/app/crud.py b/backend/app/crud.py index 905bf48724..2df414d648 100644 --- a/backend/app/crud.py +++ b/backend/app/crud.py @@ -33,7 +33,8 @@ def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: def get_user_by_email(*, session: Session, email: str) -> User | None: statement = select(User).where(User.email == email) - session_user = session.exec(statement).first() + # Use execute() instead of exec() for SQLAlchemy compatibility + session_user = session.execute(statement).scalars().first() return session_user diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py new file mode 100644 index 0000000000..86ba0ce227 --- /dev/null +++ b/backend/app/db/__init__.py @@ -0,0 +1,24 @@ +""" +Database package. + +This package provides database session management and utilities. +""" +from app.core.db import ( + get_db, + get_async_db, + get_sync_session, + get_async_session, + init_db, + async_init_db, + get_password_hash, +) + +__all__ = [ + 'get_db', + 'get_async_db', + 'get_sync_session', + 'get_async_session', + 'init_db', + 'async_init_db', + 'get_password_hash', +] diff --git a/backend/app/db/session.py b/backend/app/db/session.py new file mode 100644 index 0000000000..96ef4d1743 --- /dev/null +++ b/backend/app/db/session.py @@ -0,0 +1,53 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession + +from app.core.config import settings + +# Create SQLAlchemy engine +engine = create_engine( + str(settings.SQLALCHEMY_DATABASE_URI), + pool_pre_ping=True, + echo=settings.SQL_ECHO, +) + +# Create async engine for async database operations +async_engine = create_async_engine( + str(settings.ASYNC_SQLALCHEMY_DATABASE_URI), + echo=settings.SQL_ECHO, +) + +# Session factory +SessionLocal = sessionmaker( + autocommit=False, + autoflush=False, + bind=engine, +) + +# Async session factory +AsyncSessionLocal = sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False +) + +def get_db(): + """Dependency for getting a synchronous database session""" + db = SessionLocal() + try: + yield db + finally: + db.close() + +async def get_async_db(): + """Dependency for getting an asynchronous database session""" + async with AsyncSessionLocal() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() diff --git a/backend/app/models.py b/backend/app/models.py index 2389b4a532..f94fd1a33f 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -1,15 +1,24 @@ import uuid +from datetime import datetime -from pydantic import EmailStr +from pydantic import EmailStr, Field as PydanticField +from sqlalchemy import Column, DateTime, func from sqlmodel import Field, Relationship, SQLModel # Shared properties class UserBase(SQLModel): email: EmailStr = Field(unique=True, index=True, max_length=255) - is_active: bool = True + is_active: bool = False # Changed to False to require email verification is_superuser: bool = False + is_verified: bool = False full_name: str | None = Field(default=None, max_length=255) + sso_provider: str | None = Field(default=None, max_length=50) + sso_sub: str | None = Field(default=None, max_length=255, index=True) + last_login: datetime | None = Field( + default=None, + sa_column=Column(DateTime(timezone=True)) + ) # Properties to receive via API on creation @@ -42,8 +51,18 @@ class UpdatePassword(SQLModel): # Database model, database table inferred from class name class User(UserBase, table=True): id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - hashed_password: str + hashed_password: str | None = None # Nullable for SSO users + email_verified: bool = Field(default=False) + created_at: datetime = Field( + default_factory=datetime.utcnow, + sa_column=Column(DateTime(timezone=True), server_default=func.now()) + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + sa_column=Column(DateTime(timezone=True), onupdate=func.now()) + ) items: list["Item"] = Relationship(back_populates="owner", cascade_delete=True) + refresh_tokens: list["RefreshToken"] = Relationship(back_populates="user", cascade_delete=True) # Properties to return via API, id is always required @@ -111,3 +130,41 @@ class TokenPayload(SQLModel): class NewPassword(SQLModel): token: str new_password: str = Field(min_length=8, max_length=40) + + +class RefreshTokenBase(SQLModel): + token: str = Field(index=True) + expires_at: datetime + is_revoked: bool = Field(default=False) + user_agent: str | None = Field(default=None, max_length=255) + ip_address: str | None = Field(default=None, max_length=45) + + +class RefreshToken(RefreshTokenBase, table=True): + id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) + user_id: uuid.UUID = Field(foreign_key="user.id", nullable=False) + created_at: datetime = Field( + default_factory=datetime.utcnow, + sa_column=Column(DateTime(timezone=True), server_default=func.now()) + ) + user: "User" = Relationship(back_populates="refresh_tokens") + + +class RefreshTokenCreate(RefreshTokenBase): + user_id: uuid.UUID + + +class RefreshTokenUpdate(SQLModel): + is_revoked: bool = True + + +class RefreshTokenPublic(RefreshTokenBase): + id: uuid.UUID + user_id: uuid.UUID + created_at: datetime + + +class TokenPair(SQLModel): + access_token: str + refresh_token: str + token_type: str = "bearer" diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000000..2c92497fc9 --- /dev/null +++ b/backend/app/models/__init__.py @@ -0,0 +1,71 @@ +# Import all models here so they're properly registered with SQLAlchemy +from sqlmodel import SQLModel + +from app.models.base import BaseDBModel, TimestampMixin +from app.models.user import ( + UserRole, + OAuthProvider, + UserBase, + UserCreate, + UserUpdate, + UserInDB as User, + UserPublic, + UserLogin, + TokenPayload, + TokenPair, + RefreshTokenBase, + RefreshToken, + RefreshTokenCreate, + RefreshTokenPublic, + PasswordResetRequest, + PasswordResetConfirm, + NewPassword, + UpdatePassword, + UserRegister, + UserUpdateMe, + UsersPublic, +) + +from app.models.item import ( + Item, + ItemBase, + ItemCreate, + ItemUpdate, + ItemPublic, + ItemsPublic, + Message, +) + +# This ensures that SQLModel knows about all models for migrations +__all__ = [ + 'BaseDBModel', + 'TimestampMixin', + 'UserRole', + 'OAuthProvider', + 'UserBase', + 'UserCreate', + 'UserUpdate', + 'User', + 'UserPublic', + 'UserLogin', + 'TokenPayload', + 'TokenPair', + 'RefreshTokenBase', + 'RefreshToken', + 'RefreshTokenCreate', + 'RefreshTokenPublic', + 'PasswordResetRequest', + 'PasswordResetConfirm', + 'NewPassword', + 'UpdatePassword', + 'UserRegister', + 'UserUpdateMe', + 'UsersPublic', + 'Item', + 'ItemBase', + 'ItemCreate', + 'ItemUpdate', + 'ItemPublic', + 'ItemsPublic', + 'Message', +] diff --git a/backend/app/models/base.py b/backend/app/models/base.py new file mode 100644 index 0000000000..8ed6240d70 --- /dev/null +++ b/backend/app/models/base.py @@ -0,0 +1,79 @@ +from datetime import datetime +from typing import Any, Dict, Optional +from uuid import UUID, uuid4 + +from sqlalchemy import Column, DateTime, func +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlmodel import Field, SQLModel, text + +class BaseDBModel(SQLModel): + """Base model for all database models.""" + + id: UUID = Field( + default_factory=uuid4, + sa_column=Column( + PG_UUID(as_uuid=True), + primary_key=True, + server_default=text("gen_random_uuid()"), + + index=True, + ), + ) + created_at: datetime = Field( + default_factory=datetime.utcnow, + + sa_column=Column( + DateTime(timezone=True), + server_default=func.now(), + + ), + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + + sa_column=Column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + + ), + ) + + class Config: + arbitrary_types_allowed = True + json_encoders = { + datetime: lambda v: v.isoformat(), + UUID: str, + } + orm_mode = True + + def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Convert model to dictionary, excluding unset fields by default.""" + kwargs.setdefault("exclude_unset", True) + return super().dict(*args, **kwargs) + + +class TimestampMixin(SQLModel): + """ + Mixin to add created_at and updated_at timestamp fields to models. + These fields will be automatically managed by the database. + """ + created_at: datetime = Field( + default_factory=datetime.utcnow, + + sa_column=Column( + DateTime(timezone=True), + server_default=func.now(), + + ), + ) + updated_at: datetime = Field( + default_factory=datetime.utcnow, + + sa_column=Column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + + ), + ) diff --git a/backend/app/models/item.py b/backend/app/models/item.py new file mode 100644 index 0000000000..3e506e6493 --- /dev/null +++ b/backend/app/models/item.py @@ -0,0 +1,67 @@ +"""Item model for the application.""" +from datetime import datetime +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field +from sqlmodel import SQLModel, Field as SQLModelField + +from app.models.base import BaseDBModel + + +class ItemBase(SQLModel): + """Base model for Item with common attributes.""" + title: str = SQLModelField(..., max_length=100, description="The title of the item") + description: Optional[str] = SQLModelField( + None, max_length=500, description="A description of the item" + ) + price: float = SQLModelField(..., gt=0, description="The price of the item in USD") + tax: Optional[float] = SQLModelField(None, ge=0, description="Tax applied to the item") + + +class ItemCreate(ItemBase): + """Model for creating a new item.""" + pass + + +class ItemUpdate(SQLModel): + """Model for updating an existing item.""" + title: Optional[str] = None + description: Optional[str] = None + price: Optional[float] = None + tax: Optional[float] = None + + +class Item(ItemBase, BaseDBModel, table=True): + """Database model for items.""" + __tablename__ = "items" + + id: UUID = SQLModelField( + default_factory=uuid4, + primary_key=True, + index=True, + nullable=False, + sa_column_kwargs={"server_default": "gen_random_uuid()"}, + ) + owner_id: UUID = SQLModelField( + ..., foreign_key="users.id", description="ID of the user who owns this item" + ) + + +class ItemPublic(ItemBase): + """Public representation of an item.""" + id: UUID + owner_id: UUID + created_at: datetime + updated_at: datetime + + +class ItemsPublic(SQLModel): + """Response model for a list of items.""" + data: list[ItemPublic] + count: int + + +class Message(SQLModel): + """Generic message response model.""" + message: str diff --git a/backend/app/models/user.py b/backend/app/models/user.py new file mode 100644 index 0000000000..ea8a481e7e --- /dev/null +++ b/backend/app/models/user.py @@ -0,0 +1,356 @@ +from datetime import datetime +from enum import Enum +from typing import List, Optional +from uuid import UUID, uuid4 + +from pydantic import EmailStr, Field, validator +from sqlalchemy import Column, DateTime, Enum as SQLEnum, String, func +from sqlmodel import Field as SQLModelField, Relationship, SQLModel + +from app.models.base import BaseDBModel, TimestampMixin + + +class UserRole(str, Enum): + """User roles for role-based access control.""" + USER = "user" + ADMIN = "admin" + SUPERUSER = "superuser" + + +class OAuthProvider(str, Enum): + """Supported OAuth providers.""" + GOOGLE = "google" + MICROSOFT = "microsoft" + GITHUB = "github" + + +class UserBase(SQLModel): + """Base user model with common fields.""" + email: EmailStr = SQLModelField( + ..., + sa_column=Column(String(255), unique=True, nullable=False, index=True), + ) + is_active: bool = Field(default=True, nullable=False) # Changed from False to True so all new users are active by default + is_verified: bool = Field(default=False, nullable=False) + email_verified: bool = Field(default=False, nullable=False) + full_name: Optional[str] = Field(default=None, max_length=255) + role: UserRole = Field( + default=UserRole.USER, + nullable=False, + sa_column=Column(SQLEnum(UserRole), server_default=UserRole.USER.value, nullable=False), + ) + last_login: Optional[datetime] = Field( + default=None, sa_column=Column(DateTime(timezone=True)) + ) + + # OAuth fields + sso_provider: Optional[OAuthProvider] = Field( + default=None, + sa_column=Column(SQLEnum(OAuthProvider)), + ) + sso_id: Optional[str] = Field( + default=None, + max_length=255, + index=True, + sa_column=Column(String(255), index=True), + ) + + # Password hash (nullable for OAuth users) + hashed_password: Optional[str] = Field( + default=None, + sa_column=Column(String(255), nullable=True), + ) + + class Config: + orm_mode = True + arbitrary_types_allowed = True + json_encoders = { + datetime: lambda v: v.isoformat() if v else None, + } + schema_extra = { + "example": { + "email": "user@example.com", + "full_name": "John Doe", + "is_active": True, + "is_verified": True, + "role": "user", + } + } + + +class UserCreate(SQLModel): + """Schema for creating a new user.""" + email: EmailStr + password: str = Field(..., min_length=8, max_length=100) + full_name: Optional[str] = Field(None, max_length=255) + + @validator('password') + def password_strength(cls, v): + if len(v) < 8: + raise ValueError('Password must be at least 8 characters long') + if not any(c.isupper() for c in v): + raise ValueError('Password must contain at least one uppercase letter') + if not any(c.islower() for c in v): + raise ValueError('Password must contain at least one lowercase letter') + if not any(c.isdigit() for c in v): + raise ValueError('Password must contain at least one number') + return v + + +class UserUpdate(SQLModel): + """Schema for updating a user.""" + email: Optional[EmailStr] = None + full_name: Optional[str] = None + is_active: Optional[bool] = None + is_verified: Optional[bool] = None + role: Optional[UserRole] = None + + class Config: + schema_extra = { + "example": { + "email": "new.email@example.com", + "full_name": "New Name", + "is_active": True, + "is_verified": True, + "role": "user", + } + } + + +class UserInDB(UserBase, BaseDBModel, table=True): + """User model for database representation.""" + __tablename__ = "users" + + # Relationships + refresh_tokens: List["RefreshToken"] = Relationship(back_populates="user") + + # Override the base fields to make them compatible with SQLModel + id: UUID = SQLModelField( + default_factory=uuid4, + primary_key=True, + index=True, + ) + created_at: datetime = SQLModelField( + default_factory=datetime.utcnow, + sa_column=Column(DateTime(timezone=True), server_default=func.now(), nullable=False), + ) + updated_at: datetime = SQLModelField( + default_factory=datetime.utcnow, + sa_column=Column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ), + ) + + +class UserPublic(UserBase): + """Public user schema (excludes sensitive data).""" + id: UUID + created_at: datetime + updated_at: datetime + + class Config: + schema_extra = { + "example": { + "id": "123e4567-e89b-12d3-a456-426614174000", + "email": "user@example.com", + "full_name": "John Doe", + "is_active": True, + "is_verified": True, + "email_verified": True, + "role": "user", + "created_at": "2023-01-01T00:00:00Z", + "updated_at": "2023-01-01T00:00:00Z", + } + } + + +class UserLogin(SQLModel): + """Schema for user login.""" + email: EmailStr + password: str = Field(..., min_length=1) + remember_me: bool = False + + +class TokenPayload(SQLModel): + """Payload for JWT tokens.""" + sub: UUID + exp: int + iat: int + type: str + jti: Optional[str] = None + scopes: List[str] = [] + + class Config: + orm_mode = True + json_encoders = { + UUID: str, + } + + +class TokenPair(SQLModel): + """Schema for access and refresh token pair.""" + access_token: str + refresh_token: str + token_type: str = "bearer" + + +class RefreshTokenBase(SQLModel): + """Base schema for refresh tokens.""" + token: str = SQLModelField(..., index=True) + expires_at: datetime + is_revoked: bool = Field(default=False, nullable=False) + user_agent: Optional[str] = Field(None, max_length=255) + ip_address: Optional[str] = Field(None, max_length=45) + + +class RefreshToken(RefreshTokenBase, BaseDBModel, table=True): + """Refresh token model for database storage.""" + __tablename__ = "refresh_tokens" + + user_id: UUID = SQLModelField( + foreign_key="users.id", + index=True, + ) + user: UserInDB = Relationship(back_populates="refresh_tokens") + + # Override the base fields to make them compatible with SQLModel + id: UUID = SQLModelField( + default_factory=uuid4, + primary_key=True, + index=True, + ) + created_at: datetime = SQLModelField( + default_factory=datetime.utcnow, + sa_column=Column(DateTime(timezone=True), server_default=func.now(), nullable=False), + ) + updated_at: datetime = SQLModelField( + default_factory=datetime.utcnow, + sa_column=Column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ), + ) + + +class RefreshTokenCreate(RefreshTokenBase): + """Schema for creating a new refresh token.""" + user_id: UUID + + +class RefreshTokenPublic(RefreshTokenBase): + """Public schema for refresh tokens.""" + id: UUID + user_id: UUID + created_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + json_encoders = { + datetime: lambda v: v.isoformat() if v else None, + UUID: str, + } + + +class PasswordResetRequest(SQLModel): + """Schema for requesting a password reset.""" + email: EmailStr + + +class PasswordResetConfirm(SQLModel): + """Schema for confirming a password reset.""" + token: str + new_password: str = Field(..., min_length=8, max_length=100) + + @validator("new_password") + def password_strength(cls, v): + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + if not any(char.isdigit() for char in v): + raise ValueError("Password must contain at least one number") + if not any(char.isupper() for char in v): + raise ValueError("Password must contain at least one uppercase letter") + if not any(char.islower() for char in v): + raise ValueError("Password must contain at least one lowercase letter") + return v + + +class NewPassword(SQLModel): + """Schema for setting a new password.""" + current_password: str = Field(..., min_length=1, description="Current password") + new_password: str = Field(..., min_length=8, max_length=100, description="New password") + + @validator("new_password") + def password_strength(cls, v): + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + if not any(char.isdigit() for char in v): + raise ValueError("Password must contain at least one number") + if not any(char.isupper() for char in v): + raise ValueError("Password must contain at least one uppercase letter") + if not any(char.islower() for char in v): + raise ValueError("Password must contain at least one lowercase letter") + return v + + +class UpdatePassword(SQLModel): + """Schema for updating a user's password.""" + current_password: str = Field(..., min_length=1, description="Current password") + new_password: str = Field(..., min_length=8, max_length=100, description="New password") + + @validator("new_password") + def password_strength(cls, v): + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + if not any(char.isdigit() for char in v): + raise ValueError("Password must contain at least one number") + if not any(char.isupper() for char in v): + raise ValueError("Password must contain at least one uppercase letter") + if not any(char.islower() for char in v): + raise ValueError("Password must contain at least one lowercase letter") + return v + + +class UserRegister(UserCreate): + """Schema for user registration.""" + full_name: Optional[str] = Field(None, max_length=255, description="User's full name") + + class Config: + schema_extra = { + "example": { + "email": "user@example.com", + "password": "SecurePass123", + "full_name": "John Doe" + } + } + + +class UsersPublic(SQLModel): + """Schema for returning a paginated list of users.""" + count: int = Field(..., description="Total number of users") + data: List[UserPublic] = Field(..., description="List of users") + + class Config: + json_encoders = { + UUID: str, + datetime: lambda v: v.isoformat() if v else None + } + + +class UserUpdateMe(SQLModel): + """Schema for updating the current user's profile.""" + email: Optional[EmailStr] = None + full_name: Optional[str] = Field(None, max_length=255) + + class Config: + schema_extra = { + "example": { + "email": "new.email@example.com", + "full_name": "New Name" + } + } diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py new file mode 100644 index 0000000000..05150efb78 --- /dev/null +++ b/backend/app/schemas/auth.py @@ -0,0 +1,87 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, EmailStr, Field, validator + + +class TokenType(str, Enum): + BEARER = "bearer" + + +class TokenBase(BaseModel): + access_token: str + token_type: str = TokenType.BEARER + refresh_token: str + + +class Token(TokenBase): + pass + + +class UserLogin(BaseModel): + email: EmailStr + password: str = Field(..., min_length=8, max_length=100) + + +class UserRegister(BaseModel): + email: EmailStr + password: str = Field(..., min_length=8, max_length=100) + full_name: Optional[str] = Field(None, max_length=255) + + @validator('password') + def password_strength(cls, v): + if len(v) < 8: + raise ValueError('Password must be at least 8 characters long') + if not any(c.isupper() for c in v): + raise ValueError('Password must contain at least one uppercase letter') + if not any(c.islower() for c in v): + raise ValueError('Password must contain at least one lowercase letter') + if not any(c.isdigit() for c in v): + raise ValueError('Password must contain at least one number') + return v + + +class UserOut(BaseModel): + id: str + email: EmailStr + full_name: Optional[str] + is_active: bool + is_verified: bool + is_superuser: bool + created_at: str + updated_at: Optional[str] + last_login: Optional[str] + + class Config: + orm_mode = True + + +class PasswordResetRequest(BaseModel): + email: EmailStr + + +class PasswordResetConfirm(BaseModel): + token: str + new_password: str = Field(..., min_length=8, max_length=100) + + +class OAuthProvider(str, Enum): + GOOGLE = "google" + MICROSOFT = "microsoft" + + +class OAuthTokenRequest(BaseModel): + code: str + redirect_uri: str + + +class SSOProvider(str, Enum): + SAML = "saml" + OIDC = "oidc" + + +class SSORequest(BaseModel): + metadata_url: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + redirect_uris: Optional[list[str]] = None diff --git a/backend/scripts/__init__.py b/backend/scripts/__init__.py new file mode 100644 index 0000000000..d085fe4d46 --- /dev/null +++ b/backend/scripts/__init__.py @@ -0,0 +1 @@ +# This file makes the scripts directory a Python package diff --git a/backend/scripts/check_db.py b/backend/scripts/check_db.py new file mode 100644 index 0000000000..fff6b60df0 --- /dev/null +++ b/backend/scripts/check_db.py @@ -0,0 +1,115 @@ +""" +Database connection check script. +This script verifies that the database is accessible and properly configured. +""" +import sys +from pathlib import Path +from sqlalchemy import create_engine, text +from sqlalchemy.exc import OperationalError, SQLAlchemyError + +# Add the backend directory to the path so we can import our app +sys.path.append(str(Path(__file__).parent.parent)) + +# Import settings after ensuring the app is in the path +from app.core.config import settings +from app.core.logging import setup_logging + +# Set up logging +logger = setup_logging() + +def check_database_config() -> bool: + """Check if the database configuration is valid.""" + required_vars = [ + 'SQLALCHEMY_DATABASE_URI', + 'ASYNC_SQLALCHEMY_DATABASE_URI', + 'FIRST_SUPERUSER', + 'FIRST_SUPERUSER_PASSWORD', + ] + + missing_vars = [var for var in required_vars if not getattr(settings, var, None)] + + if missing_vars: + logger.error(f"❌ Missing required environment variables: {', '.join(missing_vars)}") + return False + + logger.info("✅ Database configuration is valid") + return True + +def check_database_connection() -> bool: + """Check if the database is accessible.""" + try: + # Create an engine and connect to the database + engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) + with engine.connect() as conn: + # Execute a simple query to verify the connection + result = conn.execute(text("SELECT 1")) + if result.scalar() == 1: + logger.info("✅ Database connection successful!") + return True + else: + logger.error("❌ Database connection check failed: Unexpected result") + return False + except OperationalError as e: + logger.error(f"❌ Database connection failed (OperationalError): {e}") + return False + except SQLAlchemyError as e: + logger.error(f"❌ Database connection failed (SQLAlchemyError): {e}") + return False + except Exception as e: + logger.error(f"❌ Database connection failed (Unexpected error): {e}") + return False + +def check_database_version() -> bool: + """Check the database version and compatibility.""" + try: + engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) + with engine.connect() as conn: + # Check PostgreSQL version + if 'postgresql' in str(settings.SQLALCHEMY_DATABASE_URI).lower(): + result = conn.execute(text("SELECT version()")) + version = result.scalar() + logger.info(f"📊 Database version: {version}") + + # Check if the required extensions are installed + try: + result = conn.execute(text("SELECT extname FROM pg_extension WHERE extname IN ('uuid-ossp', 'pgcrypto')")) + extensions = [row[0] for row in result.fetchall()] + logger.info(f"📦 Installed extensions: {', '.join(extensions) if extensions else 'None'}") + + # Install required extensions if missing + if 'uuid-ossp' not in extensions: + logger.warning("⚠️ Extension 'uuid-ossp' is not installed. Some features may not work correctly.") + if 'pgcrypto' not in extensions: + logger.warning("⚠️ Extension 'pgcrypto' is not installed. Some features may not work correctly.") + except Exception as e: + logger.warning(f"⚠️ Could not check database extensions: {e}") + + return True + except Exception as e: + logger.warning(f"⚠️ Could not check database version: {e}") + return False + +def main(): + """Main function to run database checks.""" + print("🔍 Checking database configuration...") + if not check_database_config(): + print("❌ Database configuration is invalid. Please check your .env file.") + sys.exit(1) + + print("\n🔍 Testing database connection...") + if not check_database_connection(): + print("\n❌ Database connection failed. Please check the following:") + print(f" 1. Is the database server running?") + print(f" 2. Does the database '{settings.SQLALCHEMY_DATABASE_URI.split('/')[-1]}' exist?") + print(f" 3. Are the database credentials in your .env file correct?") + print(f" 4. Is the database server accessible from this machine?") + sys.exit(1) + + print("\n🔍 Checking database version and extensions...") + check_database_version() + + print("\n✅ All database checks passed!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/backend/scripts/init_db.py b/backend/scripts/init_db.py new file mode 100644 index 0000000000..2e18451ef8 --- /dev/null +++ b/backend/scripts/init_db.py @@ -0,0 +1,96 @@ +""" +Database initialization script. +This script initializes the database and runs migrations. +""" +import sys +import subprocess +from pathlib import Path + +# Add the backend directory to the path so we can import our app +sys.path.append(str(Path(__file__).parent.parent)) + +def run_command(command: str) -> bool: + """Run a shell command and return True if successful.""" + try: + print(f"🚀 Running: {command}") + result = subprocess.run(command, shell=True, check=True, text=True, capture_output=True) + print(result.stdout) + if result.stderr: + print(f"⚠️ {result.stderr}", file=sys.stderr) + return True + except subprocess.CalledProcessError as e: + print(f"❌ Command failed with error: {e}", file=sys.stderr) + print(f"Command output: {e.output}", file=sys.stderr) + print(f"Command stderr: {e.stderr}", file=sys.stderr) + return False + +def check_database() -> bool: + """Check if the database is accessible.""" + print("🔍 Checking database connection...") + return run_command("python -m scripts.check_db") + +def run_migrations() -> bool: + """Run database migrations.""" + print("🚀 Running migrations...") + return run_command("python -m scripts.migrate upgrade head") + +def create_initial_data() -> bool: + """Create initial data in the database.""" + print("✨ Creating initial data...") + return run_command("python -c \" +import sys +from app.db.session import SessionLocal +from app.core.config import settings +from app.models.user import UserInDB +from app.core.security import get_password_hash + +try: + db = SessionLocal() + # Create initial admin user if it doesn't exist + admin = db.query(UserInDB).filter(UserInDB.email == settings.FIRST_SUPERUSER).first() + if not admin and settings.FIRST_SUPERUSER and settings.FIRST_SUPERUSER_PASSWORD: + admin = UserInDB( + email=settings.FIRST_SUPERUSER, + hashed_password=get_password_hash(settings.FIRST_SUPERUSER_PASSWORD), + is_active=True, + is_verified=True, + email_verified=True, + role='superuser', + full_name='Admin User' + ) + db.add(admin) + db.commit() + print('✅ Created initial admin user') + else: + print('ℹ️ Admin user already exists') + + db.close() + sys.exit(0) +except Exception as e: + print(f'❌ Error setting up initial data: {e}', file=sys.stderr) + sys.exit(1) +\""") + +def main(): + """Main function to run database initialization.""" + print("🚀 Starting database initialization...") + + # Check if the database is accessible + if not check_database(): + print("❌ Database connection check failed. Please check your database configuration.") + sys.exit(1) + + # Run migrations + if not run_migrations(): + print("❌ Database migrations failed. Please check the error messages above.") + sys.exit(1) + + # Create initial data + if not create_initial_data(): + print("⚠️ Failed to create initial data. Continuing anyway...") + + print("✅ Database initialization complete!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/backend/scripts/migrate.py b/backend/scripts/migrate.py new file mode 100644 index 0000000000..b7ecb4e298 --- /dev/null +++ b/backend/scripts/migrate.py @@ -0,0 +1,113 @@ +import os +import sys +import subprocess +from pathlib import Path +from typing import Optional + +from alembic.config import Config +from alembic import command + +# Add the backend directory to the path so we can import our app +sys.path.append(str(Path(__file__).parent.parent)) + +def run_alembic_command(command_name: str, *args) -> bool: + """Run an alembic command and return True if successful.""" + try: + # Get the directory containing this script + base_dir = Path(__file__).parent.parent + + # Set up the Alembic configuration + config = Config(str(base_dir / 'alembic.ini')) + config.set_main_option('script_location', str(base_dir / 'app' / 'alembic')) + + # Import the models to ensure they're registered with SQLAlchemy + from app.models import * # noqa + from app.core.config import settings + + # Set the database URL + config.set_main_option('sqlalchemy.url', str(settings.SQLALCHEMY_DATABASE_URI)) + + # Run the command + getattr(command, command_name)(config, *args) + return True + except Exception as e: + print(f"Error running {command_name}: {e}", file=sys.stderr) + return False + +def create_migration(message: Optional[str] = None) -> bool: + """Create a new migration.""" + if not message: + print("Please provide a message for the migration with --message") + return False + + print(f"Creating migration: {message}") + return run_alembic_command('revision', '--autogenerate', '-m', message) + +def upgrade_database(revision: str = 'head') -> bool: + """Upgrade the database to the specified revision.""" + print(f"Upgrading database to revision: {revision}") + return run_alembic_command('upgrade', revision) + +def downgrade_database(revision: str) -> bool: + """Downgrade the database to the specified revision.""" + print(f"Downgrading database to revision: {revision}") + return run_alembic_command('downgrade', revision) + +def show_current() -> bool: + """Show the current revision.""" + print("Current database revision:") + return run_alembic_command('current') + +def show_history() -> bool: + """Show migration history.""" + print("Migration history:") + return run_alembic_command('history') + +def main(): + import argparse + + parser = argparse.ArgumentParser(description='Database migration utility') + subparsers = parser.add_subparsers(dest='command', help='Command to run') + + # Create migration command + create_parser = subparsers.add_parser('create', help='Create a new migration') + create_parser.add_argument('--message', '-m', required=True, help='Migration message') + + # Upgrade command + upgrade_parser = subparsers.add_parser('upgrade', help='Upgrade database') + upgrade_parser.add_argument('--revision', '-r', default='head', help='Target revision') + + # Downgrade command + downgrade_parser = subparsers.add_parser('downgrade', help='Downgrade database') + downgrade_parser.add_argument('revision', help='Target revision') + + # Show current revision + subparsers.add_parser('current', help='Show current revision') + + # Show history + subparsers.add_parser('history', help='Show migration history') + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return 1 + + if args.command == 'create': + success = create_migration(args.message) + elif args.command == 'upgrade': + success = upgrade_database(args.revision) + elif args.command == 'downgrade': + success = downgrade_database(args.revision) + elif args.command == 'current': + success = show_current() + elif args.command == 'history': + success = show_history() + else: + print(f"Unknown command: {args.command}") + return 1 + + return 0 if success else 1 + +if __name__ == '__main__': + sys.exit(main()) diff --git a/backend/scripts/setup_db.sh b/backend/scripts/setup_db.sh new file mode 100755 index 0000000000..d1acc0bd38 --- /dev/null +++ b/backend/scripts/setup_db.sh @@ -0,0 +1,87 @@ +#!/bin/bash +set -e + +# Change to the backend directory +cd "$(dirname "$0")/.." + +# Activate virtual environment if it exists +if [ -d ".venv" ]; then + if [ -f ".venv/bin/activate" ]; then + source .venv/bin/activate + elif [ -f ".venv/Scripts/activate" ]; then + source .venv/Scripts/activate + fi +fi + +# Check if Python is available +if ! command -v python &> /dev/null; then + echo "Python is not installed or not in PATH" + exit 1 +fi + +# Install dependencies if not already installed +pip install -e ".[dev]" + +# Check if database is accessible +if ! python -c " +import sys +from sqlalchemy import create_engine +from app.core.config import settings + +try: + engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) + with engine.connect() as conn: + conn.execute('SELECT 1') + print('✅ Database connection successful!') + sys.exit(0) +except Exception as e: + print(f'❌ Database connection failed: {e}') + print('\nPlease ensure that:') + print('1. PostgreSQL is installed and running') + print('2. The database exists and is accessible') + print(f'3. The connection string is correct: {settings.SQLALCHEMY_DATABASE_URI}') + sys.exit(1) +"; then + exit 1 +fi + +# Run migrations +echo "🚀 Running migrations..." +python -m scripts.migrate upgrade head + +# Create initial data if needed +echo "✨ Setting up initial data..." +python -c " +import sys +from app.db.session import SessionLocal +from app.core.config import settings +from app.models import UserInDB +from app.core.security import get_password_hash + +try: + db = SessionLocal() + # Create initial admin user if it doesn't exist + admin = db.query(UserInDB).filter(UserInDB.email == settings.FIRST_SUPERUSER).first() + if not admin and settings.FIRST_SUPERUSER and settings.FIRST_SUPERUSER_PASSWORD: + admin = UserInDB( + email=settings.FIRST_SUPERUSER, + hashed_password=get_password_hash(settings.FIRST_SUPERUSER_PASSWORD), + is_active=True, + is_verified=True, + email_verified=True, + role='superuser', + full_name='Admin User' + ) + db.add(admin) + db.commit() + print('✅ Created initial admin user') + else: + print('ℹ️ Admin user already exists') + + db.close() +except Exception as e: + print(f'❌ Error setting up initial data: {e}') + sys.exit(1) +" + +echo "🎉 Database setup complete!" diff --git a/create_tables.py b/create_tables.py new file mode 100644 index 0000000000..442c9999e1 --- /dev/null +++ b/create_tables.py @@ -0,0 +1,50 @@ +import psycopg2 +from backend.app.core.config import settings + +# Create database connection +POSTGRES_SERVER = settings.POSTGRES_SERVER +POSTGRES_USER = settings.POSTGRES_USER +POSTGRES_PASSWORD = settings.POSTGRES_PASSWORD +POSTGRES_DB = settings.POSTGRES_DB +POSTGRES_PORT = settings.POSTGRES_PORT + +def create_tables(): + print("Creating database tables...") + conn = psycopg2.connect( + host=POSTGRES_SERVER, + user=POSTGRES_USER, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DB, + port=POSTGRES_PORT + ) + conn.autocommit = True + cursor = conn.cursor() + + # Create the uuid-ossp extension if it doesn't exist + cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + + # Create users table + cursor.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + created_at TIMESTAMP WITH TIME ZONE DEFAULT now() NOT NULL, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT now() NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + is_active BOOLEAN NOT NULL DEFAULT true, + is_verified BOOLEAN NOT NULL DEFAULT false, + email_verified BOOLEAN NOT NULL DEFAULT false, + full_name VARCHAR(255), + role VARCHAR(20) NOT NULL DEFAULT 'user', + last_login TIMESTAMP WITH TIME ZONE, + sso_provider VARCHAR(20), + sso_id VARCHAR(255), + hashed_password VARCHAR(255) + ); + """) + + print("Users table created successfully!") + cursor.close() + conn.close() + +if __name__ == "__main__": + create_tables() diff --git a/frontend/.env b/frontend/.env deleted file mode 100644 index 27fcbfe8c8..0000000000 --- a/frontend/.env +++ /dev/null @@ -1,2 +0,0 @@ -VITE_API_URL=http://localhost:8000 -MAILCATCHER_HOST=http://localhost:1080 diff --git a/init_db.py b/init_db.py new file mode 100644 index 0000000000..214e0e7a79 --- /dev/null +++ b/init_db.py @@ -0,0 +1,32 @@ +from sqlmodel import SQLModel +from sqlalchemy import create_engine, text +from backend.app.core.config import settings + +# Create database URL +POSTGRES_SERVER = settings.POSTGRES_SERVER +POSTGRES_USER = settings.POSTGRES_USER +POSTGRES_PASSWORD = settings.POSTGRES_PASSWORD +POSTGRES_DB = settings.POSTGRES_DB +POSTGRES_PORT = settings.POSTGRES_PORT + +# Use psycopg2 driver explicitly +SQLALCHEMY_DATABASE_URL = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}" + +# Import all models to ensure they're registered with SQLModel +from backend.app.models import User, Item, RefreshToken + +def init_db(): + print("Creating database tables...") + engine = create_engine(SQLALCHEMY_DATABASE_URL, echo=True) + + # First create a connection to enable the uuid-ossp extension + with engine.connect() as conn: + conn.execute(text("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"")) + conn.commit() + + # Then create the tables + SQLModel.metadata.create_all(engine) + print("Database tables created successfully!") + +if __name__ == "__main__": + init_db()