Skip to content

Commit 27b6b4d

Browse files
committed
fix: notebook auth fix
1 parent 5e8eacd commit 27b6b4d

File tree

1 file changed

+46
-43
lines changed

1 file changed

+46
-43
lines changed

guides/upscaling_example.ipynb

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 1,
14+
"execution_count": 29,
1515
"id": "d99f5fbc",
1616
"metadata": {},
1717
"outputs": [],
@@ -26,6 +26,7 @@
2626
"import httpx\n",
2727
"import io\n",
2828
"import base64\n",
29+
"import time\n",
2930
"from ipyleaflet import ImageOverlay\n",
3031
"from PIL import Image\n",
3132
"from ipyleaflet import Map, GeoJSON, TileLayer\n",
@@ -137,51 +138,53 @@
137138
},
138139
{
139140
"cell_type": "code",
140-
"execution_count": 7,
141+
"execution_count": 27,
141142
"id": "b5ee27f5-9e69-4557-ba83-ec7cb74aa874",
142143
"metadata": {},
143-
"outputs": [
144-
{
145-
"name": "stdout",
146-
"output_type": "stream",
147-
"text": [
148-
"Open this URL in your browser: https://auth.dev.apex.esa.int/realms/apex/protocol/openid-connect/auth?response_type=code&client_id=apex-dispatcher-api-dev&redirect_uri=http%3A%2F%2Flocalhost%3A8000%2Fcallback&scope=openid+profile+email&state=psCPZIUvLiKHtxsppPThySEoatQq2c\n"
149-
]
150-
},
151-
{
152-
"name": "stdin",
153-
"output_type": "stream",
154-
"text": [
155-
"Paste the redirect URL here http://localhost:8000/callback?state=psCPZIUvLiKHtxsppPThySEoatQq2c&session_state=c567bad0-cc01-42d0-8c23-99d5f59aebdd&iss=https%3A%2F%2Fauth.dev.apex.esa.int%2Frealms%2Fapex&code=f38c1e0a-ea0e-41d7-b2f7-2e31bbaa4cb2.c567bad0-cc01-42d0-8c23-99d5f59aebdd.c2e791df-00a5-4981-b8af-b014848a2b73\n"
156-
]
157-
}
158-
],
144+
"outputs": [],
159145
"source": [
146+
"# Endpoints\n",
160147
"authorization_endpoint = f\"https://{KEYCLOAK_HOST}/realms/apex/protocol/openid-connect/auth\"\n",
161148
"token_endpoint = f\"https://{KEYCLOAK_HOST}/realms/apex/protocol/openid-connect/token\"\n",
162149
"\n",
163-
"# Create OAuth2 session with PKCE\n",
164-
"session = OAuth2Session(client_id=CLIENT_ID, redirect_uri=\"http://localhost:8000/callback\", scope=\"openid profile email\")\n",
150+
"# Global token store\n",
151+
"_token_data = None\n",
165152
"\n",
166-
"# Get the authorization URL\n",
167-
"uri, state = session.create_authorization_url(authorization_endpoint)\n",
168-
"print(\"Open this URL in your browser:\", uri)\n",
153+
"def get_access_token():\n",
154+
" \"\"\"\n",
155+
" Returns a valid access token. Refreshes it automatically if expired.\n",
156+
" \"\"\"\n",
157+
" global _token_data\n",
169158
"\n",
170-
"# Extract the token\n",
171-
"redirect_url = input(\"Paste the redirect URL here\")\n",
172-
"parsed = urlparse(redirect_url)\n",
173-
"query_params = parse_qs(parsed.query)\n",
174-
"code = query_params.get(\"code\")[0] \n",
159+
" # If we have a token and it hasn't expired yet, return it\n",
160+
" if _token_data and _token_data.get(\"expires_at\", 0) > time.time() + 10:\n",
161+
" return _token_data[\"access_token\"]\n",
175162
"\n",
176-
"# Fetch access token\n",
177-
"token = session.fetch_token(\n",
178-
" token_endpoint,\n",
179-
" code=code,\n",
180-
" client_secret=None, # only if your client is confidential\n",
181-
" include_client_id=True\n",
182-
")\n",
163+
" # If token exists but is expired and has a refresh_token, refresh it\n",
164+
" if _token_data and \"refresh_token\" in _token_data:\n",
165+
" session = OAuth2Session(CLIENT_ID, token=_token_data)\n",
166+
" _token_data = session.refresh_token(token_endpoint)\n",
167+
" return _token_data[\"access_token\"]\n",
168+
"\n",
169+
" # Otherwise, start a new OAuth2 flow\n",
170+
" session = OAuth2Session(\n",
171+
" client_id=CLIENT_ID,\n",
172+
" redirect_uri=\"http://localhost:8000/callback\"\n",
173+
" )\n",
174+
" uri, state = session.create_authorization_url(authorization_endpoint)\n",
175+
" print(\"Open this URL in your browser:\", uri)\n",
176+
" redirect_url = input(\"Paste the redirect URL here: \")\n",
177+
" parsed = urlparse(redirect_url)\n",
178+
" code = parse_qs(parsed.query).get(\"code\")[0]\n",
179+
"\n",
180+
" _token_data = session.fetch_token(\n",
181+
" token_endpoint,\n",
182+
" code=code,\n",
183+
" client_secret=None, # only if your client is confidential\n",
184+
" include_client_id=True\n",
185+
" )\n",
183186
"\n",
184-
"access_token = token[\"access_token\"]"
187+
" return _token_data[\"access_token\"]"
185188
]
186189
},
187190
{
@@ -283,7 +286,7 @@
283286
"upscaling_task = requests.post(\n",
284287
" f\"http://{dispatch_api}/upscale_tasks\", \n",
285288
" headers={\n",
286-
" \"Authorization\": f\"Bearer {access_token}\" \n",
289+
" \"Authorization\": f\"Bearer {get_access_token()}\" \n",
287290
" },\n",
288291
" json={\n",
289292
" \"title\": \"Wind Turbine Detection\",\n",
@@ -360,14 +363,14 @@
360363
},
361364
{
362365
"cell_type": "code",
363-
"execution_count": 14,
366+
"execution_count": 30,
364367
"id": "ac428293-7cd4-49a8-9bfa-4e0dc8f4d2cc",
365368
"metadata": {},
366369
"outputs": [
367370
{
368371
"data": {
369372
"application/vnd.jupyter.widget-view+json": {
370-
"model_id": "c8d5a6e04c7f463aa4f4eba2992f1108",
373+
"model_id": "78b9621c2829433ab6a1c7a766cd6ada",
371374
"version_major": 2,
372375
"version_minor": 0
373376
},
@@ -389,8 +392,8 @@
389392
"\u001b[0;31mConnectionResetError\u001b[0m: [Errno 54] Connection reset by peer",
390393
"\nThe above exception was the direct cause of the following exception:\n",
391394
"\u001b[0;31mConnectionClosedError\u001b[0m Traceback (most recent call last)",
392-
"Cell \u001b[0;32mIn[14], line 85\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m# Run the websocket listener in the notebook\u001b[39;00m\n\u001b[0;32m---> 85\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m listen_for_updates()\n",
393-
"Cell \u001b[0;32mIn[14], line 53\u001b[0m, in \u001b[0;36mlisten_for_updates\u001b[0;34m()\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m websockets\u001b[38;5;241m.\u001b[39mconnect(ws_url) \u001b[38;5;28;01mas\u001b[39;00m websocket:\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m---> 53\u001b[0m message \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m websocket\u001b[38;5;241m.\u001b[39mrecv()\n\u001b[1;32m 54\u001b[0m message \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(message)\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m message\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
395+
"Cell \u001b[0;32mIn[30], line 85\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m# Run the websocket listener in the notebook\u001b[39;00m\n\u001b[0;32m---> 85\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m listen_for_updates()\n",
396+
"Cell \u001b[0;32mIn[30], line 53\u001b[0m, in \u001b[0;36mlisten_for_updates\u001b[0;34m()\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m websockets\u001b[38;5;241m.\u001b[39mconnect(ws_url) \u001b[38;5;28;01mas\u001b[39;00m websocket:\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[0;32m---> 53\u001b[0m message \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m websocket\u001b[38;5;241m.\u001b[39mrecv()\n\u001b[1;32m 54\u001b[0m message \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mloads(message)\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m message\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
394397
"File \u001b[0;32m~/.pyenv/versions/3.10.12/lib/python3.10/site-packages/websockets/asyncio/connection.py:322\u001b[0m, in \u001b[0;36mConnection.recv\u001b[0;34m(self, decode)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;66;03m# fallthrough\u001b[39;00m\n\u001b[1;32m 319\u001b[0m \n\u001b[1;32m 320\u001b[0m \u001b[38;5;66;03m# Wait for the protocol state to be CLOSED before accessing close_exc.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;28;01mawait\u001b[39;00m asyncio\u001b[38;5;241m.\u001b[39mshield(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconnection_lost_waiter)\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprotocol\u001b[38;5;241m.\u001b[39mclose_exc \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01mself\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mrecv_exc\u001b[39;00m\n",
395398
"\u001b[0;31mConnectionClosedError\u001b[0m: received 1012 (service restart); then sent 1012 (service restart)"
396399
]
@@ -434,7 +437,7 @@
434437
"async def show_results(job_id):\n",
435438
" async with httpx.AsyncClient() as client:\n",
436439
" result = await client.get(f\"http://{dispatch_api}/unit_jobs/{job_id}/results\", headers={\n",
437-
" \"Authorization\": f\"Bearer {access_token}\"\n",
440+
" \"Authorization\": f\"Bearer {get_access_token()}\"\n",
438441
" })\n",
439442
" response = result.json()\n",
440443
" if output_format.lower() == \"geojson\":\n",
@@ -446,7 +449,7 @@
446449
" return response\n",
447450
"\n",
448451
"async def listen_for_updates():\n",
449-
" ws_url = f\"ws://{dispatch_api}/ws/upscale_tasks/{upscaling_task_id}?interval=15&token={access_token}\"\n",
452+
" ws_url = f\"ws://{dispatch_api}/ws/upscale_tasks/{upscaling_task_id}?interval=15&token={get_access_token()}\"\n",
450453
" async with websockets.connect(ws_url) as websocket:\n",
451454
" while True:\n",
452455
" message = await websocket.recv()\n",

0 commit comments

Comments
 (0)