33import secrets
44import time
55from logging import getLogger
6- from typing import Annotated , Dict
6+ from typing import Dict
77from uuid import uuid4
88
99import aiohttp
1010import requests
11- from fastapi import APIRouter , Depends , HTTPException , Request , status
12- from fastapi .security import HTTPBearer , OAuth2PasswordBearer , OAuth2PasswordRequestForm
11+ from fastapi import APIRouter , Depends , HTTPException , status
12+ from fastapi .security import (
13+ APIKeyCookie ,
14+ OAuth2PasswordBearer ,
15+ OAuth2PasswordRequestForm ,
16+ )
1317from jose import JWTError , jwt
1418from passlib .context import CryptContext
1519from pydantic import BaseModel
1620from sqlmodel import Session , create_engine , select
21+ from typing_extensions import Annotated
1722
1823from murfey .server .murfey_db import murfey_db , url
1924from murfey .util .api import url_path_for
3136)
3237
3338
34- class CookieScheme (HTTPBearer ):
35- def __init__ (
36- self ,
37- * ,
38- description : str | None = None ,
39- auto_error : bool = True ,
40- cookie_key : str = "cookie_auth" ,
41- ):
42- """
43- Args:
44- cookie_key: Cookie key to look for in requests
45- """
46- super ().__init__ (
47- description = description ,
48- auto_error = auto_error ,
49- )
50-
51- self .cookie_key = cookie_key
52-
53- async def __call__ (self , request : Request ):
54- token = request .cookies .get (self .cookie_key )
55- if token is None :
56- if self .auto_error :
57- raise HTTPException (
58- status_code = status .HTTP_401_UNAUTHORIZED ,
59- detail = "Not authenticated" ,
60- )
61- else :
62- return None
63- return token
64-
65-
6639# Set up variables used for authentication
6740security_config = get_security_config ()
6841auth_url = security_config .auth_url
@@ -71,7 +44,7 @@ async def __call__(self, request: Request):
7144if security_config .auth_type == "password" :
7245 oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "token" )
7346else :
74- oauth2_scheme = CookieScheme ( cookie_key = security_config .cookie_key )
47+ oauth2_scheme = APIKeyCookie ( name = security_config .cookie_key )
7548if security_config .instrument_auth_type == "token" :
7649 instrument_oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "token" )
7750else :
@@ -138,19 +111,19 @@ async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]):
138111 validation_outcome = await response .json ()
139112 if not (success and validation_outcome .get ("valid" )):
140113 raise JWTError
141- # Auth URL MUST be provided if authenticating using cookies
114+ # If authenticating using cookies; an auth URL MUST be provided
142115 else :
143116 if security_config .auth_type == "cookie" :
144117 raise JWTError
145-
146118 # Validate using password
147119 if security_config .auth_type == "password" :
148120 decoded_data = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
149- # Frontend validation
121+ # Check that the user is present and is valid
150122 if decoded_data .get ("user" ):
151123 if not check_user (decoded_data ["user" ]):
152124 raise JWTError
153-
125+ else :
126+ raise JWTError
154127 except JWTError :
155128 raise HTTPException (
156129 status_code = status .HTTP_401_UNAUTHORIZED ,
@@ -221,7 +194,7 @@ async def validate_instrument_token(
221194
222195"""
223196=======================================================================================
224- VALIDATING SESSION IDS
197+ SESSION ID VALIDATION
225198=======================================================================================
226199
227200Annotated ints are defined here that trigger validation of the session IDs in incoming
@@ -233,104 +206,87 @@ async def validate_instrument_token(
233206"""
234207
235208
236- async def validate_visit (visit_name : str , token : str , instrument_access : bool ) -> bool :
237- """
238- Validates the incoming token depending on whether it's from the instrument or the frontend
239- """
240-
241- # If this token is from the instrument server, validate this way
242- if instrument_access :
243- if security_config .instrument_auth_url :
244- async with aiohttp .ClientSession () as session :
245- headers = (
246- {}
247- if not security_config .instrument_auth_type
248- else {"Authorization" : f"Bearer { token } " }
249- )
250- async with session .get (
251- f"{ security_config .instrument_auth_url } /validate_visit_access/{ visit_name } " ,
252- headers = headers ,
253- ) as response :
254- success = response .status == 200
255- validation_outcome = await response .json ()
256- if not (success and validation_outcome .get ("valid" )):
257- return False
258- # Otherwise, use this validation method
259- else :
260- if security_config .auth_url :
261- headers = (
262- {}
263- if security_config .auth_type == "cookie"
264- else {"Authorization" : f"Bearer { token } " }
265- )
266- cookies = (
267- {security_config .cookie_key : token }
268- if security_config .auth_type == "cookie"
269- else {}
270- )
271- async with aiohttp .ClientSession (cookies = cookies ) as session :
272- async with session .get (
273- f"{ auth_url } /validate_visit_access/{ visit_name } " ,
274- headers = headers ,
275- ) as response :
276- success = response .status == 200
277- validation_outcome = await response .json ()
278- if not (success and validation_outcome .get ("valid" )):
279- return False
280- return True
209+ def get_visit_name (session_id : int ) -> str :
210+ with Session (engine ) as murfey_db :
211+ return (
212+ murfey_db .exec (select (MurfeySession ).where (MurfeySession .id == session_id ))
213+ .one ()
214+ .visit
215+ )
281216
282217
283- async def validate_session_access (
218+ async def validate_frontend_session_access (
284219 session_id : int ,
285220 token : Annotated [str , Depends (oauth2_scheme )],
286- instrument_access : bool ,
287221) -> int :
288222 """
289- Validates whether the request is authorised to access information about this session
223+ Validates whether a frontend request can access information about this session
290224 """
291- with Session (engine ) as murfey_db :
292- visit_name = (
293- murfey_db .exec (select (MurfeySession ).where (MurfeySession .id == session_id ))
294- .one ()
295- .visit
225+ visit_name = get_visit_name (session_id )
226+
227+ if auth_url :
228+ headers = (
229+ {}
230+ if security_config .auth_type == "cookie"
231+ else {"Authorization" : f"Bearer { token } " }
296232 )
297- validated = await validate_visit (visit_name , token , instrument_access )
298- if not validated :
299- raise HTTPException (
300- status_code = status .HTTP_401_UNAUTHORIZED ,
301- detail = "You do not have access to this visit" ,
302- headers = {"WWW-Authenticate" : "Bearer" },
233+ cookies = (
234+ {security_config .cookie_key : token }
235+ if security_config .auth_type == "cookie"
236+ else {}
303237 )
238+ async with aiohttp .ClientSession (cookies = cookies ) as session :
239+ async with session .get (
240+ f"{ auth_url } /validate_visit_access/{ visit_name } " ,
241+ headers = headers ,
242+ ) as response :
243+ success = response .status == 200
244+ validation_outcome : dict = await response .json ()
245+ if not (success and validation_outcome .get ("valid" )):
246+ logger .warning ("Unauthorised visit access request from frontend" )
247+ raise HTTPException (
248+ status_code = status .HTTP_401_UNAUTHORIZED ,
249+ detail = "You do not have access to this visit" ,
250+ headers = {"WWW-Authenticate" : "Bearer" },
251+ )
304252 return session_id
305253
306254
307- def validate_session_access_wrapper (instrument_access : bool ):
255+ async def validate_instrument_session_access (
256+ session_id : int ,
257+ token : Annotated [str , Depends (instrument_oauth2_scheme )],
258+ ) -> int :
308259 """
309- Factory that returns an async wrapper arond 'validate_session_access' with the
310- 'instrument_access' field preconfigured.
311-
312- This is used to generate FastAPI-compatible dependencies for validating session
313- access, based on the context of the request.
314-
315- Unlike 'functools.partial', this approach preserves introspection compatibility
316- required by FastAPI for dependency resolution and OpenAPI generation.
260+ Validates whether an instrument request can access information about this session
317261 """
262+ visit_name = get_visit_name (session_id )
318263
319- async def _validate (session_id : int , token : Annotated [str , Depends (oauth2_scheme )]):
320- return await validate_session_access (
321- session_id , token , instrument_access = instrument_access
322- )
323-
324- return _validate
264+ if security_config .instrument_auth_url :
265+ async with aiohttp .ClientSession () as session :
266+ headers = (
267+ {}
268+ if not security_config .instrument_auth_type
269+ else {"Authorization" : f"Bearer { token } " }
270+ )
271+ async with session .get (
272+ f"{ security_config .instrument_auth_url } /validate_visit_access/{ visit_name } " ,
273+ headers = headers ,
274+ ) as response :
275+ success = response .status == 200
276+ validation_outcome = await response .json ()
277+ if not (success and validation_outcome .get ("valid" )):
278+ logger .warning ("Unauthorised visit access request from instrument" )
279+ raise HTTPException (
280+ status_code = status .HTTP_401_UNAUTHORIZED ,
281+ detail = "You do not have access to this visit" ,
282+ headers = {"WWW-Authenticate" : "Bearer" },
283+ )
284+ return session_id
325285
326286
327287# Set validation conditions for the session ID based on where the request is from
328- MurfeySessionIDFrontend = Annotated [
329- int , Depends (validate_session_access_wrapper (instrument_access = False ))
330- ]
331- MurfeySessionIDInstrument = Annotated [
332- int , Depends (validate_session_access_wrapper (instrument_access = True ))
333- ]
288+ MurfeySessionIDFrontend = Annotated [int , Depends (validate_frontend_session_access )]
289+ MurfeySessionIDInstrument = Annotated [int , Depends (validate_instrument_session_access )]
334290
335291
336292"""
@@ -354,23 +310,27 @@ def validate_user(username: str, password: str) -> bool:
354310
355311
356312def create_access_token (data : dict , token : str = "" ) -> str :
357- if auth_url and data .get ("session" ):
358- session_id = data ["session" ]
359- if not isinstance (session_id , int ) and session_id > 0 :
360- # check the session ID is alphanumeric for security
361- raise ValueError ("Session ID was invalid (not alphanumeric)" )
362- minted_token_response = requests .get (
363- f"{ auth_url } { url_path_for ('auth.router' , 'mint_session_token' , session_id = session_id )} " ,
364- headers = {"Authorization" : f"Bearer { token } " },
365- )
366- if minted_token_response .status_code != 200 :
367- raise RuntimeError (
368- f"Request received status code { minted_token_response .status_code } when trying to create session token"
313+
314+ # If authenticating with password, auth URL needs a 'mint_session_token' endpoint
315+ if security_config .auth_type == "password" :
316+ if auth_url and data .get ("session" ):
317+ session_id = data ["session" ]
318+ if not isinstance (session_id , int ) and session_id > 0 :
319+ # Check the session ID is alphanumeric for security
320+ raise ValueError ("Session ID was invalid (not alphanumeric)" )
321+ minted_token_response = requests .get (
322+ f"{ auth_url } { url_path_for ('auth.router' , 'mint_session_token' , session_id = session_id )} " ,
323+ headers = {"Authorization" : f"Bearer { token } " },
369324 )
370- return minted_token_response .json ()["access_token" ]
325+ if minted_token_response .status_code != 200 :
326+ raise RuntimeError (
327+ f"Request received status code { minted_token_response .status_code } when trying to create session token"
328+ )
329+ return minted_token_response .json ()["access_token" ]
371330
372331 to_encode = data .copy ()
373332
333+ # Make token for instrument
374334 encoded_jwt = jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM )
375335 return encoded_jwt
376336
@@ -384,31 +344,36 @@ class Token(BaseModel):
384344async def generate_token (
385345 form_data : Annotated [OAuth2PasswordRequestForm , Depends ()],
386346) -> Token :
387- if auth_url :
388- data = aiohttp .FormData ()
389- data .add_field ("username" , form_data .username )
390- data .add_field ("password" , form_data .password )
391- async with aiohttp .ClientSession () as session :
392- async with session .post (
393- f"{ auth_url } { url_path_for ('auth.router' , 'generate_token' )} " ,
394- data = data ,
395- ) as response :
396- validated = response .status == 200
397- token = await response .json ()
398- access_token = token .get ("access_token" )
399- else :
400- validated = validate_user (form_data .username , form_data .password )
401- if not validated :
402- raise HTTPException (
403- status_code = status .HTTP_401_UNAUTHORIZED ,
404- detail = "Incorrect username or password" ,
405- headers = {"WWW-Authenticate" : "Bearer" },
406- )
407- if not auth_url :
408- access_token = create_access_token (
409- data = {"user" : form_data .username },
410- )
411- return Token (access_token = access_token , token_type = "bearer" )
347+ # Only generate a token if it's a password
348+ if security_config .auth_type == "password" :
349+ if auth_url :
350+ data = aiohttp .FormData ()
351+ data .add_field ("username" , form_data .username )
352+ data .add_field ("password" , form_data .password )
353+ async with aiohttp .ClientSession () as session :
354+ async with session .post (
355+ f"{ auth_url } { url_path_for ('auth.router' , 'generate_token' )} " ,
356+ data = data ,
357+ ) as response :
358+ validated = response .status == 200
359+ token = await response .json ()
360+ access_token = token .get ("access_token" )
361+ else :
362+ validated = validate_user (form_data .username , form_data .password )
363+ if not validated :
364+ raise HTTPException (
365+ status_code = status .HTTP_401_UNAUTHORIZED ,
366+ detail = "Incorrect username or password" ,
367+ headers = {"WWW-Authenticate" : "Bearer" },
368+ )
369+ if not auth_url :
370+ access_token = create_access_token (
371+ data = {"user" : form_data .username },
372+ )
373+ return Token (access_token = access_token , token_type = "bearer" )
374+
375+ # Return empty token otherwise
376+ return Token (access_token = "" , token_type = "bearer" )
412377
413378
414379@router .get ("/sessions/{session_id}/token" )
@@ -431,5 +396,7 @@ async def mint_session_token(session_id: MurfeySessionIDFrontend, db=murfey_db):
431396
432397
433398@router .get ("/validate_token" )
434- async def simple_token_validation (token : Annotated [str , Depends (validate_token )]):
399+ async def simple_token_validation (
400+ token : Annotated [str , Depends (validate_instrument_token )]
401+ ):
435402 return {"valid" : True }
0 commit comments