Skip to content

Commit eacd416

Browse files
authored
feat: updated websockets for consistent communication (#13)
* feat: added initial support for ws * refactor: streamlining of websocket communication * fix: linting issues
1 parent 6125e1e commit eacd416

File tree

8 files changed

+227
-49
lines changed

8 files changed

+227
-49
lines changed

app/routers/jobs_status.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
2+
import json
23

34
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
45
from sqlalchemy.orm import Session
56
from loguru import logger
67

7-
from app.database.db import get_db
8+
from app.database.db import SessionLocal, get_db
89
from app.schemas.jobs_status import JobsStatusResponse
10+
from app.schemas.websockets import WSStatusMessage
911
from app.services.processing import get_processing_jobs_by_user_id
1012
from app.services.upscaling import get_upscaling_tasks_by_user_id
1113

@@ -44,13 +46,27 @@ async def ws_jobs_status(
4446
await websocket.accept()
4547
logger.debug(f"WebSocket connected for user {user}")
4648

47-
db = next(get_db())
49+
await websocket.send_json(
50+
WSStatusMessage(type="init", message="Starting status stream").model_dump()
51+
)
52+
4853
try:
4954
while True:
50-
status = await get_jobs_status(db, user)
51-
await websocket.send_json(status.model_dump())
52-
53-
await asyncio.sleep(interval)
55+
with SessionLocal() as db:
56+
await websocket.send_json(
57+
WSStatusMessage(
58+
type="loading",
59+
message="Starting retrieval of status",
60+
).model_dump()
61+
)
62+
status = await get_jobs_status(db, user)
63+
await websocket.send_json(
64+
WSStatusMessage(
65+
type="status",
66+
data=json.loads(status.model_dump_json()),
67+
).model_dump()
68+
)
69+
await asyncio.sleep(interval)
5470

5571
except WebSocketDisconnect:
5672
logger.info(f"WebSocket disconnected for user {user}")

app/routers/upscale_tasks.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1+
import asyncio
2+
import json
13
from typing import Annotated
2-
from fastapi import Body, APIRouter, Depends, HTTPException, status
4+
from fastapi import (
5+
Body,
6+
APIRouter,
7+
Depends,
8+
HTTPException,
9+
WebSocket,
10+
WebSocketDisconnect,
11+
status,
12+
)
313
from loguru import logger
414
from sqlalchemy.orm import Session
515

6-
from app.database.db import get_db
16+
from app.database.db import SessionLocal, get_db
717
from app.schemas.enum import ProcessTypeEnum
818
from app.schemas.unit_job import (
919
ServiceDetails,
@@ -14,6 +24,7 @@
1424
UpscalingTaskRequest,
1525
UpscalingTaskSummary,
1626
)
27+
from app.schemas.websockets import WSTaskStatusMessage
1728
from app.services.upscaling import create_upscaling_task, get_upscaling_task_by_user_id
1829

1930
# from app.auth import get_current_user
@@ -117,3 +128,59 @@ async def get_upscale_task(
117128
detail=f"Upscale task {task_id} not found",
118129
)
119130
return job
131+
132+
133+
@router.websocket(
134+
"/ws/upscale_tasks/{task_id}",
135+
)
136+
async def ws_task_status(
137+
websocket: WebSocket,
138+
task_id: int,
139+
user: str = "foobar",
140+
interval: int = 10,
141+
):
142+
await websocket.accept()
143+
logger.info("WebSocket connected", extra={"user": user, "task_id": task_id})
144+
145+
try:
146+
await websocket.send_json(
147+
WSTaskStatusMessage(
148+
type="init", task_id=task_id, message="Starting status stream"
149+
).model_dump()
150+
)
151+
while True:
152+
with SessionLocal() as db:
153+
await websocket.send_json(
154+
WSTaskStatusMessage(
155+
type="loading",
156+
task_id=task_id,
157+
message="Starting retrieval of status",
158+
).model_dump()
159+
)
160+
status = await get_upscale_task(task_id, db, user)
161+
if not status:
162+
await websocket.send_json(
163+
WSTaskStatusMessage(
164+
type="error",
165+
task_id=task_id,
166+
message="Task not found",
167+
).model_dump()
168+
)
169+
await websocket.close(
170+
code=1011, reason=f"Upscale task {task_id} not found"
171+
)
172+
break
173+
await websocket.send_json(
174+
WSTaskStatusMessage(
175+
type="status",
176+
task_id=task_id,
177+
data=json.loads(status.model_dump_json()),
178+
).model_dump()
179+
)
180+
await asyncio.sleep(interval)
181+
182+
except WebSocketDisconnect:
183+
logger.info(f"WebSocket disconnected for user {user}")
184+
except Exception as e:
185+
logger.exception(f"Error in upscaling task status websocket: {e}")
186+
await websocket.close(code=1011, reason=f"Error in job status websocket: {e}")

app/schemas/websockets.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any, Literal, Optional
2+
from pydantic import BaseModel
3+
4+
5+
class WSStatusMessage(BaseModel):
6+
type: Literal["init", "status", "loading", "error"]
7+
data: Optional[Any] = None
8+
message: Optional[str] = None
9+
10+
11+
class WSTaskStatusMessage(WSStatusMessage):
12+
task_id: int

app/services/processing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def create_processing_job(
6666

6767

6868
def get_job_status(job: ProcessingJobRecord) -> ProcessingStatusEnum:
69-
logger.info(f"Retrieving job status for job: {job.platform_job_id}")
69+
logger.info(
70+
f"Retrieving job status for job: {job.platform_job_id} (current: {job.status})"
71+
)
7072
platform = get_processing_platform(job.label)
7173
details = ServiceDetails.model_validate_json(job.service)
7274
return platform.get_job_status(job.platform_job_id, details)

guides/upscaling_example.ipynb

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 1,
14+
"execution_count": 2,
1515
"id": "d99f5fbc",
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"import requests"
19+
"import requests\n",
20+
"import asyncio\n",
21+
"import json\n",
22+
"import websockets\n",
23+
"from ipyleaflet import Map, GeoJSON"
2024
]
2125
},
2226
{
@@ -29,17 +33,17 @@
2933
},
3034
{
3135
"cell_type": "code",
32-
"execution_count": 2,
36+
"execution_count": 3,
3337
"id": "f95065a9",
3438
"metadata": {},
3539
"outputs": [],
3640
"source": [
37-
"dispatch_api = \"http://localhost:8000/\""
41+
"dispatch_api = \"localhost:8000\""
3842
]
3943
},
4044
{
4145
"cell_type": "code",
42-
"execution_count": 14,
46+
"execution_count": 4,
4347
"id": "251a343f",
4448
"metadata": {},
4549
"outputs": [],
@@ -50,7 +54,7 @@
5054
},
5155
{
5256
"cell_type": "code",
53-
"execution_count": 7,
57+
"execution_count": 5,
5458
"id": "c83aa9d5",
5559
"metadata": {},
5660
"outputs": [],
@@ -96,42 +100,40 @@
96100
},
97101
{
98102
"cell_type": "code",
99-
"execution_count": 8,
103+
"execution_count": 6,
100104
"id": "e0618338",
101105
"metadata": {},
102106
"outputs": [],
103107
"source": [
104-
"tiles = requests.post(f\"{dispatch_api}/tiles\", json={\n",
108+
"tiles = requests.post(f\"http://{dispatch_api}/tiles\", json={\n",
105109
" \"grid\": \"20x20km\",\n",
106110
" \"aoi\": area_of_interest\n",
107111
"}).json()"
108112
]
109113
},
110114
{
111115
"cell_type": "code",
112-
"execution_count": 13,
116+
"execution_count": 7,
113117
"id": "6de8d686",
114118
"metadata": {},
115119
"outputs": [
116120
{
117121
"data": {
118122
"application/vnd.jupyter.widget-view+json": {
119-
"model_id": "ef6f633a4eca4ec686c253deb9302479",
123+
"model_id": "eafd8bf582b0472ba6f0b719594d1041",
120124
"version_major": 2,
121125
"version_minor": 0
122126
},
123127
"text/plain": [
124128
"Map(center=[42.251628548555004, -6.37490034623255], controls=(ZoomControl(options=['position', 'zoom_in_text',…"
125129
]
126130
},
127-
"execution_count": 13,
131+
"execution_count": 7,
128132
"metadata": {},
129133
"output_type": "execute_result"
130134
}
131135
],
132136
"source": [
133-
"from ipyleaflet import Map, GeoJSON\n",
134-
"\n",
135137
"# Create a map centered at the approximate center of the area of interest\n",
136138
"m = Map(center=[42.251628548555004, -6.37490034623255], zoom=8)\n",
137139
" \n",
@@ -174,7 +176,7 @@
174176
}
175177
],
176178
"source": [
177-
"upscaling_task = requests.post(f\"{dispatch_api}/upscale_tasks\", json={\n",
179+
"upscaling_task = requests.post(f\"http://{dispatch_api}/upscale_tasks\", json={\n",
178180
" \"title\": \"Forest Fire Detection\",\n",
179181
" \"label\": \"openeo\",\n",
180182
" \"service\": {\n",
@@ -203,14 +205,26 @@
203205
},
204206
{
205207
"cell_type": "code",
206-
"execution_count": null,
208+
"execution_count": 9,
209+
"id": "ce1e3da9-d0ae-4c57-b42e-143f5dc109fd",
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"upscaling_task = {\n",
214+
" \"id\": 4\n",
215+
"}"
216+
]
217+
},
218+
{
219+
"cell_type": "code",
220+
"execution_count": 38,
207221
"id": "ac428293-7cd4-49a8-9bfa-4e0dc8f4d2cc",
208222
"metadata": {},
209223
"outputs": [
210224
{
211225
"data": {
212226
"application/vnd.jupyter.widget-view+json": {
213-
"model_id": "a89dccec118c4dbe829f9b97b34eb872",
227+
"model_id": "42f8c0e923364838ab783c8c7fe48dcb",
214228
"version_major": 2,
215229
"version_minor": 0
216230
},
@@ -220,13 +234,26 @@
220234
},
221235
"metadata": {},
222236
"output_type": "display_data"
237+
},
238+
{
239+
"ename": "ConnectionClosedError",
240+
"evalue": "received 1012 (service restart); then sent 1012 (service restart)",
241+
"output_type": "error",
242+
"traceback": [
243+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
244+
"\u001b[0;31mConnectionResetError\u001b[0m Traceback (most recent call last)",
245+
"File \u001b[0;32m~/.pyenv/versions/3.10.12/lib/python3.10/asyncio/selector_events.py:862\u001b[0m, in \u001b[0;36m_SelectorSocketTransport._read_ready__data_received\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 861\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 862\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecv\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_size\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 863\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mBlockingIOError\u001b[39;00m, \u001b[38;5;167;01mInterruptedError\u001b[39;00m):\n",
246+
"\u001b[0;31mConnectionResetError\u001b[0m: [Errno 54] Connection reset by peer",
247+
"\nThe above exception was the direct cause of the following exception:\n",
248+
"\u001b[0;31mConnectionClosedError\u001b[0m Traceback (most recent call last)",
249+
"Cell \u001b[0;32mIn[38], line 52\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;66;03m# Run the websocket listener in the notebook\u001b[39;00m\n\u001b[0;32m---> 52\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m listen_for_updates()\n",
250+
"Cell \u001b[0;32mIn[38], line 31\u001b[0m, in \u001b[0;36mlisten_for_updates\u001b[0;34m()\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m websockets\u001b[38;5;241m.\u001b[39mconnect(ws_url) \u001b[38;5;28;01mas\u001b[39;00m websocket:\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m---> 31\u001b[0m message \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m websocket\u001b[38;5;241m.\u001b[39mrecv()\n\u001b[1;32m 32\u001b[0m status \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(message)\n\u001b[1;32m 33\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m status\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
251+
"File \u001b[0;32m~/.pyenv/versions/3.10.12/lib/python3.10/site-packages/websockets/asyncio/connection.py:322\u001b[0m, in \u001b[0;36mConnection.recv\u001b[0;34m(self, decode)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;66;03m# fallthrough\u001b[39;00m\n\u001b[1;32m 319\u001b[0m \n\u001b[1;32m 320\u001b[0m \u001b[38;5;66;03m# Wait for the protocol state to be CLOSED before accessing close_exc.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mshield(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconnection_lost_waiter)\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprotocol\u001b[38;5;241m.\u001b[39mclose_exc \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mself\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrecv_exc\u001b[39;00m\n",
252+
"\u001b[0;31mConnectionClosedError\u001b[0m: received 1012 (service restart); then sent 1012 (service restart)"
253+
]
223254
}
224255
],
225256
"source": [
226-
"from ipyleaflet import Map, GeoJSON\n",
227-
"import requests\n",
228-
"import time\n",
229-
"\n",
230257
"m = Map(center=[42.251628548555004, -6.37490034623255], zoom=8)\n",
231258
"geo_json = GeoJSON(\n",
232259
" data={\n",
@@ -253,25 +280,41 @@
253280
" \"fillOpacity\": 0.5\n",
254281
" }\n",
255282
"\n",
256-
"while upscaling_task[\"status\"] not in [\"finished\", \"canceled\", \"failed\"]:\n",
257-
" upscaling_task = requests.get(f\"{dispatch_api}/upscale_tasks/{upscaling_task['id']}\").json()\n",
258-
" if upscaling_task.get(\"jobs\"):\n",
259-
" features = []\n",
260-
" for job in upscaling_task[\"jobs\"]:\n",
261-
" features.append({\n",
262-
" \"type\": \"Feature\",\n",
263-
" \"geometry\": job[\"parameters\"][\"spatial_extent\"],\n",
264-
" \"properties\": {\n",
265-
" \"status\": job[\"status\"],\n",
283+
"async def listen_for_updates():\n",
284+
" ws_url = f\"ws://{dispatch_api}/ws/upscale_tasks/{upscaling_task['id']}?interval=15\"\n",
285+
" async with websockets.connect(ws_url) as websocket:\n",
286+
" while True:\n",
287+
" message = await websocket.recv()\n",
288+
" status = json.loads(message)\n",
289+
" if status.get(\"data\"):\n",
290+
" features = []\n",
291+
" for job in status[\"data\"][\"jobs\"]:\n",
292+
" features.append({\n",
293+
" \"type\": \"Feature\",\n",
294+
" \"geometry\": job[\"parameters\"][\"spatial_extent\"],\n",
295+
" \"properties\": {\n",
296+
" \"status\": job[\"status\"],\n",
297+
" }\n",
298+
" })\n",
299+
" geo_json.data = {\n",
300+
" \"type\": \"FeatureCollection\",\n",
301+
" \"features\": features\n",
266302
" }\n",
267-
" })\n",
268-
" geo_json.data = {\n",
269-
" \"type\": \"FeatureCollection\",\n",
270-
" \"features\": features\n",
271-
" }\n",
272-
" geo_json.style_callback = job_style\n",
273-
" time.sleep(15)"
303+
" geo_json.style_callback = job_style\n",
304+
" if status.get(\"status\") in [\"finished\", \"canceled\", \"failed\"]:\n",
305+
" break\n",
306+
"\n",
307+
"# Run the websocket listener in the notebook\n",
308+
"await listen_for_updates()"
274309
]
310+
},
311+
{
312+
"cell_type": "code",
313+
"execution_count": null,
314+
"id": "56bbbdce-45bc-4c05-a6f6-66ab6dac4808",
315+
"metadata": {},
316+
"outputs": [],
317+
"source": []
275318
}
276319
],
277320
"metadata": {

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ python-dotenv
2020
requests
2121
SQLAlchemy
2222
types-requests
23-
types-shapely
23+
types-shapely
24+
uvicorn[standard]

0 commit comments

Comments
 (0)