Skip to content

Commit d141f6b

Browse files
committed
feat: add type filtering on job status endpoints
1 parent 8692e58 commit d141f6b

File tree

4 files changed

+115
-11
lines changed

4 files changed

+115
-11
lines changed

app/routers/jobs_status.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
import asyncio
22
import json
3+
from typing import List
34

4-
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
5+
from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect
56
from sqlalchemy.orm import Session
67
from loguru import logger
78

89
from app.database.db import SessionLocal, get_db
9-
from app.schemas.jobs_status import JobsStatusResponse
10+
from app.schemas.jobs_status import JobsFilter, JobsStatusResponse
1011
from app.schemas.websockets import WSStatusMessage
1112
from app.services.processing import get_processing_jobs_by_user_id
1213
from app.services.upscaling import get_upscaling_tasks_by_user_id
1314
from app.auth import get_current_user_id, websocket_authenticate
1415

1516
router = APIRouter()
1617

18+
DEFAULT_FILTERS = [JobsFilter.upscaling, JobsFilter.processing]
19+
1720

1821
@router.get(
1922
"/jobs_status",
@@ -23,14 +26,28 @@
2326
async def get_jobs_status(
2427
db: Session = Depends(get_db),
2528
user: str = Depends(get_current_user_id),
29+
filter: List[JobsFilter] = Query(
30+
DEFAULT_FILTERS,
31+
description="Filter jobs: upscaling, processing. Can be provided multiple times.",
32+
),
2633
) -> JobsStatusResponse:
2734
"""
2835
Return combined list of upscaling tasks and processing jobs for the authenticated user.
2936
"""
3037
logger.debug(f"Fetching jobs list for user {user}")
38+
upscaling_tasks = (
39+
get_upscaling_tasks_by_user_id(db, user)
40+
if JobsFilter.upscaling in filter
41+
else []
42+
)
43+
processing_jobs = (
44+
get_processing_jobs_by_user_id(db, user)
45+
if JobsFilter.processing in filter
46+
else []
47+
)
3148
return JobsStatusResponse(
32-
upscaling_tasks=get_upscaling_tasks_by_user_id(db, user),
33-
processing_jobs=get_processing_jobs_by_user_id(db, user),
49+
upscaling_tasks=upscaling_tasks,
50+
processing_jobs=processing_jobs,
3451
)
3552

3653

@@ -40,6 +57,7 @@ async def get_jobs_status(
4057
async def ws_jobs_status(
4158
websocket: WebSocket,
4259
interval: int = 10,
60+
filter: List[JobsFilter] = Query(DEFAULT_FILTERS),
4361
):
4462
"""
4563
Return combined list of upscaling tasks and processing jobs for the authenticated user.
@@ -62,7 +80,7 @@ async def ws_jobs_status(
6280
message="Starting retrieval of status",
6381
).model_dump()
6482
)
65-
status = await get_jobs_status(db, user)
83+
status = await get_jobs_status(db, user, filter=filter)
6684
await websocket.send_json(
6785
WSStatusMessage(
6886
type="status",

app/schemas/jobs_status.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import List
23
from pydantic import BaseModel, Field
34

@@ -12,3 +13,8 @@ class JobsStatusResponse(BaseModel):
1213
processing_jobs: List[ProcessingJobSummary] = Field(
1314
..., description="List of processing jobs that are available for the user"
1415
)
16+
17+
18+
class JobsFilter(str, Enum):
19+
upscaling = "upscaling"
20+
processing = "processing"

openapi.json

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,31 @@
156156
"summary": "Get a list of all upscaling tasks & processing jobs for the authenticated user",
157157
"description": "Return combined list of upscaling tasks and processing jobs for the authenticated user.",
158158
"operationId": "get_jobs_status_jobs_status_get",
159+
"security": [
160+
{
161+
"OAuth2AuthorizationCodeBearer": []
162+
}
163+
],
164+
"parameters": [
165+
{
166+
"name": "filter",
167+
"in": "query",
168+
"required": false,
169+
"schema": {
170+
"type": "array",
171+
"items": {
172+
"$ref": "#/components/schemas/JobsFilter"
173+
},
174+
"description": "Filter jobs: upscaling, processing. Can be provided multiple times.",
175+
"default": [
176+
"upscaling",
177+
"processing"
178+
],
179+
"title": "Filter"
180+
},
181+
"description": "Filter jobs: upscaling, processing. Can be provided multiple times."
182+
}
183+
],
159184
"responses": {
160185
"200": {
161186
"description": "Successful Response",
@@ -166,13 +191,18 @@
166191
}
167192
}
168193
}
194+
},
195+
"422": {
196+
"description": "Validation Error",
197+
"content": {
198+
"application/json": {
199+
"schema": {
200+
"$ref": "#/components/schemas/HTTPValidationError"
201+
}
202+
}
203+
}
169204
}
170-
},
171-
"security": [
172-
{
173-
"OAuth2AuthorizationCodeBearer": []
174-
}
175-
]
205+
}
176206
}
177207
},
178208
"/unit_jobs": {
@@ -972,6 +1002,14 @@
9721002
"type": "object",
9731003
"title": "HTTPValidationError"
9741004
},
1005+
"JobsFilter": {
1006+
"type": "string",
1007+
"enum": [
1008+
"upscaling",
1009+
"processing"
1010+
],
1011+
"title": "JobsFilter"
1012+
},
9751013
"JobsStatusResponse": {
9761014
"properties": {
9771015
"upscaling_tasks": {

tests/routers/test_job_status.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,48 @@ def test_unit_jobs_get_200(
2828
).model_dump_json(indent=1)
2929

3030

31+
@patch("app.routers.jobs_status.get_processing_jobs_by_user_id")
32+
@patch("app.routers.jobs_status.get_upscaling_tasks_by_user_id")
33+
def test_unit_jobs_get_only_processing_200(
34+
mock_get_upscaling_tasks,
35+
mock_get_processing_jobs,
36+
client,
37+
fake_processing_job_summary,
38+
fake_upscaling_task_summary,
39+
):
40+
41+
mock_get_processing_jobs.return_value = [fake_processing_job_summary]
42+
mock_get_upscaling_tasks.return_value = [fake_upscaling_task_summary]
43+
44+
r = client.get("/jobs_status?filter=processing")
45+
assert r.status_code == 200
46+
assert json.dumps(r.json(), indent=1) == JobsStatusResponse(
47+
upscaling_tasks=[],
48+
processing_jobs=[fake_processing_job_summary],
49+
).model_dump_json(indent=1)
50+
51+
52+
@patch("app.routers.jobs_status.get_processing_jobs_by_user_id")
53+
@patch("app.routers.jobs_status.get_upscaling_tasks_by_user_id")
54+
def test_unit_jobs_get_only_upscaling_200(
55+
mock_get_upscaling_tasks,
56+
mock_get_processing_jobs,
57+
client,
58+
fake_processing_job_summary,
59+
fake_upscaling_task_summary,
60+
):
61+
62+
mock_get_processing_jobs.return_value = [fake_processing_job_summary]
63+
mock_get_upscaling_tasks.return_value = [fake_upscaling_task_summary]
64+
65+
r = client.get("/jobs_status?filter=upscaling")
66+
assert r.status_code == 200
67+
assert json.dumps(r.json(), indent=1) == JobsStatusResponse(
68+
upscaling_tasks=[fake_upscaling_task_summary],
69+
processing_jobs=[],
70+
).model_dump_json(indent=1)
71+
72+
3173
@pytest.mark.asyncio
3274
@patch("app.auth.get_current_user_id", new_callable=AsyncMock)
3375
@patch("app.routers.jobs_status.get_jobs_status", new_callable=AsyncMock)

0 commit comments

Comments
 (0)