Skip to content

Commit 27a681a

Browse files
committed
Major refactor of 'murfey.server.api.auth', rearranging functions by purpose and splitting the authentication of instrument and frontend tokens into separate functions; created new annotated ints for type hinting in endpoints receiving requests from frontend and instrument server; updates the other server routers to use the newly created annotated ints
1 parent 606a957 commit 27a681a

File tree

8 files changed

+206
-105
lines changed

8 files changed

+206
-105
lines changed

src/murfey/server/api/auth.py

Lines changed: 194 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import aiohttp
1010
import requests
11-
from backports.entry_points_selectable import entry_points
1211
from fastapi import APIRouter, Depends, HTTPException, Request, status
1312
from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm
1413
from jose import JWTError, jwt
@@ -78,20 +77,6 @@ async def __call__(self, request: Request):
7877

7978
instrument_server_tokens: Dict[float, dict] = {}
8079

81-
82-
"""
83-
HELPER FUNCTIONS
84-
"""
85-
86-
87-
def verify_password(plain_password: str, hashed_password: str) -> bool:
88-
return pwd_context.verify(plain_password, hashed_password)
89-
90-
91-
def hash_password(password: str) -> str:
92-
return pwd_context.hash(password)
93-
94-
9580
# Set up database engine
9681
try:
9782
_url = url(security_config)
@@ -100,22 +85,19 @@ def hash_password(password: str) -> str:
10085
engine = None
10186

10287

103-
def validate_user(username: str, password: str) -> bool:
104-
try:
105-
with Session(engine) as murfey_db:
106-
user = murfey_db.exec(select(User).where(User.username == username)).one()
107-
except Exception:
108-
return False
109-
return verify_password(password, user.hashed_password)
88+
def hash_password(password: str) -> str:
89+
return pwd_context.hash(password)
11090

11191

112-
def validate_visit(visit_name: str, token: str) -> bool:
113-
if validators := entry_points().select(
114-
group="murfey.auth.session_validation",
115-
name=security_config.auth_type,
116-
):
117-
return validators[0].load()(visit_name, token)
118-
return True
92+
"""
93+
=======================================================================================
94+
TOKEN VALIDATION FUNCTIONS
95+
=======================================================================================
96+
97+
Functions and helpers used to validate incoming requests from both the client and
98+
the frontend. 'validate_token()' and 'validate_instrument_token()' are imported
99+
int the other FastAPI modules and attached as dependencies to the routers.
100+
"""
119101

120102

121103
def check_user(username: str) -> bool:
@@ -127,75 +109,75 @@ def check_user(username: str) -> bool:
127109
return username in [u.username for u in users]
128110

129111

130-
def validate_instrument_server_session_token(session_id: int, visit: str):
131-
with Session(engine) as murfey_db:
132-
session_data = murfey_db.exec(
133-
select(MurfeySession).where(MurfeySession.id == session_id)
134-
).all()
135-
if len(session_data) != 1:
136-
return False
137-
return visit == session_data[0].visit
138-
139-
140112
async def validate_token(token: Annotated[str, Depends(oauth2_scheme)]):
113+
"""
114+
Used by the backend routers to validate requests coming in from frontend.
115+
"""
141116
try:
142-
try:
143-
if security_config.auth_type == "password":
144-
await validate_password_token(token)
145-
except JWTError:
146-
await validate_instrument_token(token)
147-
decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
148-
# first check if the token has expired
149-
if expiry_time := decoded_data.get("expiry_time"):
150-
if expiry_time < time.time():
151-
raise JWTError
152-
if decoded_data.get("user"):
153-
if not check_user(decoded_data["user"]):
154-
raise JWTError
155-
elif decoded_data.get("session") is not None:
156-
if not validate_instrument_server_session_token(
157-
decoded_data["session"], decoded_data["visit"]
158-
):
117+
# Validate using auth URL if provided; will error if invalid
118+
if auth_url:
119+
headers = (
120+
{}
121+
if security_config.auth_type == "cookie"
122+
else {"Authorization": f"Bearer {token}"}
123+
)
124+
cookies = (
125+
{security_config.cookie_key: token}
126+
if security_config.auth_type == "cookie"
127+
else {}
128+
)
129+
async with aiohttp.ClientSession(cookies=cookies) as session:
130+
async with session.get(
131+
f"{auth_url}/validate_token",
132+
headers=headers,
133+
) as response:
134+
success = response.status == 200
135+
validation_outcome = await response.json()
136+
if not (success and validation_outcome.get("valid")):
159137
raise JWTError
138+
# Auth URL MUST be provided if authenticating using cookies
160139
else:
161-
raise JWTError
140+
if security_config.auth_type == "cookie":
141+
raise JWTError
142+
143+
# Validate using password
144+
if security_config.auth_type == "password":
145+
decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
146+
# Frontend validation
147+
if decoded_data.get("user"):
148+
if not check_user(decoded_data["user"]):
149+
raise JWTError
150+
162151
except JWTError:
163152
raise HTTPException(
164153
status_code=status.HTTP_401_UNAUTHORIZED,
165-
detail="Could not validate credentials",
154+
detail="Could not validate credentials from frontend",
166155
headers={"WWW-Authenticate": "Bearer"},
167156
)
168157
return None
169158

