Skip to content

Commit f41dfd2

Browse files
✨ Add S3 file upload functionality and update document creation process
1 parent 540461c commit f41dfd2

File tree

7 files changed

+1017
-887
lines changed

7 files changed

+1017
-887
lines changed
Lines changed: 31 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,45 @@
1-
import uuid
21
from typing import Any
3-
4-
from fastapi import APIRouter, HTTPException
5-
from sqlmodel import func, select
6-
72
from app.api.deps import CurrentUser, SessionDep
8-
from app.models import Document, DocumentCreate, DocumentPublic, DocumentsPublic, DocumentUpdate, Message
3+
from app.models import Document, DocumentCreate, DocumentPublic
4+
from fastapi import APIRouter, BackgroundTasks, File, UploadFile, HTTPException
5+
from app.s3 import upload_file_to_s3, generate_s3_url
96

107
router = APIRouter(prefix="/documents", tags=["documents"])
11-
12-
13-
@router.get("/", response_model=DocumentsPublic)
14-
def read_documents(
15-
session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100
16-
) -> Any:
17-
"""
18-
Retrieve documents.
19-
"""
20-
21-
if current_user.is_superuser:
22-
count_statement = select(func.count()).select_from(Document)
23-
count = session.exec(count_statement).one()
24-
statement = select(Document).offset(skip).limit(limit)
25-
documents = session.exec(statement).all()
26-
else:
27-
count_statement = (
28-
select(func.count())
29-
.select_from(Document)
30-
.where(Document.owner_id == current_user.id)
31-
)
32-
count = session.exec(count_statement).one()
33-
statement = (
34-
select(Document)
35-
.where(Document.owner_id == current_user.id)
36-
.offset(skip)
37-
.limit(limit)
38-
)
39-
documents = session.exec(statement).all()
40-
41-
return DocumentsPublic(data=documents, count=count)
42-
43-
44-
@router.get("/{id}", response_model=DocumentPublic)
45-
def read_document(session: SessionDep, current_user: CurrentUser, id: uuid.UUID) -> Any:
46-
"""
47-
Get document by ID.
48-
"""
49-
document = session.get(Document, id)
50-
if not document:
51-
raise HTTPException(status_code=404, detail="Document not found")
52-
if not current_user.is_superuser and (document.owner_id != current_user.id):
53-
raise HTTPException(status_code=400, detail="Not enough permissions")
54-
return document
55-
56-
8+
579
@router.post("/", response_model=DocumentPublic)
5810
def create_document(
59-
*, session: SessionDep, current_user: CurrentUser, document_in: DocumentCreate
60-
) -> Any:
61-
"""
62-
Create new document.
63-
"""
11+
*, session: SessionDep, current_user: CurrentUser,
12+
background_tasks: BackgroundTasks,
13+
file: UploadFile = File(...),
14+
):
15+
key = None
16+
try:
17+
user_id = current_user.id
18+
key = upload_file_to_s3(file, current_user.id)
19+
except Exception as e:
20+
raise HTTPException(500, f"Failed to upload file: {key}. Error: {str(e)}")
21+
22+
try:
23+
url = generate_s3_url(key)
24+
except Exception as e:
25+
raise HTTPException(500, f"Could not generate URL for file key: {key}")
26+
27+
document_in = DocumentCreate(
28+
filename=file.filename,
29+
content_type=file.content_type,
30+
size=file.size,
31+
s3_url=url,
32+
)
33+
6434
document = Document.model_validate(document_in, update={"owner_id": current_user.id})
65-
session.add(document)
66-
session.commit()
67-
session.refresh(document)
68-
return document
6935

7036

