88
99import aiohttp
1010import requests
11- from backports .entry_points_selectable import entry_points
1211from fastapi import APIRouter , Depends , HTTPException , Request , status
1312from fastapi .security import HTTPBearer , OAuth2PasswordBearer , OAuth2PasswordRequestForm
1413from jose import JWTError , jwt
@@ -78,20 +77,6 @@ async def __call__(self, request: Request):
7877
7978instrument_server_tokens : Dict [float , dict ] = {}
8079
81-
82- """
83- HELPER FUNCTIONS
84- """
85-
86-
87- def verify_password (plain_password : str , hashed_password : str ) -> bool :
88- return pwd_context .verify (plain_password , hashed_password )
89-
90-
91- def hash_password (password : str ) -> str :
92- return pwd_context .hash (password )
93-
94-
9580# Set up database engine
9681try :
9782 _url = url (security_config )
@@ -100,22 +85,19 @@ def hash_password(password: str) -> str:
10085 engine = None
10186
10287
103- def validate_user (username : str , password : str ) -> bool :
104- try :
105- with Session (engine ) as murfey_db :
106- user = murfey_db .exec (select (User ).where (User .username == username )).one ()
107- except Exception :
108- return False
109- return verify_password (password , user .hashed_password )
88+ def hash_password (password : str ) -> str :
89+ return pwd_context .hash (password )
11090
11191
112- def validate_visit (visit_name : str , token : str ) -> bool :
113- if validators := entry_points ().select (
114- group = "murfey.auth.session_validation" ,
115- name = security_config .auth_type ,
116- ):
117- return validators [0 ].load ()(visit_name , token )
118- return True
92+ """
93+ =======================================================================================
94+ TOKEN VALIDATION FUNCTIONS
95+ =======================================================================================
96+
97+ Functions and helpers used to validate incoming requests from both the client and
98+ the frontend. 'validate_token()' and 'validate_instrument_token()' are imported
99+ int the other FastAPI modules and attached as dependencies to the routers.
100+ """
119101
120102
121103def check_user (username : str ) -> bool :
@@ -127,75 +109,75 @@ def check_user(username: str) -> bool:
127109 return username in [u .username for u in users ]
128110
129111
130- def validate_instrument_server_session_token (session_id : int , visit : str ):
131- with Session (engine ) as murfey_db :
132- session_data = murfey_db .exec (
133- select (MurfeySession ).where (MurfeySession .id == session_id )
134- ).all ()
135- if len (session_data ) != 1 :
136- return False
137- return visit == session_data [0 ].visit
138-
139-
140112async def validate_token (token : Annotated [str , Depends (oauth2_scheme )]):
113+ """
114+ Used by the backend routers to validate requests coming in from frontend.
115+ """
141116 try :
142- try :
143- if security_config .auth_type == "password" :
144- await validate_password_token (token )
145- except JWTError :
146- await validate_instrument_token (token )
147- decoded_data = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
148- # first check if the token has expired
149- if expiry_time := decoded_data .get ("expiry_time" ):
150- if expiry_time < time .time ():
151- raise JWTError
152- if decoded_data .get ("user" ):
153- if not check_user (decoded_data ["user" ]):
154- raise JWTError
155- elif decoded_data .get ("session" ) is not None :
156- if not validate_instrument_server_session_token (
157- decoded_data ["session" ], decoded_data ["visit" ]
158- ):
117+ # Validate using auth URL if provided; will error if invalid
118+ if auth_url :
119+ headers = (
120+ {}
121+ if security_config .auth_type == "cookie"
122+ else {"Authorization" : f"Bearer { token } " }
123+ )
124+ cookies = (
125+ {security_config .cookie_key : token }
126+ if security_config .auth_type == "cookie"
127+ else {}
128+ )
129+ async with aiohttp .ClientSession (cookies = cookies ) as session :
130+ async with session .get (
131+ f"{ auth_url } /validate_token" ,
132+ headers = headers ,
133+ ) as response :
134+ success = response .status == 200
135+ validation_outcome = await response .json ()
136+ if not (success and validation_outcome .get ("valid" )):
159137 raise JWTError
138+ # Auth URL MUST be provided if authenticating using cookies
160139 else :
161- raise JWTError
140+ if security_config .auth_type == "cookie" :
141+ raise JWTError
142+
143+ # Validate using password
144+ if security_config .auth_type == "password" :
145+ decoded_data = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
146+ # Frontend validation
147+ if decoded_data .get ("user" ):
148+ if not check_user (decoded_data ["user" ]):
149+ raise JWTError
150+
162151 except JWTError :
163152 raise HTTPException (
164153 status_code = status .HTTP_401_UNAUTHORIZED ,
165- detail = "Could not validate credentials" ,
154+ detail = "Could not validate credentials from frontend " ,
166155 headers = {"WWW-Authenticate" : "Bearer" },
167156 )
168157 return None
169158
170159
171- async def validate_password_token (token : Annotated [str , Depends (oauth2_scheme )]):
172- if auth_url :
173- headers = (
174- {}
175- if security_config .auth_type == "cookie"
176- else {"Authorization" : f"Bearer { token } " }
177- )
178- cookies = (
179- {security_config .cookie_key : token }
180- if security_config .auth_type == "cookie"
181- else {}
182- )
183- async with aiohttp .ClientSession (cookies = cookies ) as session :
184- async with session .get (
185- f"{ auth_url } /validate_token" ,
186- headers = headers ,
187- ) as response :
188- success = response .status == 200
189- validation_outcome = await response .json ()
190- if not (success and validation_outcome .get ("valid" )):
191- raise JWTError
192- return None
160+ def validate_session_against_visit (session_id : int , visit : str ):
161+ """
162+ Checks that the session ID is associated with the claimed visit.
163+ """
164+ with Session (engine ) as murfey_db :
165+ session_data = murfey_db .exec (
166+ select (MurfeySession ).where (MurfeySession .id == session_id )
167+ ).all ()
168+ if len (session_data ) != 1 :
169+ return False
170+ return visit == session_data [0 ].visit
193171
194172
195173async def validate_instrument_token (
196174 token : Annotated [str , Depends (instrument_oauth2_scheme )]
197175):
176+ """
177+ Used by the backend routers to check the incoming instrument server token.
178+ """
198179 try :
180+ # Validate using auth URL if provided
199181 if security_config .instrument_auth_url :
200182 async with aiohttp .ClientSession () as session :
201183 headers = (
@@ -212,33 +194,105 @@ async def validate_instrument_token(
212194 if not (success and validation_outcome .get ("valid" )):
213195 raise JWTError
214196 else :
215- if validators := entry_points ().select (
216- group = "murfey.auth.token_validation" ,
217- name = security_config .auth_type ,
218- ):
219- validators [0 ].load ()(token )
197+ # First, check if the token has expired
198+ decoded_data = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
199+ if expiry_time := decoded_data .get ("expiry_time" ):
200+ if expiry_time < time .time ():
201+ raise JWTError
202+ elif decoded_data .get ("session" ) is not None :
203+ # Check that the decoded session corresponds to the visit
204+ if not validate_session_against_visit (
205+ decoded_data ["session" ], decoded_data ["visit" ]
206+ ):
207+ raise JWTError
220208 else :
221209 raise JWTError
222210 except JWTError :
223211 raise HTTPException (
224212 status_code = status .HTTP_401_UNAUTHORIZED ,
225- detail = "Could not validate credentials" ,
213+ detail = "Could not validate credentials from instrument " ,
226214 headers = {"WWW-Authenticate" : "Bearer" },
227215 )
228216 return None
229217
230218
219+ """
220+ =======================================================================================
221+ VALIDATING SESSION IDS
222+ =======================================================================================
223+
224+ Annotated ints are defined here that trigger validation of the session IDs in incoming
225+ requests, verifying that the session is allowed to access the particular visit.
226+
227+ The 'MurfeySessionID...' types are imported and used in the type hints of the endpoint
228+ functions in the other FastAPI routers, depending on whether requests from the frontend
229+ or the instrument are expected.
230+ """
231+
232+
233+ async def validate_visit (visit_name : str , token : str , instrument_access : bool ) -> bool :
234+ """
235+ Validates the incoming token depending on whether it's from the instrument or the frontend
236+ """
237+
238+ # If this token is from the instrument server, validate this way
239+ if instrument_access :
240+ if security_config .instrument_auth_url :
241+ async with aiohttp .ClientSession () as session :
242+ headers = (
243+ {}
244+ if not security_config .instrument_auth_type
245+ else {"Authorization" : f"Bearer { token } " }
246+ )
247+ async with session .get (
248+ f"{ security_config .instrument_auth_url } /validate_visit_access/{ visit_name } " ,
249+ headers = headers ,
250+ ) as response :
251+ success = response .status == 200
252+ validation_outcome = await response .json ()
253+ if not (success and validation_outcome .get ("valid" )):
254+ return False
255+ # Otherwise, use this validation method
256+ else :
257+ if security_config .auth_url :
258+ headers = (
259+ {}
260+ if security_config .auth_type == "cookie"
261+ else {"Authorization" : f"Bearer { token } " }
262+ )
263+ cookies = (
264+ {security_config .cookie_key : token }
265+ if security_config .auth_type == "cookie"
266+ else {}
267+ )
268+ async with aiohttp .ClientSession (cookies = cookies ) as session :
269+ async with session .get (
270+ f"{ auth_url } /validate_visit_access/{ visit_name } " ,
271+ headers = headers ,
272+ ) as response :
273+ success = response .status == 200
274+ validation_outcome = await response .json ()
275+ if not (success and validation_outcome .get ("valid" )):
276+ return False
277+ return True
278+
279+
231280async def validate_session_access (
232- session_id : int , token : Annotated [str , Depends (oauth2_scheme )]
281+ session_id : int ,
282+ token : Annotated [str , Depends (oauth2_scheme )],
283+ instrument_access : bool ,
233284) -> int :
234- await validate_token (token )
285+ """
286+ Validates whether the request is authorised to access information about this session
287+ """
235288 with Session (engine ) as murfey_db :
236289 visit_name = (
237290 murfey_db .exec (select (MurfeySession ).where (MurfeySession .id == session_id ))
238291 .one ()
239292 .visit
240293 )
241- if not validate_visit (visit_name , token ):
294+ validated = await validate_visit (visit_name , token , instrument_access )
295+ if not validated :
242296 raise HTTPException (
243297 status_code = status .HTTP_401_UNAUTHORIZED ,
244298 detail = "You do not have access to this visit" ,
@@ -247,9 +301,53 @@ async def validate_session_access(
247301 return session_id
248302
249303
250- class Token (BaseModel ):
251- access_token : str
252- token_type : str
304+ def validate_session_access_wrapper (instrument_access : bool ):
305+ """
306+ Factory that returns an async wrapper arond 'validate_session_access' with the
307+ 'instrument_access' field preconfigured.
308+
309+ This is used to generate FastAPI-compatible dependencies for validating session
310+ access, based on the context of the request.
311+
312+ Unlike 'functools.partial', this approach preserves introspection compatibility
313+ required by FastAPI for dependency resolution and OpenAPI generation.
314+ """
315+
316+ async def _validate (session_id : int , token : Annotated [str , Depends (oauth2_scheme )]):
317+ return await validate_session_access (
318+ session_id , token , instrument_access = instrument_access
319+ )
320+
321+ return _validate
322+
323+
324+ # Set validation conditions for the session ID based on where the request is from
325+ MurfeySessionIDFrontend = Annotated [
326+ int , Depends (validate_session_access_wrapper (instrument_access = False ))
327+ ]
328+ MurfeySessionIDInstrument = Annotated [
329+ int , Depends (validate_session_access_wrapper (instrument_access = True ))
330+ ]
331+
332+
333+ """
334+ =======================================================================================
335+ API ENDPOINTS AND HELPER FUNCTIONS/CLASSES
336+ =======================================================================================
337+ """
338+
339+
340+ def verify_password (plain_password : str , hashed_password : str ) -> bool :
341+ return pwd_context .verify (plain_password , hashed_password )
342+
343+
344+ def validate_user (username : str , password : str ) -> bool :
345+ try :
346+ with Session (engine ) as murfey_db :
347+ user = murfey_db .exec (select (User ).where (User .username == username )).one ()
348+ except Exception :
349+ return False
350+ return verify_password (password , user .hashed_password )
253351
254352
255353def create_access_token (data : dict , token : str = "" ) -> str :
@@ -274,11 +372,9 @@ def create_access_token(data: dict, token: str = "") -> str:
274372 return encoded_jwt
275373
276374
277- MurfeySessionID = Annotated [int , Depends (validate_session_access )]
278-
279- """
280- API ENDPOINTS
281- """
375+ class Token (BaseModel ):
376+ access_token : str
377+ token_type : str
282378
283379
284380@router .post ("/token" )
@@ -313,7 +409,7 @@ async def generate_token(
313409
314410
315411@router .get ("/sessions/{session_id}/token" )
316- async def mint_session_token (session_id : MurfeySessionID , db = murfey_db ):
412+ async def mint_session_token (session_id : MurfeySessionIDFrontend , db = murfey_db ):
317413 visit = (
318414 db .exec (select (MurfeySession ).where (MurfeySession .id == session_id )).one ().visit
319415 )
0 commit comments