170159

171-
async def validate_password_token(token: Annotated[str, Depends(oauth2_scheme)]):
172-
if auth_url:
173-
headers = (
174-
{}
175-
if security_config.auth_type == "cookie"
176-
else {"Authorization": f"Bearer {token}"}
177-
)
178-
cookies = (
179-
{security_config.cookie_key: token}
180-
if security_config.auth_type == "cookie"
181-
else {}
182-
)
183-
async with aiohttp.ClientSession(cookies=cookies) as session:
184-
async with session.get(
185-
f"{auth_url}/validate_token",
186-
headers=headers,
187-
) as response:
188-
success = response.status == 200
189-
validation_outcome = await response.json()
190-
if not (success and validation_outcome.get("valid")):
191-
raise JWTError
192-
return None
160+
def validate_session_against_visit(session_id: int, visit: str):
161+
"""
162+
Checks that the session ID is associated with the claimed visit.
163+
"""
164+
with Session(engine) as murfey_db:
165+
session_data = murfey_db.exec(
166+
select(MurfeySession).where(MurfeySession.id == session_id)
167+
).all()
168+
if len(session_data) != 1:
169+
return False
170+
return visit == session_data[0].visit
193171

194172

195173
async def validate_instrument_token(
196174
token: Annotated[str, Depends(instrument_oauth2_scheme)]
197175
):
176+
"""
177+
Used by the backend routers to check the incoming instrument server token.
178+
"""
198179
try:
180+
# Validate using auth URL if provided
199181
if security_config.instrument_auth_url:
200182
async with aiohttp.ClientSession() as session:
201183
headers = (
@@ -212,33 +194,105 @@ async def validate_instrument_token(
212194
if not (success and validation_outcome.get("valid")):
213195
raise JWTError
214196
else:
215-
if validators := entry_points().select(
216-
group="murfey.auth.token_validation",
217-
name=security_config.auth_type,
218-
):
219-
validators[0].load()(token)
197+
# First, check if the token has expired
198+
decoded_data = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
199+
if expiry_time := decoded_data.get("expiry_time"):
200+
if expiry_time < time.time():
201+
raise JWTError
202+
elif decoded_data.get("session") is not None:
203+
# Check that the decoded session corresponds to the visit
204+
if not validate_session_against_visit(
205+
decoded_data["session"], decoded_data["visit"]
206+
):
207+
raise JWTError
220208
else:
221209
raise JWTError
222210
except JWTError:
223211
raise HTTPException(
224212
status_code=status.HTTP_401_UNAUTHORIZED,
225-
detail="Could not validate credentials",
213+
detail="Could not validate credentials from instrument",
226214
headers={"WWW-Authenticate": "Bearer"},
227215
)
228216
return None
229217

230218

219+
"""
220+
=======================================================================================
221+
VALIDATING SESSION IDS
222+
=======================================================================================
223+
224+
Annotated ints are defined here that trigger validation of the session IDs in incoming
225+
requests, verifying that the session is allowed to access the particular visit.
226+
227+
The 'MurfeySessionID...' types are imported and used in the type hints of the endpoint
228+
functions in the other FastAPI routers, depending on whether requests from the frontend
229+
or the instrument are expected.
230+
"""
231+
232+
233+
async def validate_visit(visit_name: str, token: str, instrument_access: bool) -> bool:
234+
"""
235+
Validates the incoming token depending on whether it's from the instrument or the frontend
236+
"""
237+
238+
# If this token is from the instrument server, validate this way
239+
if instrument_access:
240+
if security_config.instrument_auth_url:
241+
async with aiohttp.ClientSession() as session:
242+
headers = (
243+
{}
244+
if not security_config.instrument_auth_type
245+
else {"Authorization": f"Bearer {token}"}
246+
)
247+
async with session.get(
248+
f"{security_config.instrument_auth_url}/validate_visit_access/{visit_name}",
249+
headers=headers,
250+
) as response:
251+
success = response.status == 200
252+
validation_outcome = await response.json()
253+
if not (success and validation_outcome.get("valid")):
254+
return False
255+
# Otherwise, use this validation method
256+
else:
257+
if security_config.auth_url:
258+
headers = (
259+
{}
260+
if security_config.auth_type == "cookie"
261+
else {"Authorization": f"Bearer {token}"}
262+
)
263+
cookies = (
264+
{security_config.cookie_key: token}
265+
if security_config.auth_type == "cookie"
266+
else {}
267+
)
268+
async with aiohttp.ClientSession(cookies=cookies) as session:
269+
async with session.get(
270+
f"{auth_url}/validate_visit_access/{visit_name}",
271+
headers=headers,
272+
) as response:
273+
success = response.status == 200
274+
validation_outcome = await response.json()
275+
if not (success and validation_outcome.get("valid")):
276+
return False
277+
return True
278+
279+
231280
async def validate_session_access(
232-
session_id: int, token: Annotated[str, Depends(oauth2_scheme)]
281+
session_id: int,
282+
token: Annotated[str, Depends(oauth2_scheme)],
283+
instrument_access: bool,
233284
) -> int:
234-
await validate_token(token)
285+
"""
286+
Validates whether the request is authorised to access information about this session
287+
"""
235288
with Session(engine) as murfey_db:
236289
visit_name = (
237290
murfey_db.exec(select(MurfeySession).where(MurfeySession.id == session_id))
238291
.one()
239292
.visit
240293
)
241-
if not validate_visit(visit_name, token):
294+
validated = await validate_visit(visit_name, token, instrument_access)
295+
if not validated:
242296
raise HTTPException(
243297
status_code=status.HTTP_401_UNAUTHORIZED,
244298
detail="You do not have access to this visit",
@@ -247,9 +301,53 @@ async def validate_session_access(
247301
return session_id
248302

249303

250-
class Token(BaseModel):
251-
access_token: str
252-
token_type: str
304+
def validate_session_access_wrapper(instrument_access: bool):
305+
"""
306+
Factory that returns an async wrapper arond 'validate_session_access' with the
307+
'instrument_access' field preconfigured.
308+
309+
This is used to generate FastAPI-compatible dependencies for validating session
310+
access, based on the context of the request.
311+
312+
Unlike 'functools.partial', this approach preserves introspection compatibility
313+
required by FastAPI for dependency resolution and OpenAPI generation.
314+
"""
315+
316+
async def _validate(session_id: int, token: Annotated[str, Depends(oauth2_scheme)]):
317+
return await validate_session_access(
318+
session_id, token, instrument_access=instrument_access
319+
)
320+
321+
return _validate
322+
323+
324+
# Set validation conditions for the session ID based on where the request is from
325+
MurfeySessionIDFrontend = Annotated[
326+
int, Depends(validate_session_access_wrapper(instrument_access=False))
327+
]
328+
MurfeySessionIDInstrument = Annotated[
329+
int, Depends(validate_session_access_wrapper(instrument_access=True))
330+
]
331+
332+
333+
"""
334+
=======================================================================================
335+
API ENDPOINTS AND HELPER FUNCTIONS/CLASSES
336+
=======================================================================================
337+
"""
338+
339+
340+
def verify_password(plain_password: str, hashed_password: str) -> bool:
341+
return pwd_context.verify(plain_password, hashed_password)
342+
343+
344+
def validate_user(username: str, password: str) -> bool:
345+
try:
346+
with Session(engine) as murfey_db:
347+
user = murfey_db.exec(select(User).where(User.username == username)).one()
348+
except Exception:
349+
return False
350+
return verify_password(password, user.hashed_password)
253351

254352

255353
def create_access_token(data: dict, token: str = "") -> str:
@@ -274,11 +372,9 @@ def create_access_token(data: dict, token: str = "") -> str:
274372
return encoded_jwt
275373

276374

277-
MurfeySessionID = Annotated[int, Depends(validate_session_access)]
278-
279-
"""
280-
API ENDPOINTS
281-
"""
375+
class Token(BaseModel):
376+
access_token: str
377+
token_type: str
282378

283379

284380
@router.post("/token")
@@ -313,7 +409,7 @@ async def generate_token(
313409

314410

315411
@router.get("/sessions/{session_id}/token")
316-
async def mint_session_token(session_id: MurfeySessionID, db=murfey_db):
412+
async def mint_session_token(session_id: MurfeySessionIDFrontend, db=murfey_db):
317413
visit = (
318414
db.exec(select(MurfeySession).where(MurfeySession.id == session_id)).one().visit
319415
)

src/murfey/server/api/file_manip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from sqlmodel import select
99
from werkzeug.utils import secure_filename
1010

11-
from murfey.server.api.auth import MurfeySessionID, validate_instrument_token
11+
from murfey.server.api.auth import MurfeySessionIDInstrument as MurfeySessionID
12+
from murfey.server.api.auth import validate_instrument_token
1213
from murfey.server.gain import Camera, prepare_eer_gain, prepare_gain
1314
from murfey.server.murfey_db import murfey_db
1415
from murfey.util import sanitise, secure_path

0 commit comments

Comments
 (0)