71-
@router.put("/{id}", response_model=DocumentPublic)
72-
def update_document(
73-
*,
74-
session: SessionDep,
75-
current_user: CurrentUser,
76-
id: uuid.UUID,
77-
document_in: DocumentUpdate,
78-
) -> Any:
79-
"""
80-
Update an document.
81-
"""
82-
document = session.get(Document, id)
83-
if not document:
84-
raise HTTPException(status_code=404, detail="Document not found")
85-
if not current_user.is_superuser and (document.owner_id != current_user.id):
86-
raise HTTPException(status_code=400, detail="Not enough permissions")
87-
update_dict = document_in.model_dump(exclude_unset=True)
88-
document.sqlmodel_update(update_dict)
8937
session.add(document)
9038
session.commit()
9139
session.refresh(document)
92-
return document
9340

41+
# 3. Kick off background job
42+
print("Document created, starting background task...")
43+
# background_tasks.add_task(generate_questions, document.id)
9444

95-
@router.delete("/{id}")
96-
def delete_document(
97-
session: SessionDep, current_user: CurrentUser, id: uuid.UUID
98-
) -> Message:
99-
"""
100-
Delete an document.
101-
"""
102-
document = session.get(Document, id)
103-
if not document:
104-
raise HTTPException(status_code=404, detail="Document not found")
105-
if not current_user.is_superuser and (document.owner_id != current_user.id):
106-
raise HTTPException(status_code=400, detail="Not enough permissions")
107-
session.delete(document)
108-
session.commit()
109-
return Message(message="Document deleted successfully")
45+
return document

backend/app/core/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def all_cors_origins(self) -> list[str]:
5757
POSTGRES_PASSWORD: str = ""
5858
POSTGRES_DB: str = ""
5959

60+
AWS_ACCESS_KEY_ID: str = ""
61+
AWS_SECRET_ACCESS_KEY: str = ""
62+
AWS_REGION: str = ""
63+
S3_BUCKET_NAME: str = ""
64+
6065
@computed_field # type: ignore[prop-decorator]
6166
@property
6267
def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:

backend/app/s3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77

88
def upload_file_to_s3(file: UploadFile, user_id: str) -> str:
99
extension = file.filename.split(".")[-1]
10+
bucket = settings.S3_BUCKET_NAME
1011
key = f"documents/{user_id}/{uuid.uuid4()}.{extension}"
1112

12-
s3.upload_fileobj(file.file, settings.S3_BUCKET_NAME, key)
13+
try:
14+
s3.upload_fileobj(file.file, bucket, key)
15+
except Exception as e:
16+
raise Exception(f"Failed to upload file to S3: {str(e)}")
1317

1418
return key
1519

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import uuid
2+
3+
from fastapi.testclient import TestClient
4+
from sqlmodel import Session
5+
from unittest.mock import patch
6+
from app.core.config import settings
7+
from app.tests.utils.document import create_random_document
8+
import io
9+
10+
11+
def skip_test_create_document(
12+
client: TestClient, superuser_token_headers: dict[str, str]
13+
) -> None:
14+
'''Test creating a document with a file upload with the real S3 service.'''
15+
file_content = b"%PDF-1.4 test file content"
16+
17+
response = client.post(
18+
f"{settings.API_V1_STR}/documents/",
19+
headers=superuser_token_headers,
20+
files={
21+
"file": ("example.pdf", io.BytesIO(file_content), "application/pdf")
22+
},
23+
)
24+
25+
assert response.status_code == 200
26+
content = response.json()
27+
assert "id" in content, "actual response: " + str(content)
28+
# assert content["title"] == metadata["title"]
29+
# assert content["description"] == metadata["description"]
30+
# assert "id" in content
31+
# assert "owner_id" in content
32+
33+
def test_create_document(
34+
client: TestClient, superuser_token_headers: dict[str, str]
35+
) -> None:
36+
'''Test creating a document with a file upload using mocked S3.'''
37+
file_content = b"%PDF-1.4 test file content"
38+
39+
with patch("app.api.routes.documents.upload_file_to_s3", return_value="document-slug"):
40+
response = client.post(
41+
f"{settings.API_V1_STR}/documents/",
42+
headers=superuser_token_headers,
43+
files={
44+
"file": ("example.pdf", io.BytesIO(file_content), "application/pdf")
45+
},
46+
)
47+
48+
assert response.status_code == 200, f"Unexpected response: {response.content}"
49+
content = response.json()
50+
assert "id" in content, "actual response: " + str(content)
51+
assert "document-slug" in content["s3_url"], "S3 URL should match mocked value"
52+
assert content["filename"] == "example.pdf", "Filename should match uploaded file"

backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"sqlmodel<1.0.0,>=0.0.21",
1919
# Pin bcrypt until passlib supports the latest
2020
"bcrypt==4.3.0",
21+
"boto3",
2122
"pydantic-settings<3.0.0,>=2.2.1",
2223
"sentry-sdk[fastapi]<2.0.0,>=1.40.6",
2324
"pyjwt<3.0.0,>=2.8.0",

backend/requirements.txt

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
alembic==1.15.2
2+
annotated-types==0.7.0
3+
anyio==4.6.0
4+
-e file:///Users/monicasmith/Documents/personal-projects/study-assistant/backend
5+
bcrypt==4.3.0
6+
boto3==1.39.9
7+
botocore==1.39.9
8+
cachetools==5.5.0
9+
certifi==2024.8.30
10+
cfgv==3.4.0
11+
chardet==5.2.0
12+
charset-normalizer==3.3.2
13+
click==8.1.7
14+
coverage==7.6.1
15+
cssselect==1.2.0
16+
cssutils==2.11.1
17+
distlib==0.3.8
18+
dnspython==2.6.1
19+
email-validator==2.2.0
20+
emails==0.6
21+
fastapi==0.115.0
22+
fastapi-cli==0.0.5
23+
filelock==3.16.1
24+
h11==0.14.0
25+
httpcore==1.0.5
26+
httptools==0.6.1
27+
httpx==0.28.1
28+
identify==2.6.1
29+
idna==3.10
30+
iniconfig==2.0.0
31+
jinja2==3.1.6
32+
jmespath==1.0.1
33+
lxml==5.3.0
34+
mako==1.3.5
35+
markdown-it-py==3.0.0
36+
markupsafe==2.1.5
37+
mdurl==0.1.2
38+
more-itertools==10.5.0
39+
mypy==1.11.2
40+
mypy-extensions==1.0.0
41+
nodeenv==1.9.1
42+
packaging==24.1
43+
passlib==1.7.4
44+
platformdirs==4.3.6
45+
pluggy==1.5.0
46+
pre-commit==3.8.0
47+
premailer==3.10.0
48+
psycopg==3.2.2
49+
psycopg-binary==3.2.2
50+
pydantic==2.9.2
51+
pydantic-core==2.23.4
52+
pydantic-settings==2.9.1
53+
pygments==2.18.0
54+
pyjwt==2.10.1
55+
pytest==7.4.4
56+
python-dateutil==2.9.0.post0
57+
python-dotenv==1.0.1
58+
python-multipart==0.0.20
59+
pyyaml==6.0.2
60+
requests==2.32.3
61+
rich==13.8.1
62+
ruff==0.6.7
63+
s3transfer==0.13.1
64+
sentry-sdk==1.45.1
65+
shellingham==1.5.4
66+
six==1.16.0
67+
sniffio==1.3.1
68+
sqlalchemy==2.0.35
69+
sqlmodel==0.0.24
70+
starlette==0.38.6
71+
tenacity==8.5.0
72+
typer==0.12.5
73+
types-passlib==1.7.7.20240819
74+
typing-extensions==4.12.2
75+
typing-inspection==0.4.0
76+
urllib3==2.2.3
77+
uvicorn==0.30.6
78+
uvloop==0.20.0
79+
virtualenv==20.26.5
80+
watchfiles==0.24.0
81+
websockets==13.1

0 commit comments

Comments
 (0)