-
Notifications
You must be signed in to change notification settings - Fork 0
Implement JWT token validation on backend and auth error handling on frontend #232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
b26a4a0
1e9751b
d21e777
72c134c
1de2f54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,14 +7,16 @@ | |
| from pathlib import Path | ||
| from typing import Annotated, Dict, List, Literal, Optional | ||
|
|
||
| import jwt | ||
| import nlp | ||
| import uvicorn | ||
| from dotenv import load_dotenv | ||
| from fastapi import BackgroundTasks, Body, FastAPI | ||
| from fastapi import BackgroundTasks, Body, Depends, FastAPI, HTTPException, Request | ||
| from fastapi.exception_handlers import request_validation_exception_handler | ||
| from fastapi.exceptions import RequestValidationError | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | ||
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | ||
| from fastapi.staticfiles import StaticFiles | ||
| from pydantic import AfterValidator, BaseModel, ConfigDict | ||
| from sse_starlette.sse import EventSourceResponse | ||
|
|
@@ -36,6 +38,60 @@ | |
| LOG_SECRET = os.getenv("LOG_SECRET", "").strip() | ||
| print(f"Log secret: {LOG_SECRET!r}") | ||
|
|
||
| # JWT Configuration for Auth0 | ||
| JWT_ISSUER = os.getenv("JWT_ISSUER", "https://textfocals.auth0.com/") | ||
| JWT_AUDIENCE = os.getenv("JWT_AUDIENCE", "textfocals.com") | ||
| DEMO_API_KEY = os.getenv("DEMO_API_KEY", "demo-key") | ||
|
|
||
| # Initialize HTTP Bearer security scheme | ||
| security = HTTPBearer(auto_error=False) | ||
|
|
||
|
|
||
| # Initialize HTTP Bearer security scheme | ||
| security = HTTPBearer(auto_error=False) | ||
|
|
||
|
|
||
| async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict: | ||
| """ | ||
| Verify JWT token and return decoded payload. | ||
| Handles both Auth0 tokens and demo tokens. | ||
| """ | ||
| if not credentials: | ||
| raise HTTPException(status_code=401, detail="Authorization header required") | ||
|
|
||
| token = credentials.credentials | ||
|
|
||
| # Handle demo token | ||
| if token == "demo-access-token": | ||
| return { | ||
| "sub": "demo-user", | ||
| "username": "demo", | ||
| "iss": "demo", | ||
| "aud": JWT_AUDIENCE, | ||
| "is_demo": True | ||
| } | ||
|
|
||
| try: | ||
| # For Auth0 tokens, we typically need to fetch the public key | ||
| # For now, we'll do basic JWT validation without signature verification | ||
| # In production, you'd want to verify the signature against Auth0's public key | ||
| payload = jwt.decode( | ||
| token, | ||
| options={"verify_signature": False, "verify_aud": False, "verify_iss": False} | ||
| ) | ||
|
|
||
| # Basic validation | ||
| if not payload.get("sub"): | ||
| raise HTTPException(status_code=401, detail="Invalid token: missing subject") | ||
|
|
||
| payload["is_demo"] = False | ||
| return payload | ||
|
|
||
| except jwt.ExpiredSignatureError: | ||
| raise HTTPException(status_code=401, detail="Token expired") | ||
| except jwt.InvalidTokenError: | ||
| raise HTTPException(status_code=401, detail="Invalid token") | ||
|
|
||
|
|
||
| def should_log(username: str) -> bool: | ||
| """ | ||
|
|
@@ -130,7 +186,11 @@ async def validation_exception_handler(request, exc): | |
|
|
||
| # Routes | ||
| @app.post("/api/generation") | ||
| async def generation(payload: GenerationRequestPayload, background_tasks: BackgroundTasks) -> nlp.GenerationResult: | ||
| async def generation( | ||
| payload: GenerationRequestPayload, | ||
| background_tasks: BackgroundTasks, | ||
| token_data: dict = Depends(verify_token) | ||
|
||
| ) -> nlp.GenerationResult: | ||
| ''' | ||
| To test this endpoint from curl: | ||
|
|
||
|
|
@@ -180,7 +240,11 @@ async def generation(payload: GenerationRequestPayload, background_tasks: Backgr | |
|
|
||
|
|
||
| @app.post("/api/reflections") | ||
| async def reflections(payload: ReflectionRequestPayload, background_tasks: BackgroundTasks): | ||
| async def reflections( | ||
| payload: ReflectionRequestPayload, | ||
| background_tasks: BackgroundTasks, | ||
| token_data: dict = Depends(verify_token) | ||
| ): | ||
| should_log_doctext = should_log(payload.username) | ||
|
|
||
| start_time = datetime.now() | ||
|
|
@@ -203,7 +267,10 @@ async def reflections(payload: ReflectionRequestPayload, background_tasks: Backg | |
|
|
||
|
|
||
| @app.post("/api/chat") | ||
| async def chat(payload: ChatRequestPayload): | ||
| async def chat( | ||
| payload: ChatRequestPayload, | ||
| token_data: dict = Depends(verify_token) | ||
| ): | ||
| response = await nlp.chat_stream( | ||
| messages=payload.messages, | ||
| temperature=0.7, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| """ | ||
| Simple test for authentication endpoints | ||
| """ | ||
| import requests | ||
| import time | ||
| import subprocess | ||
| import sys | ||
| import os | ||
| import signal | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| def test_auth_endpoints(): | ||
| """Test authentication on live server""" | ||
|
|
||
| print("Testing authentication endpoints...") | ||
|
|
||
| # Test data | ||
| generation_data = { | ||
| "username": "test", | ||
| "gtype": "Completion", | ||
| "prompt": "test prompt" | ||
| } | ||
|
|
||
| reflections_data = { | ||
| "username": "test", | ||
| "paragraph": "test paragraph", | ||
| "prompt": "test prompt" | ||
| } | ||
|
|
||
| chat_data = { | ||
| "messages": [{"role": "user", "content": "test"}], | ||
| "username": "test" | ||
| } | ||
|
|
||
| base_url = "http://localhost:8000" | ||
|
|
||
| # Test 1: Endpoints without auth should return 401 | ||
| print("1. Testing endpoints without authorization...") | ||
|
|
||
| response = requests.post(f"{base_url}/api/generation", json=generation_data) | ||
| assert response.status_code == 401, f"Expected 401, got {response.status_code}" | ||
| print(" ✓ /api/generation requires auth") | ||
|
|
||
| response = requests.post(f"{base_url}/api/reflections", json=reflections_data) | ||
| assert response.status_code == 401, f"Expected 401, got {response.status_code}" | ||
| print(" ✓ /api/reflections requires auth") | ||
|
|
||
| response = requests.post(f"{base_url}/api/chat", json=chat_data) | ||
| assert response.status_code == 401, f"Expected 401, got {response.status_code}" | ||
| print(" ✓ /api/chat requires auth") | ||
|
|
||
| # Test 2: Demo token should work | ||
| print("2. Testing demo token...") | ||
| headers = {"Authorization": "Bearer demo-access-token"} | ||
|
|
||
| response = requests.post(f"{base_url}/api/generation", json=generation_data, headers=headers) | ||
| assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}" | ||
| print(" ✓ /api/generation works with demo token") | ||
|
|
||
| response = requests.post(f"{base_url}/api/reflections", json=reflections_data, headers=headers) | ||
| assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}" | ||
| print(" ✓ /api/reflections works with demo token") | ||
|
|
||
| # Test 3: Invalid token should return 401 | ||
| print("3. Testing invalid token...") | ||
| headers = {"Authorization": "Bearer invalid-token"} | ||
|
|
||
| response = requests.post(f"{base_url}/api/generation", json=generation_data, headers=headers) | ||
| assert response.status_code == 401, f"Expected 401, got {response.status_code}" | ||
| print(" ✓ /api/generation rejects invalid token") | ||
|
|
||
| response = requests.post(f"{base_url}/api/reflections", json=reflections_data, headers=headers) | ||
| assert response.status_code == 401, f"Expected 401, got {response.status_code}" | ||
| print(" ✓ /api/reflections rejects invalid token") | ||
|
|
||
| # Test 4: Ping should not require auth | ||
| print("4. Testing ping endpoint...") | ||
| response = requests.get(f"{base_url}/api/ping") | ||
| assert response.status_code == 200, f"Expected 200, got {response.status_code}" | ||
| assert "timestamp" in response.json() | ||
| print(" ✓ /api/ping works without auth") | ||
|
|
||
| print("\n✅ All authentication tests passed!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_auth_endpoints() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| """ | ||
| Integration test for JWT validation | ||
| """ | ||
| import asyncio | ||
| import pytest | ||
| from fastapi.testclient import TestClient | ||
| from unittest.mock import patch, AsyncMock | ||
| import sys | ||
| import os | ||
|
|
||
| # Add the backend directory to the path | ||
| sys.path.insert(0, os.path.join(os.path.dirname(__file__))) | ||
|
|
||
| # Mock the nlp module to avoid dependencies | ||
| nlp_mock = AsyncMock() | ||
| nlp_mock.chat_completion = AsyncMock(return_value=type('MockResult', (), {'result': 'test result', 'extra_data': {}})()) | ||
| nlp_mock.reflection = AsyncMock(return_value=type('MockResult', (), {'result': 'test reflection', 'extra_data': {}})()) | ||
| nlp_mock.chat_stream = AsyncMock() | ||
|
|
||
| with patch.dict(sys.modules, {'nlp': nlp_mock}): | ||
| from server import app | ||
|
|
||
| client = TestClient(app) | ||
|
|
||
|
|
||
| def test_generation_endpoint_without_token(): | ||
| """Test that generation endpoint requires authentication""" | ||
| response = client.post("/api/generation", json={ | ||
| "username": "test", | ||
| "gtype": "Completion", | ||
| "prompt": "test prompt" | ||
| }) | ||
| assert response.status_code == 401 | ||
| assert "Authorization header required" in response.json()["detail"] | ||
|
|
||
|
|
||
| def test_generation_endpoint_with_demo_token(): | ||
| """Test that demo token works""" | ||
| response = client.post("/api/generation", json={ | ||
| "username": "test", | ||
| "gtype": "Completion", | ||
| "prompt": "test prompt" | ||
| }, headers={"Authorization": "Bearer demo-access-token"}) | ||
|
|
||
| # Should succeed (status 200) | ||
| assert response.status_code == 200 | ||
|
|
||
|
|
||
| def test_generation_endpoint_with_invalid_token(): | ||
| """Test that invalid token returns 401""" | ||
| response = client.post("/api/generation", json={ | ||
| "username": "test", | ||
| "gtype": "Completion", | ||
| "prompt": "test prompt" | ||
| }, headers={"Authorization": "Bearer invalid-token"}) | ||
|
|
||
| assert response.status_code == 401 | ||
| assert "Invalid token" in response.json()["detail"] | ||
|
|
||
|
|
||
| def test_reflections_endpoint_without_token(): | ||
| """Test that reflections endpoint requires authentication""" | ||
| response = client.post("/api/reflections", json={ | ||
| "username": "test", | ||
| "paragraph": "test paragraph", | ||
| "prompt": "test prompt" | ||
| }) | ||
| assert response.status_code == 401 | ||
|
|
||
|
|
||
| def test_reflections_endpoint_with_demo_token(): | ||
| """Test that reflections work with demo token""" | ||
| response = client.post("/api/reflections", json={ | ||
| "username": "test", | ||
| "paragraph": "test paragraph", | ||
| "prompt": "test prompt" | ||
| }, headers={"Authorization": "Bearer demo-access-token"}) | ||
|
|
||
| assert response.status_code == 200 | ||
|
|
||
|
|
||
| def test_chat_endpoint_without_token(): | ||
| """Test that chat endpoint requires authentication""" | ||
| response = client.post("/api/chat", json={ | ||
| "messages": [{"role": "user", "content": "test"}], | ||
| "username": "test" | ||
| }) | ||
| assert response.status_code == 401 | ||
|
|
||
|
|
||
| def test_ping_endpoint_no_auth_required(): | ||
| """Test that ping endpoint doesn't require authentication""" | ||
| response = client.get("/api/ping") | ||
| assert response.status_code == 200 | ||
| assert "timestamp" in response.json() | ||
|
|
||
|
|
||
| def test_generation_with_valid_jwt(): | ||
| """Test with a valid JWT token (no signature verification)""" | ||
| import jwt | ||
| payload = {"sub": "test-user", "aud": "textfocals.com", "iss": "https://textfocals.auth0.com/"} | ||
| test_token = jwt.encode(payload, "secret", algorithm="HS256") | ||
|
|
||
| response = client.post("/api/generation", json={ | ||
| "username": "test", | ||
| "gtype": "Completion", | ||
| "prompt": "test prompt" | ||
| }, headers={"Authorization": f"Bearer {test_token}"}) | ||
|
|
||
| assert response.status_code == 200 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| """ | ||
| Test JWT validation functionality | ||
| """ | ||
| import pytest | ||
| from fastapi import HTTPException | ||
| from fastapi.security import HTTPAuthorizationCredentials | ||
| from server import verify_token | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_verify_demo_token(): | ||
| """Test that demo token is handled correctly""" | ||
| credentials = HTTPAuthorizationCredentials( | ||
| scheme="Bearer", | ||
| credentials="demo-access-token" | ||
| ) | ||
|
|
||
| result = await verify_token(credentials) | ||
|
|
||
| assert result["sub"] == "demo-user" | ||
| assert result["username"] == "demo" | ||
| assert result["is_demo"] is True | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_verify_no_token(): | ||
| """Test that missing token raises 401""" | ||
| with pytest.raises(HTTPException) as exc_info: | ||
| await verify_token(None) | ||
|
|
||
| assert exc_info.value.status_code == 401 | ||
| assert "Authorization header required" in str(exc_info.value.detail) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_verify_invalid_token(): | ||
| """Test that invalid token raises 401""" | ||
| credentials = HTTPAuthorizationCredentials( | ||
| scheme="Bearer", | ||
| credentials="invalid-token" | ||
| ) | ||
|
|
||
| with pytest.raises(HTTPException) as exc_info: | ||
| await verify_token(credentials) | ||
|
|
||
| assert exc_info.value.status_code == 401 | ||
| assert "Invalid token" in str(exc_info.value.detail) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_verify_valid_jwt_token(): | ||
| """Test that a valid JWT token is decoded (without signature verification for now)""" | ||
| # Create a simple JWT token for testing (no signature) | ||
| import jwt | ||
| payload = {"sub": "test-user", "aud": "textfocals.com", "iss": "https://textfocals.auth0.com/"} | ||
| test_token = jwt.encode(payload, "secret", algorithm="HS256") | ||
|
|
||
| credentials = HTTPAuthorizationCredentials( | ||
| scheme="Bearer", | ||
| credentials=test_token | ||
| ) | ||
|
|
||
| result = await verify_token(credentials) | ||
|
|
||
| assert result["sub"] == "test-user" | ||
| assert result["is_demo"] is False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to implement this verification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented JWT signature verification by fetching public keys from Auth0's JWKS endpoint and verifying RS256 signatures. Added comprehensive error handling for invalid signatures, expired tokens, and malformed JWTs. Commit: 1de2f54