Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from pathlib import Path
from typing import Annotated, Dict, List, Literal, Optional

import jwt
import nlp
import uvicorn
from dotenv import load_dotenv
from fastapi import BackgroundTasks, Body, FastAPI
from fastapi import BackgroundTasks, Body, Depends, FastAPI, HTTPException, Request
from fastapi.exception_handlers import request_validation_exception_handler
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.staticfiles import StaticFiles
from pydantic import AfterValidator, BaseModel, ConfigDict
from sse_starlette.sse import EventSourceResponse
Expand All @@ -36,6 +38,60 @@
LOG_SECRET = os.getenv("LOG_SECRET", "").strip()
print(f"Log secret: {LOG_SECRET!r}")

# JWT Configuration for Auth0
JWT_ISSUER = os.getenv("JWT_ISSUER", "https://textfocals.auth0.com/")
JWT_AUDIENCE = os.getenv("JWT_AUDIENCE", "textfocals.com")
DEMO_API_KEY = os.getenv("DEMO_API_KEY", "demo-key")

# Initialize HTTP Bearer security scheme
security = HTTPBearer(auto_error=False)


# Initialize HTTP Bearer security scheme
security = HTTPBearer(auto_error=False)


async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> dict:
"""
Verify JWT token and return decoded payload.
Handles both Auth0 tokens and demo tokens.
"""
if not credentials:
raise HTTPException(status_code=401, detail="Authorization header required")

token = credentials.credentials

# Handle demo token
if token == "demo-access-token":
return {
"sub": "demo-user",
"username": "demo",
"iss": "demo",
"aud": JWT_AUDIENCE,
"is_demo": True
}

try:
# For Auth0 tokens, we typically need to fetch the public key
# For now, we'll do basic JWT validation without signature verification
# In production, you'd want to verify the signature against Auth0's public key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to implement this verification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented JWT signature verification by fetching public keys from Auth0's JWKS endpoint and verifying RS256 signatures. Added comprehensive error handling for invalid signatures, expired tokens, and malformed JWTs. Commit: 1de2f54

payload = jwt.decode(
token,
options={"verify_signature": False, "verify_aud": False, "verify_iss": False}
)

# Basic validation
if not payload.get("sub"):
raise HTTPException(status_code=401, detail="Invalid token: missing subject")

payload["is_demo"] = False
return payload

except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")


def should_log(username: str) -> bool:
"""
Expand Down Expand Up @@ -130,7 +186,11 @@ async def validation_exception_handler(request, exc):

# Routes
@app.post("/api/generation")
async def generation(payload: GenerationRequestPayload, background_tasks: BackgroundTasks) -> nlp.GenerationResult:
async def generation(
payload: GenerationRequestPayload,
background_tasks: BackgroundTasks,
token_data: dict = Depends(verify_token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this auth_token_data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed parameter from token_data to auth_token_data across all protected endpoints (generation, reflections, chat). Commit: 1de2f54

) -> nlp.GenerationResult:
'''
To test this endpoint from curl:

Expand Down Expand Up @@ -180,7 +240,11 @@ async def generation(payload: GenerationRequestPayload, background_tasks: Backgr


@app.post("/api/reflections")
async def reflections(payload: ReflectionRequestPayload, background_tasks: BackgroundTasks):
async def reflections(
payload: ReflectionRequestPayload,
background_tasks: BackgroundTasks,
token_data: dict = Depends(verify_token)
):
should_log_doctext = should_log(payload.username)

start_time = datetime.now()
Expand All @@ -203,7 +267,10 @@ async def reflections(payload: ReflectionRequestPayload, background_tasks: Backg


@app.post("/api/chat")
async def chat(payload: ChatRequestPayload):
async def chat(
payload: ChatRequestPayload,
token_data: dict = Depends(verify_token)
):
response = await nlp.chat_stream(
messages=payload.messages,
temperature=0.7,
Expand Down
88 changes: 88 additions & 0 deletions backend/test_auth_manual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Simple test for authentication endpoints
"""
import requests
import time
import subprocess
import sys
import os
import signal
from pathlib import Path


def test_auth_endpoints():
"""Test authentication on live server"""

print("Testing authentication endpoints...")

# Test data
generation_data = {
"username": "test",
"gtype": "Completion",
"prompt": "test prompt"
}

reflections_data = {
"username": "test",
"paragraph": "test paragraph",
"prompt": "test prompt"
}

chat_data = {
"messages": [{"role": "user", "content": "test"}],
"username": "test"
}

base_url = "http://localhost:8000"

# Test 1: Endpoints without auth should return 401
print("1. Testing endpoints without authorization...")

response = requests.post(f"{base_url}/api/generation", json=generation_data)
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
print(" ✓ /api/generation requires auth")

response = requests.post(f"{base_url}/api/reflections", json=reflections_data)
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
print(" ✓ /api/reflections requires auth")

response = requests.post(f"{base_url}/api/chat", json=chat_data)
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
print(" ✓ /api/chat requires auth")

# Test 2: Demo token should work
print("2. Testing demo token...")
headers = {"Authorization": "Bearer demo-access-token"}

