Skip to content

Commit 96fce0c

Browse files
committed
Used FastAPI's 'APIKeyCookie' object for cookie authentication instead of homebrew method; split session access validation function for instrument server and frontend into separate functions; instrument validation function was incorrectly calling oauth2 scheme for frontend instead of backend; fixed logic for 'create_access_token' and 'generate_token' for handling authentication using either 'password' or 'cookie'; 'simple_token_validation()' should be using instrument server validation function instead
1 parent e8f57e4 commit 96fce0c

File tree

1 file changed

+127
-160
lines changed

1 file changed

+127
-160
lines changed

src/murfey/server/api/auth.py

Lines changed: 127 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,22 @@
33
import secrets
44
import time
55
from logging import getLogger
6-
from typing import Annotated, Dict
6+
from typing import Dict
77
from uuid import uuid4
88

99
import aiohttp
1010
import requests
11-
from fastapi import APIRouter, Depends, HTTPException, Request, status
12-
from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm
11+
from fastapi import APIRouter, Depends, HTTPException, status
12+
from fastapi.security import (
13+
APIKeyCookie,
14+
OAuth2PasswordBearer,
15+
OAuth2PasswordRequestForm,
16+
)
1317
from jose import JWTError, jwt
1418
from passlib.context import CryptContext
1519
from pydantic import BaseModel
1620
from sqlmodel import Session, create_engine, select
21+
from typing_extensions import Annotated
1722

1823
from murfey.server.murfey_db import murfey_db, url
1924
from murfey.util.api import url_path_for
@@ -31,38 +36,6 @@
3136
)
3237

3338

