Skip to content

Commit 11f9fae

Browse files
committed
Authenticate users and use userids to embed documents
1 parent c95f938 commit 11f9fae

File tree

13 files changed

+491
-27
lines changed

13 files changed

+491
-27
lines changed

genai/requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,7 @@ python-multipart==0.0.20
3131
# Testing
3232
pytest==8.4.1
3333
fpdf==1.7.2
34-
pypdf==5.6.0
34+
pypdf==5.6.0
35+
36+
# Authentication
37+
python-jose[cryptography]==3.3.0

genai/routes/routes.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import APIRouter, UploadFile, File, HTTPException, Request
1+
from fastapi import APIRouter, UploadFile, File, HTTPException, Request, Depends
22
from fastapi.responses import JSONResponse
33
import os
44

@@ -15,6 +15,7 @@
1515
prepare_prompt,
1616
process_raw_messages
1717
)
18+
from service.auth_service import get_current_user, UserInfo
1819
from metrics import (
1920
file_upload_request_counter,
2021
file_upload_successfully_counter,
@@ -68,12 +69,16 @@
6869

6970

7071
@router.post("/upload")
71-
async def upload_file(file: UploadFile = File(...)):
72+
async def upload_file(
73+
file: UploadFile = File(...),
74+
current_user: UserInfo = Depends(get_current_user)
75+
):
7276
file_upload_request_counter.inc()
7377
start_time = perf_counter()
7478
logger.info(
75-
"Upload endpoint is called in genai for the file %s",
76-
file.filename
79+
"Upload endpoint is called in genai for the file %s by user %s",
80+
file.filename,
81+
current_user.username
7782
)
7883

7984
if not file.filename.endswith(".pdf"):
@@ -89,7 +94,7 @@ async def upload_file(file: UploadFile = File(...)):
8994
with open(file_path, "wb") as buffer:
9095
buffer.write(await file.read())
9196

92-
collection_name = "recipes"
97+
collection_name = f"recipes_{current_user.user_id}"
9398
if (
9499
qdrant.client.collection_exists(collection_name)
95100
and qdrant.collection_contains_file(
@@ -98,7 +103,7 @@ async def upload_file(file: UploadFile = File(...)):
98103
filename
99104
)
100105
):
101-
logger.info("File already exists in qdrant")
106+
logger.info("File already exists in qdrant for user %s", current_user.username)
102107
return {"message": f"File '{filename}' already uploaded."}
103108

104109
vector_store = qdrant.create_and_get_vector_storage(collection_name)
@@ -130,16 +135,19 @@ async def generate(request: Request):
130135
logger.info("Generate endpoint is called in genai")
131136

132137
body = await request.json()
133-
if "query" not in body or "messages" not in body:
134-
logger.error("Missing 'query' or 'messages' in the request body")
138+
if "query" not in body or "messages" not in body or "user_id" not in body:
139+
logger.error("Missing 'query', 'messages', or 'user_id' in the request body")
135140
raise HTTPException(
136141
status_code=400,
137-
detail="Missing 'query' or 'messages'"
142+
detail="Missing 'query', 'messages', or 'user_id'"
138143
)
139144

140145
query = body["query"]
141146
messages_raw = body["messages"]
142-
collection_name = "recipes"
147+
user_id = body["user_id"]
148+
collection_name = f"recipes_{user_id}"
149+
150+
logger.info("Generate endpoint called for user_id: %s", user_id)
143151

144152
try:
145153
retrieved_docs = ""
@@ -148,8 +156,9 @@ async def generate(request: Request):
148156
collection_name
149157
)
150158
logger.info(
151-
"Vector store is created for the collection %s",
152-
collection_name
159+
"Vector store is created for the collection %s for user_id %s",
160+
collection_name,
161+
user_id
153162
)
154163
retrieved_docs = retrieve_similar_docs(vector_store, query)
155164
logger.info("Similar docs retrieved from the vector store")