response = requests.post(f"{base_url}/api/generation", json=generation_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
print(" ✓ /api/generation works with demo token")

response = requests.post(f"{base_url}/api/reflections", json=reflections_data, headers=headers)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
print(" ✓ /api/reflections works with demo token")

# Test 3: Invalid token should return 401
print("3. Testing invalid token...")
headers = {"Authorization": "Bearer invalid-token"}

response = requests.post(f"{base_url}/api/generation", json=generation_data, headers=headers)
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
print(" ✓ /api/generation rejects invalid token")

response = requests.post(f"{base_url}/api/reflections", json=reflections_data, headers=headers)
assert response.status_code == 401, f"Expected 401, got {response.status_code}"
print(" ✓ /api/reflections rejects invalid token")

# Test 4: Ping should not require auth
print("4. Testing ping endpoint...")
response = requests.get(f"{base_url}/api/ping")
assert response.status_code == 200, f"Expected 200, got {response.status_code}"
assert "timestamp" in response.json()
print(" ✓ /api/ping works without auth")

print("\n✅ All authentication tests passed!")


if __name__ == "__main__":
test_auth_endpoints()
110 changes: 110 additions & 0 deletions backend/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Integration test for JWT validation
"""
import asyncio
import pytest
from fastapi.testclient import TestClient
from unittest.mock import patch, AsyncMock
import sys
import os

# Add the backend directory to the path
sys.path.insert(0, os.path.join(os.path.dirname(__file__)))

# Mock the nlp module to avoid dependencies
nlp_mock = AsyncMock()
nlp_mock.chat_completion = AsyncMock(return_value=type('MockResult', (), {'result': 'test result', 'extra_data': {}})())
nlp_mock.reflection = AsyncMock(return_value=type('MockResult', (), {'result': 'test reflection', 'extra_data': {}})())
nlp_mock.chat_stream = AsyncMock()

with patch.dict(sys.modules, {'nlp': nlp_mock}):
from server import app

client = TestClient(app)


def test_generation_endpoint_without_token():
"""Test that generation endpoint requires authentication"""
response = client.post("/api/generation", json={
"username": "test",
"gtype": "Completion",
"prompt": "test prompt"
})
assert response.status_code == 401
assert "Authorization header required" in response.json()["detail"]


def test_generation_endpoint_with_demo_token():
"""Test that demo token works"""
response = client.post("/api/generation", json={
"username": "test",
"gtype": "Completion",
"prompt": "test prompt"
}, headers={"Authorization": "Bearer demo-access-token"})

# Should succeed (status 200)
assert response.status_code == 200


def test_generation_endpoint_with_invalid_token():
"""Test that invalid token returns 401"""
response = client.post("/api/generation", json={
"username": "test",
"gtype": "Completion",
"prompt": "test prompt"
}, headers={"Authorization": "Bearer invalid-token"})

assert response.status_code == 401
assert "Invalid token" in response.json()["detail"]


def test_reflections_endpoint_without_token():
"""Test that reflections endpoint requires authentication"""
response = client.post("/api/reflections", json={
"username": "test",
"paragraph": "test paragraph",
"prompt": "test prompt"
})
assert response.status_code == 401


def test_reflections_endpoint_with_demo_token():
"""Test that reflections work with demo token"""
response = client.post("/api/reflections", json={
"username": "test",
"paragraph": "test paragraph",
"prompt": "test prompt"
}, headers={"Authorization": "Bearer demo-access-token"})

assert response.status_code == 200


def test_chat_endpoint_without_token():
"""Test that chat endpoint requires authentication"""
response = client.post("/api/chat", json={
"messages": [{"role": "user", "content": "test"}],
"username": "test"
})
assert response.status_code == 401


def test_ping_endpoint_no_auth_required():
"""Test that ping endpoint doesn't require authentication"""
response = client.get("/api/ping")
assert response.status_code == 200
assert "timestamp" in response.json()


def test_generation_with_valid_jwt():
"""Test with a valid JWT token (no signature verification)"""
import jwt
payload = {"sub": "test-user", "aud": "textfocals.com", "iss": "https://textfocals.auth0.com/"}
test_token = jwt.encode(payload, "secret", algorithm="HS256")

response = client.post("/api/generation", json={
"username": "test",
"gtype": "Completion",
"prompt": "test prompt"
}, headers={"Authorization": f"Bearer {test_token}"})

assert response.status_code == 200
66 changes: 66 additions & 0 deletions backend/test_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Test JWT validation functionality
"""
import pytest
from fastapi import HTTPException
from fastapi.security import HTTPAuthorizationCredentials
from server import verify_token


@pytest.mark.asyncio
async def test_verify_demo_token():
"""Test that demo token is handled correctly"""
credentials = HTTPAuthorizationCredentials(
scheme="Bearer",
credentials="demo-access-token"
)

result = await verify_token(credentials)

assert result["sub"] == "demo-user"
assert result["username"] == "demo"
assert result["is_demo"] is True


@pytest.mark.asyncio
async def test_verify_no_token():
"""Test that missing token raises 401"""
with pytest.raises(HTTPException) as exc_info:
await verify_token(None)

assert exc_info.value.status_code == 401
assert "Authorization header required" in str(exc_info.value.detail)


@pytest.mark.asyncio
async def test_verify_invalid_token():
"""Test that invalid token raises 401"""
credentials = HTTPAuthorizationCredentials(
scheme="Bearer",
credentials="invalid-token"
)

with pytest.raises(HTTPException) as exc_info:
await verify_token(credentials)

assert exc_info.value.status_code == 401
assert "Invalid token" in str(exc_info.value.detail)


@pytest.mark.asyncio
async def test_verify_valid_jwt_token():
"""Test that a valid JWT token is decoded (without signature verification for now)"""
# Create a simple JWT token for testing (no signature)
import jwt
payload = {"sub": "test-user", "aud": "textfocals.com", "iss": "https://textfocals.auth0.com/"}
test_token = jwt.encode(payload, "secret", algorithm="HS256")

credentials = HTTPAuthorizationCredentials(
scheme="Bearer",
credentials=test_token
)

result = await verify_token(credentials)

assert result["sub"] == "test-user"
assert result["is_demo"] is False
Loading