Skip to content

Commit 2cb0a3f

Browse files
committed
refactor: updated upscale_task creation to include background tasks
1 parent 3989df1 commit 2cb0a3f

File tree

7 files changed

+162
-93
lines changed

7 files changed

+162
-93
lines changed

app/database/models/processing_job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def save_job_to_db(
5656
db_session.add(job)
5757
db_session.commit()
5858
db_session.refresh(job) # Refresh to get the ID after commit
59-
logger.debug("Processing job saved with ID: {job.id}")
59+
logger.debug(f"Processing job saved with ID: {job.id}")
6060
return job
6161

6262

app/routers/upscale_tasks.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
from typing import Annotated
44
from fastapi import (
5+
BackgroundTasks,
56
Body,
67
APIRouter,
78
Depends,
@@ -25,7 +26,11 @@
2526
UpscalingTaskSummary,
2627
)
2728
from app.schemas.websockets import WSTaskStatusMessage
28-
from app.services.upscaling import create_upscaling_task, get_upscaling_task_by_user_id
29+
from app.services.upscaling import (
30+
create_upscaling_processing_jobs,
31+
create_upscaling_task,
32+
get_upscaling_task_by_user_id,
33+
)
2934

3035
# from app.auth import get_current_user
3136

@@ -98,12 +103,21 @@ async def create_upscale_task(
98103
},
99104
),
100105
],
106+
background_tasks: BackgroundTasks,
101107
db: Session = Depends(get_db),
102108
user: str = "foobar",
103109
) -> UpscalingTaskSummary:
104110
"""Create a new upscaling job with the provided data."""
105111
try:
106-
return create_upscaling_task(db, user, payload)
112+
task = create_upscaling_task(db, user, payload)
113+
background_tasks.add_task(
114+
create_upscaling_processing_jobs,
115+
database=db,
116+
user=user,
117+
request=payload,
118+
upscaling_task_id=task.id,
119+
)
120+
return task
107121
except Exception as e:
108122
logger.exception(f"Error creating upscale task for user {user}: {e}")
109123
raise HTTPException(

app/services/upscaling.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
}
3030

3131