genai/service/auth_service.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import requests
2+
from fastapi import HTTPException, Header
3+
from typing import Optional
4+
from logger import logger
5+
6+
7+
class UserInfo:
8+
def __init__(self, user_id: int, username: str):
9+
self.user_id = user_id
10+
self.username = username
11+
12+
13+
async def get_current_user(authorization: Optional[str] = Header(None)) -> UserInfo:
14+
"""
15+
Extract user information from the Authorization header.
16+
17+
This function validates the OAuth token by calling the user service
18+
and returns the user information including user_id.
19+
"""
20+
if not authorization:
21+
logger.error("Authorization header is missing")
22+
raise HTTPException(status_code=401, detail="Authorization header required")
23+
24+
if not authorization.startswith("Bearer "):
25+
logger.error("Invalid authorization header format")
26+
raise HTTPException(status_code=401, detail="Invalid authorization header format")
27+
28+
token = authorization.split(" ")[1]
29+
30+
user_service_urls = [
31+
"http://localhost:8081/user/info", # dev
32+
"http://user-service:8081/user/info" # prod
33+
]
34+
35+
for url in user_service_urls:
36+
try:
37+
response = requests.get(
38+
url,
39+
headers={"Authorization": f"Bearer {token}"},
40+
timeout=10
41+
)
42+
if response.status_code == 200:
43+
user_data = response.json()
44+
logger.info(f"User authenticated: {user_data.get('username')}")
45+
return UserInfo(
46+
user_id=user_data.get("id"),
47+
username=user_data.get("username")
48+
)
49+
else:
50+
logger.error(f"{url} returned status {response.status_code}")
51+
except requests.exceptions.RequestException as e:
52+
logger.warning(f"Failed to reach {url}: {e}")
53+
54+
raise HTTPException(status_code=500, detail="Authentication service unavailable")

genai/tests/integration/generation_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def test_generate_endpoint_success(_mock_exists, mock_invoke):
1717
"messages": [
1818
{"role": "USER", "content": "I have rice and eggs"},
1919
{"role": "ASSISTANT", "content": "How about fried rice?"}
20-
]
20+
],
21+
"user_id": 123
2122
}
2223

2324
response = client.post("/genai/generate", json=payload)
@@ -39,7 +40,8 @@ def test_generate_endpoint_empty_messages(_mock_exists, mock_invoke):
3940

4041
payload = {
4142
"query": "Can I cook with lentils?",
42-
"messages": []
43+
"messages": [],
44+
"user_id": 123
4345
}
4446

4547
response = client.post("/genai/generate", json=payload)
@@ -59,4 +61,4 @@ def test_generate_endpoint_missing_fields():
5961
response = client.post("/genai/generate", json=payload)
6062

6163
assert response.status_code == 400
62-
assert response.json() == {"detail": "Missing 'query' or 'messages'"}
64+
assert response.json() == {"detail": "Missing 'query', 'messages', or 'user_id'"}

genai/tests/integration/upload_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import patch, MagicMock
33
from fastapi.testclient import TestClient
44
from main import app
5+
from genai.service.auth_service import UserInfo, get_current_user
56

67
client = TestClient(app)
78

@@ -14,6 +15,10 @@ def test_upload_file_success(
1415
_mock_vector_store,
1516
_mock_exists
1617
):
18+
# Mock the authentication by overriding the dependency
19+
mock_user = UserInfo(user_id=123, username="test_user")
20+
app.dependency_overrides[get_current_user] = lambda: mock_user
21+
1722
mock_pipeline = MagicMock()
1823
mock_pipeline_class.return_value = mock_pipeline
1924

@@ -31,9 +36,16 @@ def test_upload_file_success(
3136

3237
mock_pipeline_class.assert_called_once()
3338
mock_pipeline.ingest.assert_called_once()
39+
40+
# Clean up the dependency override
41+
app.dependency_overrides.clear()
3442

3543

3644
def test_upload_file_invalid_type():
45+
# Mock the authentication by overriding the dependency
46+
mock_user = UserInfo(user_id=123, username="test_user")
47+
app.dependency_overrides[get_current_user] = lambda: mock_user
48+
3749
file = io.BytesIO(b"just some text")
3850
file.name = "notes.txt"
3951

@@ -45,11 +57,18 @@ def test_upload_file_invalid_type():
4557
assert response.status_code == 400
4658
assert (response.json()["detail"] ==
4759
"Invalid file type. Only PDF files are allowed.")
60+
61+
# Clean up the dependency override
62+
app.dependency_overrides.clear()
4863

4964

5065
@patch("routes.routes.qdrant.client.collection_exists", return_value=True)
5166
@patch("routes.routes.qdrant.collection_contains_file", return_value=True)
5267
def test_upload_file_already_exists(_mock_contains, _mock_exists):
68+
# Mock the authentication
69+
mock_user = UserInfo(user_id=123, username="test_user")
70+
app.dependency_overrides[get_current_user] = lambda: mock_user
71+
5372
file = io.BytesIO(b"%PDF-1.4")
5473
file.name = "existing.pdf"
5574

@@ -60,3 +79,6 @@ def test_upload_file_already_exists(_mock_contains, _mock_exists):
6079

6180
assert response.status_code == 200
6281
assert "already uploaded" in response.json()["message"]
82+
83+
# Clean up the dependency override
84+
app.dependency_overrides.clear()

recipe_pdfs/basic_recipes.pdf

10.5 KB
Binary file not shown.

recipe_pdfs/basic_recipes2.pdf

10.8 KB
Binary file not shown.

0 commit comments

Comments
 (0)