Skip to content

Commit 51bd421

Browse files
authored
Validate user instrument access for the frontend API endpoints (#608)
1 parent 041056e commit 51bd421

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

src/murfey/server/api/auth.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,7 @@ def get_visit_name(session_id: int) -> str:
219219
)
220220

221221

222-
async def validate_frontend_session_access(
223-
session_id: int,
224-
token: Annotated[str, Depends(oauth2_scheme)],
225-
) -> int:
226-
"""
227-
Validates whether a frontend request can access information about this session
228-
"""
229-
visit_name = get_visit_name(session_id)
230-
222+
async def submit_to_auth_endpoint(url_subpath: str, token: str) -> None:
231223
if auth_url:
232224
headers = (
233225
{}
@@ -241,7 +233,7 @@ async def validate_frontend_session_access(
241233
)
242234
async with aiohttp.ClientSession(cookies=cookies) as session:
243235
async with session.get(
244-
f"{auth_url}/validate_visit_access/{visit_name}",
236+
f"{auth_url}/{url_subpath}",
245237
headers=headers,
246238
) as response:
247239
success = response.status == 200
@@ -253,10 +245,21 @@ async def validate_frontend_session_access(
253245
detail="You do not have access to this visit",
254246
headers={"WWW-Authenticate": "Bearer"},
255247
)
248+
249+
250+
async def validate_frontend_session_access(
251+
session_id: int,
252+
token: Annotated[str, Depends(oauth2_scheme)],
253+
) -> int:
254+
"""
255+
Validates whether a frontend request can access information about this session
256+
"""
257+
visit_name = get_visit_name(session_id)
258+
await submit_to_auth_endpoint(f"validate_visit_access/{visit_name}", token)
256259
return session_id
257260

258261

259-
async def validate_instrument_session_access(
262+
async def validate_instrument_server_session_access(
260263
session_id: int,
261264
token: Annotated[str, Depends(instrument_oauth2_scheme)],
262265
) -> int:
@@ -288,9 +291,26 @@ async def validate_instrument_session_access(
288291
return session_id
289292

290293

294+
async def validate_user_instrument_access(
295+
instrument_name: str,
296+
token: Annotated[str, Depends(oauth2_scheme)],
297+
) -> str:
298+
"""
299+
Validates whether a frontend request can access information about this instrument
300+
"""
301+
await submit_to_auth_endpoint(
302+
f"validate_instrument_access/{instrument_name}", token
303+
)
304+
return instrument_name
305+
306+
291307
# Set validation conditions for the session ID based on where the request is from
292308
MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
293-
MurfeySessionIDInstrument = Annotated[int, Depends(validate_instrument_session_access)]
309+
MurfeySessionIDInstrument = Annotated[
310+
int, Depends(validate_instrument_server_session_access)
311+
]
312+
313+
MurfeyInstrumentNameFrontend = Annotated[str, Depends(validate_user_instrument_access)]
294314

295315

296316
"""

src/murfey/server/api/instrument.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlmodel import select
1313
from werkzeug.utils import secure_filename
1414

15+
from murfey.server.api.auth import MurfeyInstrumentNameFrontend as MurfeyInstrumentName
1516
from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID
1617
from murfey.server.api.auth import (
1718
create_access_token,
@@ -42,7 +43,7 @@
4243
"/instruments/{instrument_name}/sessions/{session_id}/activate_instrument_server"
4344
)
4445
async def activate_instrument_server_for_session(
45-
instrument_name: str,
46+
instrument_name: MurfeyInstrumentName,
4647
session_id: int,
4748
token_in: Annotated[str, Depends(oauth2_scheme)],
4849
db=murfey_db,
@@ -80,7 +81,9 @@ async def activate_instrument_server_for_session(
8081

8182

8283
@router.get("/instruments/{instrument_name}/sessions/{session_id}/active")
83-
async def check_if_session_is_active(instrument_name: str, session_id: int):
84+
async def check_if_session_is_active(
85+
instrument_name: MurfeyInstrumentName, session_id: int
86+
):
8487
if instrument_server_tokens.get(session_id) is None:
8588
return {"active": False}
8689
async with lock:
@@ -214,7 +217,7 @@ async def pass_proc_params_to_instrument_server(
214217

215218

216219
@router.get("/instruments/{instrument_name}/instrument_server")
217-
async def check_instrument_server(instrument_name: str):
220+
async def check_instrument_server(instrument_name: MurfeyInstrumentName):
218221
data = None
219222
machine_config = get_machine_config(instrument_name=instrument_name)[
220223
instrument_name
@@ -232,7 +235,7 @@ async def check_instrument_server(instrument_name: str):
232235
"/instruments/{instrument_name}/sessions/{session_id}/possible_gain_references"
233236
)
234237
async def get_possible_gain_references(
235-
instrument_name: str, session_id: MurfeySessionID
238+
instrument_name: MurfeyInstrumentName, session_id: MurfeySessionID
236239
) -> List[File]:
237240
data = []
238241
machine_config = get_machine_config(instrument_name=instrument_name)[
@@ -491,7 +494,7 @@ class RSyncerInfo(BaseModel):
491494

492495
@router.get("/instruments/{instrument_name}/sessions/{session_id}/rsyncer_info")
493496
async def get_rsyncer_info(
494-
instrument_name: str, session_id: MurfeySessionID, db=murfey_db
497+
instrument_name: MurfeyInstrumentName, session_id: MurfeySessionID, db=murfey_db
495498
) -> List[RSyncerInfo]:
496499
rsyncer_list = []
497500
analyser_list = []

src/murfey/server/api/session_info.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import murfey.server.api.websocket as ws
1313
from murfey.server import _transport_object
1414
from murfey.server.api import templates
15+
from murfey.server.api.auth import MurfeyInstrumentNameFrontend as MurfeyInstrumentName
1516
from murfey.server.api.auth import MurfeySessionIDFrontend as MurfeySessionID
1617
from murfey.server.api.auth import validate_token
1718
from murfey.server.api.shared import get_foil_hole as _get_foil_hole
@@ -74,20 +75,24 @@ def connections_check():
7475

7576

7677
@router.get("/instruments/{instrument_name}/machine")
77-
def machine_info_by_instrument(instrument_name: str) -> Optional[MachineConfig]:
78+
def machine_info_by_instrument(
79+
instrument_name: MurfeyInstrumentName,
80+
) -> Optional[MachineConfig]:
7881
return get_machine_config_for_instrument(instrument_name)
7982

8083

8184
@router.get("/instruments/{instrument_name}/visits_raw", response_model=List[Visit])
82-
def get_current_visits(instrument_name: str, db=ispyb_db):
85+
def get_current_visits(instrument_name: MurfeyInstrumentName, db=ispyb_db):
8386
logger.debug(
8487
f"Received request to look up ongoing visits for {sanitise(instrument_name)}"
8588
)
8689
return get_all_ongoing_visits(instrument_name, db)
8790

8891

8992
@router.get("/instruments/{instrument_name}/visits/")
90-
def all_visit_info(instrument_name: str, request: Request, db=ispyb_db):
93+
def all_visit_info(
94+
instrument_name: MurfeyInstrumentName, request: Request, db=ispyb_db
95+
):
9196
visits = get_all_ongoing_visits(instrument_name, db)
9297

9398
if visits:
@@ -159,7 +164,7 @@ class VisitEndTime(BaseModel):
159164

160165
@router.post("/instruments/{instrument_name}/visits/{visit}/session/{name}")
161166
def create_session(
162-
instrument_name: str,
167+
instrument_name: MurfeyInstrumentName,
163168
visit: str,
164169
name: str,
165170
visit_end_time: VisitEndTime,
@@ -195,7 +200,7 @@ def remove_session(session_id: MurfeySessionID, db=murfey_db):
195200

196201
@router.get("/instruments/{instrument_name}/visits/{visit_name}/sessions")
197202
def get_sessions_with_visit(
198-
instrument_name: str, visit_name: str, db=murfey_db
203+
instrument_name: MurfeyInstrumentName, visit_name: str, db=murfey_db
199204
) -> List[Session]:
200205
sessions = db.exec(
201206
select(Session)
@@ -207,7 +212,7 @@ def get_sessions_with_visit(
207212

208213
@router.get("/instruments/{instrument_name}/sessions")
209214
async def get_sessions_by_instrument_name(
210-
instrument_name: str, db=murfey_db
215+
instrument_name: MurfeyInstrumentName, db=murfey_db
211216
) -> List[Session]:
212217
sessions = db.exec(
213218
select(Session).where(Session.instrument_name == instrument_name)

0 commit comments

Comments
 (0)