32-
def _create_upscaling_processing_jobs(
32+
def create_upscaling_processing_jobs(
3333
database: Session, user: str, request: UpscalingTaskRequest, upscaling_task_id: int
3434
) -> List[ProcessingJobSummary]:
3535
jobs: List[ProcessingJobSummary] = []
@@ -72,12 +72,6 @@ def create_upscaling_task(
7272
service=request.service.model_dump_json(),
7373
)
7474
record = save_upscaling_task_to_db(database, record)
75-
76-
logger.info(f"Creating upscaling job for {user} with request: {request}")
77-
_create_upscaling_processing_jobs(
78-
database=database, user=user, request=request, upscaling_task_id=record.id
79-
)
80-
8175
return UpscalingTaskSummary(
8276
id=record.id, title=record.title, label=record.label, status=record.status
8377
)

guides/upscaling_example.ipynb

Lines changed: 92 additions & 70 deletions
Large diffs are not rendered by default.

tests/routers/test_unit_jobs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,15 @@ def test_unit_jobs_get_job_results_404(mock_get_processing_job_results, client):
8181
r = client.get("/unit_jobs/1/results")
8282
assert r.status_code == 404
8383
assert "result for processing job 1 not found" in r.json().get("detail", "").lower()
84+
85+
86+
@patch("app.routers.unit_jobs.get_processing_job_results")
87+
def test_unit_jobs_get_job_results_500(mock_get_processing_job_results, client):
88+
89+
mock_get_processing_job_results.side_effect = RuntimeError(
90+
"Database connection lost"
91+
)
92+
93+
r = client.get("/unit_jobs/1/results")
94+
assert r.status_code == 500
95+
assert "database connection lost" in r.json().get("detail", "").lower()

tests/routers/test_upscale_tasks.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import pytest
66

77

8+
@patch("app.routers.upscale_tasks.create_upscaling_processing_jobs")
89
@patch("app.routers.upscale_tasks.create_upscaling_task")
910
def test_upscaling_task_create_201(
1011
mock_create_upscaling_task,
12+
mock_create_processing_jobs,
1113
client,
1214
fake_upscaling_task_request,
1315
fake_upscaling_task_summary,
@@ -18,6 +20,7 @@ def test_upscaling_task_create_201(
1820
r = client.post("/upscale_tasks", json=fake_upscaling_task_request.model_dump())
1921
assert r.status_code == 201
2022
assert r.json() == fake_upscaling_task_summary.model_dump()
23+
assert mock_create_processing_jobs.called_once()
2124

2225

2326
@patch("app.routers.upscale_tasks.create_upscaling_task")
@@ -63,9 +66,7 @@ def test_upscaling_task_get_task_404(mock_get_upscale_task, client):
6366

6467

6568
@pytest.mark.asyncio
66-
@patch(
67-
"app.routers.upscale_tasks.get_upscale_task", new_callable=AsyncMock
68-
)
69+
@patch("app.routers.upscale_tasks.get_upscale_task", new_callable=AsyncMock)
6970
async def test_ws_jobs_status(mock_get_task_status, client, fake_upscaling_task):
7071
mock_get_task_status.return_value = fake_upscaling_task
7172

@@ -77,9 +78,7 @@ async def test_ws_jobs_status(mock_get_task_status, client, fake_upscaling_task)
7778

7879

7980
@pytest.mark.asyncio
80-
@patch(
81-
"app.routers.upscale_tasks.get_upscale_task", new_callable=AsyncMock
82-
)
81+
@patch("app.routers.upscale_tasks.get_upscale_task", new_callable=AsyncMock)
8382
async def test_ws_jobs_status_closes_on_error(mock_get_task_status, client):
8483
mock_get_task_status.side_effect = RuntimeError("Database connection lost")
8584

@@ -90,3 +89,18 @@ async def test_ws_jobs_status_closes_on_error(mock_get_task_status, client):
9089
websocket.receive_json()
9190

9291
assert exc_info.value.code == 1011
92+
93+
94+
@pytest.mark.asyncio
95+
@patch("app.routers.upscale_tasks.get_upscale_task", new_callable=AsyncMock)
96+
async def test_ws_jobs_status_not_found(
97+
mock_get_task_status, client, fake_upscaling_task
98+
):
99+
mock_get_task_status.return_value = None
100+
101+
with client.websocket_connect("/ws/upscale_tasks/1?interval=1") as websocket:
102+
websocket.receive_json()
103+
websocket.receive_json()
104+
data = websocket.receive_json()
105+
assert data["type"] == "error"
106+
assert data["message"].lower() == "task not found"

tests/services/test_upscaling.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from app.services.upscaling import (
1111
_get_upscale_status,
1212
_refresh_record_status,
13+
create_upscaling_processing_jobs,
1314
create_upscaling_task,
1415
get_upscaling_task_by_user_id,
1516
get_upscaling_tasks_by_user_id,
@@ -38,21 +39,34 @@ def make_upscaling_record(status: ProcessingJobSummary) -> UpscalingTaskRecord:
3839
)
3940

4041

41-
@patch("app.services.upscaling.create_processing_job")
4242
@patch("app.services.upscaling.save_upscaling_task_to_db")
43-
def test_create_upscaling_task_creates_jobs(
43+
def test_create_upscaling_task_creates_task(
4444
mock_save_upscaling_task,
45-
mock_create_processing_job,
4645
fake_upscaling_task_request,
4746
fake_upscaling_task_record,
4847
fake_upscaling_task_summary,
49-
fake_processing_job_summary,
5048
fake_db_session,
5149
):
5250
user = "foobar"
5351
mock_save_upscaling_task.return_value = fake_upscaling_task_record
54-
mock_create_processing_job.return_value = fake_processing_job_summary
5552
result = create_upscaling_task(fake_db_session, user, fake_upscaling_task_request)
53+
mock_save_upscaling_task.assert_called_once()
54+
assert result == fake_upscaling_task_summary
55+
56+
57+
@patch("app.services.upscaling.create_processing_job")
58+
def test_create_upscaling_task_creates_jobs(
59+
mock_create_processing_job,
60+
fake_upscaling_task_request,
61+
fake_upscaling_task_record,
62+
fake_processing_job_summary,
63+
fake_db_session,
64+
):
65+
user = "foobar"
66+
mock_create_processing_job.return_value = fake_processing_job_summary
67+
result = create_upscaling_processing_jobs(
68+
fake_db_session, user, fake_upscaling_task_request, 1
69+
)
5670

5771
expected_calls = [
5872
call(
@@ -73,8 +87,7 @@ def test_create_upscaling_task_creates_jobs(
7387
]
7488

7589
mock_create_processing_job.assert_has_calls(expected_calls)
76-
mock_save_upscaling_task.assert_called_once()
77-
assert result == fake_upscaling_task_summary
90+
assert len(result) == len(fake_upscaling_task_request.dimension.values)
7891

7992

8093
def test_returns_running_if_any_running():

0 commit comments

Comments
 (0)