34-
class CookieScheme(HTTPBearer):
35-
def __init__(
36-
self,
37-
*,
38-
description: str | None = None,
39-
auto_error: bool = True,
40-
cookie_key: str = "cookie_auth",
41-
):
42-
"""
43-
Args:
44-
cookie_key: Cookie key to look for in requests
45-
"""
46-
super().__init__(
47-
description=description,
48-
auto_error=auto_error,
49-
)
50-
51-
self.cookie_key = cookie_key
52-
53-
async def __call__(self, request: Request):
54-
token = request.cookies.get(self.cookie_key)
55-
if token is None:
56-
if self.auto_error:
57-
raise HTTPException(
58-
status_code=status.HTTP_401_UNAUTHORIZED,
59-
detail="Not authenticated",
60-
)
61-
else:
62-
return None
63-
return token
64-
65-
6639
# Set up variables used for authentication
6740
security_config = get_security_config()
6841
auth_url = security_config.auth_url
@@ -71,7 +44,7 @@ async def __call__(self, request: Request):
7144
if security_config.auth_type == "password":
7245
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
7346
else:
74-
oauth2_scheme = CookieScheme(cookie_key=security_config.cookie_key)
47+
oauth2_scheme = APIKeyCookie(name=security_config.cookie_key)
7548
if security_config.instrument_auth_type == "token":
7649
instrument_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
7750
else:
@@ -138,19 +111,19 @@ async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]):
138111
validation_outcome = await response.json()
139112
if not (success and validation_outcome.get("valid")):
140113
raise JWTError
141-
# Auth URL MUST be provided if authenticating using cookies
114+
# If authenticating using cookies; an auth URL MUST be provided
142115
else:
143116
if security_config.auth_type == "cookie":
144117
raise JWTError
145-
146118
# Validate using password
147119
if security_config.auth_type == "password":
148120
decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
149-
# Frontend validation
121+
# Check that the user is present and is valid
150122
if decoded_data.get("user"):
151123
if not check_user(decoded_data["user"]):
152124
raise JWTError
153-
125+
else:
126+
raise JWTError
154127
except JWTError:
155128
raise HTTPException(
156129
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -221,7 +194,7 @@ async def validate_instrument_token(
221194

222195
"""
223196
=======================================================================================
224-
VALIDATING SESSION IDS
197+
SESSION ID VALIDATION
225198
=======================================================================================
226199
227200
Annotated ints are defined here that trigger validation of the session IDs in incoming
@@ -233,104 +206,87 @@ async def validate_instrument_token(
233206
"""
234207

235208

236-
async def validate_visit(visit_name: str, token: str, instrument_access: bool) -> bool:
237-
"""
238-
Validates the incoming token depending on whether it's from the instrument or the frontend
239-
"""
240-
241-
# If this token is from the instrument server, validate this way
242-
if instrument_access:
243-
if security_config.instrument_auth_url:
244-
async with aiohttp.ClientSession() as session:
245-
headers = (
246-
{}
247-
if not security_config.instrument_auth_type
248-
else {"Authorization": f"Bearer {token}"}
249-
)
250-
async with session.get(
251-
f"{security_config.instrument_auth_url}/validate_visit_access/{visit_name}",
252-
headers=headers,
253-
) as response:
254-
success = response.status == 200
255-
validation_outcome = await response.json()
256-
if not (success and validation_outcome.get("valid")):
257-
return False
258-
# Otherwise, use this validation method
259-
else:
260-
if security_config.auth_url:
261-
headers = (
262-
{}
263-
if security_config.auth_type == "cookie"
264-
else {"Authorization": f"Bearer {token}"}
265-
)
266-
cookies = (
267-
{security_config.cookie_key: token}
268-
if security_config.auth_type == "cookie"
269-
else {}
270-
)
271-
async with aiohttp.ClientSession(cookies=cookies) as session:
272-
async with session.get(
273-
f"{auth_url}/validate_visit_access/{visit_name}",
274-
headers=headers,
275-
) as response:
276-
success = response.status == 200
277-
validation_outcome = await response.json()
278-
if not (success and validation_outcome.get("valid")):
279-
return False
280-
return True
209+
def get_visit_name(session_id: int) -> str:
210+
with Session(engine) as murfey_db:
211+
return (
212+
murfey_db.exec(select(MurfeySession).where(MurfeySession.id == session_id))
213+
.one()
214+
.visit
215+
)
281216

282217

283-
async def validate_session_access(
218+
async def validate_frontend_session_access(
284219
session_id: int,
285220
token: Annotated[str, Depends(oauth2_scheme)],
286-
instrument_access: bool,
287221
) -> int:
288222
"""
289-
Validates whether the request is authorised to access information about this session
223+
Validates whether a frontend request can access information about this session
290224
"""
291-
with Session(engine) as murfey_db:
292-
visit_name = (
293-
murfey_db.exec(select(MurfeySession).where(MurfeySession.id == session_id))
294-
.one()
295-
.visit
225+
visit_name = get_visit_name(session_id)
226+
227+
if auth_url:
228+
headers = (
229+
{}
230+
if security_config.auth_type == "cookie"
231+
else {"Authorization": f"Bearer {token}"}
296232
)
297-
validated = await validate_visit(visit_name, token, instrument_access)
298-
if not validated:
299-
raise HTTPException(
300-
status_code=status.HTTP_401_UNAUTHORIZED,
301-
detail="You do not have access to this visit",
302-
headers={"WWW-Authenticate": "Bearer"},
233+
cookies = (
234+
{security_config.cookie_key: token}
235+
if security_config.auth_type == "cookie"
236+
else {}
303237
)
238+
async with aiohttp.ClientSession(cookies=cookies) as session:
239+
async with session.get(
240+
f"{auth_url}/validate_visit_access/{visit_name}",
241+
headers=headers,
242+
) as response:
243+
success = response.status == 200
244+
validation_outcome: dict = await response.json()
245+
if not (success and validation_outcome.get("valid")):
246+
logger.warning("Unauthorised visit access request from frontend")
247+
raise HTTPException(
248+
status_code=status.HTTP_401_UNAUTHORIZED,
249+
detail="You do not have access to this visit",
250+
headers={"WWW-Authenticate": "Bearer"},
251+
)
304252
return session_id
305253

306254

307-
def validate_session_access_wrapper(instrument_access: bool):
255+
async def validate_instrument_session_access(
256+
session_id: int,
257+
token: Annotated[str, Depends(instrument_oauth2_scheme)],
258+
) -> int:
308259
"""
309-
Factory that returns an async wrapper arond 'validate_session_access' with the
310-
'instrument_access' field preconfigured.
311-
312-
This is used to generate FastAPI-compatible dependencies for validating session
313-
access, based on the context of the request.
314-
315-
Unlike 'functools.partial', this approach preserves introspection compatibility
316-
required by FastAPI for dependency resolution and OpenAPI generation.
260+
Validates whether an instrument request can access information about this session
317261
"""
262+
visit_name = get_visit_name(session_id)
318263

319-
async def _validate(session_id: int, token: Annotated[str, Depends(oauth2_scheme)]):
320-
return await validate_session_access(
321-
session_id, token, instrument_access=instrument_access
322-
)
323-
324-
return _validate
264+
if security_config.instrument_auth_url:
265+
async with aiohttp.ClientSession() as session:
266+
headers = (
267+
{}
268+
if not security_config.instrument_auth_type
269+
else {"Authorization": f"Bearer {token}"}
270+
)
271+
async with session.get(
272+
f"{security_config.instrument_auth_url}/validate_visit_access/{visit_name}",
273+
headers=headers,
274+
) as response:
275+
success = response.status == 200
276+
validation_outcome = await response.json()
277+
if not (success and validation_outcome.get("valid")):
278+
logger.warning("Unauthorised visit access request from instrument")
279+
raise HTTPException(
280+
status_code=status.HTTP_401_UNAUTHORIZED,
281+
detail="You do not have access to this visit",
282+
headers={"WWW-Authenticate": "Bearer"},
283+
)
284+
return session_id
325285

326286

327287
# Set validation conditions for the session ID based on where the request is from
328-
MurfeySessionIDFrontend = Annotated[
329-
int, Depends(validate_session_access_wrapper(instrument_access=False))
330-
]
331-
MurfeySessionIDInstrument = Annotated[
332-
int, Depends(validate_session_access_wrapper(instrument_access=True))
333-
]
288+
MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
289+
MurfeySessionIDInstrument = Annotated[int, Depends(validate_instrument_session_access)]
334290

335291

336292
"""
@@ -354,23 +310,27 @@ def validate_user(username: str, password: str) -> bool:
354310

355311

356312
def create_access_token(data: dict, token: str = "") -> str:
357-
if auth_url and data.get("session"):
358-
session_id = data["session"]
359-
if not isinstance(session_id, int) and session_id > 0:
360-
# check the session ID is alphanumeric for security
361-
raise ValueError("Session ID was invalid (not alphanumeric)")
362-
minted_token_response = requests.get(
363-
f"{auth_url}{url_path_for('auth.router', 'mint_session_token', session_id=session_id)}",
364-
headers={"Authorization": f"Bearer {token}"},
365-
)
366-
if minted_token_response.status_code != 200:
367-
raise RuntimeError(
368-
f"Request received status code {minted_token_response.status_code} when trying to create session token"
313+
314+
# If authenticating with password, auth URL needs a 'mint_session_token' endpoint
315+
if security_config.auth_type == "password":
316+
if auth_url and data.get("session"):
317+
session_id = data["session"]
318+
if not isinstance(session_id, int) and session_id > 0:
319+
# Check the session ID is alphanumeric for security
320+
raise ValueError("Session ID was invalid (not alphanumeric)")
321+
minted_token_response = requests.get(
322+
f"{auth_url}{url_path_for('auth.router', 'mint_session_token', session_id=session_id)}",
323+
headers={"Authorization": f"Bearer {token}"},
369324
)
370-
return minted_token_response.json()["access_token"]
325+
if minted_token_response.status_code != 200:
326+
raise RuntimeError(
327+
f"Request received status code {minted_token_response.status_code} when trying to create session token"
328+
)
329+
return minted_token_response.json()["access_token"]
371330

372331
to_encode = data.copy()
373332

333+
# Make token for instrument
374334
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
375335
return encoded_jwt
376336

@@ -384,31 +344,36 @@ class Token(BaseModel):
384344
async def generate_token(
385345
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
386346
) -> Token:
387-
if auth_url:
388-
data = aiohttp.FormData()
389-
data.add_field("username", form_data.username)
390-
data.add_field("password", form_data.password)
391-
async with aiohttp.ClientSession() as session:
392-
async with session.post(
393-
f"{auth_url}{url_path_for('auth.router', 'generate_token')}",
394-
data=data,
395-
) as response:
396-
validated = response.status == 200
397-
token = await response.json()
398-
access_token = token.get("access_token")
399-
else:
400-
validated = validate_user(form_data.username, form_data.password)
401-
if not validated:
402-
raise HTTPException(
403-
status_code=status.HTTP_401_UNAUTHORIZED,
404-
detail="Incorrect username or password",
405-
headers={"WWW-Authenticate": "Bearer"},
406-
)
407-
if not auth_url:
408-
access_token = create_access_token(
409-
data={"user": form_data.username},
410-
)
411-
return Token(access_token=access_token, token_type="bearer")
347+
# Only generate a token if it's a password
348+
if security_config.auth_type == "password":
349+
if auth_url:
350+
data = aiohttp.FormData()
351+
data.add_field("username", form_data.username)
352+
data.add_field("password", form_data.password)
353+
async with aiohttp.ClientSession() as session:
354+
async with session.post(
355+
f"{auth_url}{url_path_for('auth.router', 'generate_token')}",
356+
data=data,
357+
) as response:
358+
validated = response.status == 200
359+
token = await response.json()
360+
access_token = token.get("access_token")
361+
else:
362+
validated = validate_user(form_data.username, form_data.password)
363+
if not validated:
364+
raise HTTPException(
365+
status_code=status.HTTP_401_UNAUTHORIZED,
366+
detail="Incorrect username or password",
367+
headers={"WWW-Authenticate": "Bearer"},
368+
)
369+
if not auth_url:
370+
access_token = create_access_token(
371+
data={"user": form_data.username},
372+
)
373+
return Token(access_token=access_token, token_type="bearer")
374+
375+
# Return empty token otherwise
376+
return Token(access_token="", token_type="bearer")
412377

413378

414379
@router.get("/sessions/{session_id}/token")
@@ -431,5 +396,7 @@ async def mint_session_token(session_id: MurfeySessionIDFrontend, db=murfey_db):
431396

432397

433398
@router.get("/validate_token")
434-
async def simple_token_validation(token: Annotated[str, Depends(validate_token)]):
399+
async def simple_token_validation(
400+
token: Annotated[str, Depends(validate_instrument_token)]
401+
):
435402
return {"valid": True}

0 commit comments

Comments
 (0)