diff --git a/src/ground_station_v2/api.py b/src/ground_station_v2/api.py index ffede36..54186ca 100644 --- a/src/ground_station_v2/api.py +++ b/src/ground_station_v2/api.py @@ -1,4 +1,5 @@ -from fastapi import FastAPI, Header, WebSocket, WebSocketException, WebSocketDisconnect, HTTPException +from fastapi import FastAPI, Query, WebSocket, WebSocketException, WebSocketDisconnect, HTTPException +from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from pathlib import Path from dataclasses import dataclass @@ -38,11 +39,11 @@ class ClientState: def get_live_clients(): return {client_id: state.websocket for client_id, state in connected_clients.items() if state.mode == "live"} -async def get_client_state(x_client_id: str = Header(alias="X-Client-ID")) -> ClientState: +async def get_client_state(client_id: str) -> ClientState: """FastAPI dependency to validate and retrieve client state""" - if x_client_id not in connected_clients: + if client_id not in connected_clients: raise HTTPException(status_code=401, detail="Client not connected") - return connected_clients[x_client_id] + return connected_clients[client_id] async def start_client_replay(client_id: str, state: ClientState, replay_path: str, speed: float) -> ClientReplayState: """Setup and start a client replay session""" @@ -71,28 +72,34 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) -# all the requests below expect the X-Client-ID header to be set with the client_id from the request above +# all the requests below expect the client_id query parameter to be set with the client_id from the request above # for now it's not checked whether it was generated by the get_client_id endpoint or not, this will be changed -# TODO: include the header id check in middleware +# TODO: include the client_id check in middleware @app.post("/replay_play") async def replay_play( replay_path: str, speed: float = 1.0, - x_client_id: str = Header(alias="X-Client-ID") + client_id: str = Query(alias="client_id") ): - state = await get_client_state(x_client_id) + state = await get_client_state(client_id) try: if state.replay: await state.replay.cleanup() state.mode = "replay" - state.replay = await start_client_replay(x_client_id, state, replay_path, speed) + state.replay = await start_client_replay(client_id, state, replay_path, speed) - logger.info(f"Replay started for client {x_client_id}: {replay_path} at {speed}x") + logger.info(f"Replay started for client {client_id}: {replay_path} at {speed}x") return {"status": "ok", "replay_path": replay_path, "speed": speed} except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -102,8 +109,8 @@ async def replay_play( @app.post("/resume_live") -async def resume_live(x_client_id: str = Header(alias="X-Client-ID")): - state = await get_client_state(x_client_id) +async def resume_live(client_id: str = Query(alias="client_id")): + state = await get_client_state(client_id) if state.replay: await state.replay.cleanup() @@ -111,7 +118,7 @@ async def resume_live(x_client_id: str = Header(alias="X-Client-ID")): state.mode = "live" state.replay = None - logger.info(f"Client {x_client_id} resumed live mode") + logger.info(f"Client {client_id} resumed live mode") return {"status": "ok"} @@ -119,52 +126,52 @@ async def resume_live(x_client_id: str = Header(alias="X-Client-ID")): async def replay_pause( paused: bool = True, speed: float = 1.0, - x_client_id: str = Header(alias="X-Client-ID") + client_id: str = Query(alias="client_id") ): - state = await get_client_state(x_client_id) + state = await get_client_state(client_id) if not state.replay or not state.replay.instance.is_playing(): raise HTTPException(status_code=400, detail="No replay is currently active") if paused: state.replay.instance.pause() - logger.info(f"Replay paused for client {x_client_id}") + logger.info(f"Replay paused for client {client_id}") else: state.replay.instance.resume(speed) - logger.info(f"Replay resumed for client {x_client_id} at {speed}x") + logger.info(f"Replay resumed for client {client_id} at {speed}x") return {"status": "ok", "paused": paused} @app.post("/replay_stop") -async def replay_stop(x_client_id: str = Header(alias="X-Client-ID")): - state = await get_client_state(x_client_id) +async def replay_stop(client_id: str = Query(alias="client_id")): + state = await get_client_state(client_id) if state.replay and state.replay.instance.is_playing(): state.replay.instance.stop() await state.replay.cleanup() - logger.info(f"Replay stopped for client {x_client_id}") + logger.info(f"Replay stopped for client {client_id}") return {"status": "ok"} @app.post("/record_start") -async def record_start(x_client_id: str = Header(alias="X-Client-ID")): +async def record_start(client_id: str = Query(alias="client_id")): recorder.start() - logger.info(f"Record start for client {x_client_id}") + logger.info(f"Record start for client {client_id}") return {"status": "ok"} @app.post("/record_stop") -async def record_stop(x_client_id: str = Header(alias="X-Client-ID")): +async def record_stop(client_id: str = Query(alias="client_id")): recorder.stop() - logger.info(f"Record stop for client {x_client_id}") + logger.info(f"Record stop for client {client_id}") return {"status": "ok"} @app.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket, x_client_id: str = Header(alias="X-Client-ID")): - if x_client_id in connected_clients: - logger.warning(f"Client already connected: {x_client_id}") +async def websocket_endpoint(websocket: WebSocket, client_id: str = Query(alias="client_id")): + if client_id in connected_clients: + logger.warning(f"Client already connected: {client_id}") raise WebSocketException(code=1008, reason="Client already connected") await websocket.accept() @@ -173,8 +180,8 @@ async def websocket_endpoint(websocket: WebSocket, x_client_id: str = Header(ali websocket=websocket, mode="live" ) - connected_clients[x_client_id] = state - logger.info(f"Client connected: {x_client_id}") + connected_clients[client_id] = state + logger.info(f"Client connected: {client_id}") try: while True: @@ -182,8 +189,8 @@ async def websocket_endpoint(websocket: WebSocket, x_client_id: str = Header(ali except WebSocketDisconnect: if state.replay: await state.replay.cleanup() - connected_clients.pop(x_client_id, None) - logger.info(f"Client disconnected: {x_client_id}") + connected_clients.pop(client_id, None) + logger.info(f"Client disconnected: {client_id}") def run_server(host: str = "0.0.0.0", port: int = 8000, from_recording: Path | None = None):