Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 38 additions & 31 deletions src/ground_station_v2/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import FastAPI, Header, WebSocket, WebSocketException, WebSocketDisconnect, HTTPException
from fastapi import FastAPI, Query, WebSocket, WebSocketException, WebSocketDisconnect, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from pathlib import Path
from dataclasses import dataclass
Expand Down Expand Up @@ -38,11 +39,11 @@ class ClientState:
def get_live_clients():
return {client_id: state.websocket for client_id, state in connected_clients.items() if state.mode == "live"}

async def get_client_state(x_client_id: str = Header(alias="X-Client-ID")) -> ClientState:
async def get_client_state(client_id: str) -> ClientState:
"""FastAPI dependency to validate and retrieve client state"""
if x_client_id not in connected_clients:
if client_id not in connected_clients:
raise HTTPException(status_code=401, detail="Client not connected")
return connected_clients[x_client_id]
return connected_clients[client_id]

async def start_client_replay(client_id: str, state: ClientState, replay_path: str, speed: float) -> ClientReplayState:
"""Setup and start a client replay session"""
Expand Down Expand Up @@ -71,28 +72,34 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)

# all the requests below expect the X-Client-ID header to be set with the client_id from the request above
# all the requests below expect the client_id query parameter to be set with the client_id from the request above
# for now it's not checked whether it was generated by the get_client_id endpoint or not, this will be changed

# TODO: include the header id check in middleware
# TODO: include the client_id check in middleware

@app.post("/replay_play")
async def replay_play(
replay_path: str,
speed: float = 1.0,
x_client_id: str = Header(alias="X-Client-ID")
client_id: str = Query(alias="client_id")
):
state = await get_client_state(x_client_id)
state = await get_client_state(client_id)

try:
if state.replay:
await state.replay.cleanup()

state.mode = "replay"
state.replay = await start_client_replay(x_client_id, state, replay_path, speed)
state.replay = await start_client_replay(client_id, state, replay_path, speed)

logger.info(f"Replay started for client {x_client_id}: {replay_path} at {speed}x")
logger.info(f"Replay started for client {client_id}: {replay_path} at {speed}x")
return {"status": "ok", "replay_path": replay_path, "speed": speed}
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
Expand All @@ -102,69 +109,69 @@ async def replay_play(


@app.post("/resume_live")
async def resume_live(x_client_id: str = Header(alias="X-Client-ID")):
state = await get_client_state(x_client_id)
async def resume_live(client_id: str = Query(alias="client_id")):
state = await get_client_state(client_id)

if state.replay:
await state.replay.cleanup()

state.mode = "live"
state.replay = None

logger.info(f"Client {x_client_id} resumed live mode")
logger.info(f"Client {client_id} resumed live mode")
return {"status": "ok"}


@app.post("/replay_pause")
async def replay_pause(
paused: bool = True,
speed: float = 1.0,
x_client_id: str = Header(alias="X-Client-ID")
client_id: str = Query(alias="client_id")
):
state = await get_client_state(x_client_id)
state = await get_client_state(client_id)
if not state.replay or not state.replay.instance.is_playing():
raise HTTPException(status_code=400, detail="No replay is currently active")

if paused:
state.replay.instance.pause()
logger.info(f"Replay paused for client {x_client_id}")
logger.info(f"Replay paused for client {client_id}")
else:
state.replay.instance.resume(speed)
logger.info(f"Replay resumed for client {x_client_id} at {speed}x")
logger.info(f"Replay resumed for client {client_id} at {speed}x")

return {"status": "ok", "paused": paused}


@app.post("/replay_stop")
async def replay_stop(x_client_id: str = Header(alias="X-Client-ID")):
state = await get_client_state(x_client_id)
async def replay_stop(client_id: str = Query(alias="client_id")):
state = await get_client_state(client_id)

if state.replay and state.replay.instance.is_playing():
state.replay.instance.stop()
await state.replay.cleanup()
logger.info(f"Replay stopped for client {x_client_id}")
logger.info(f"Replay stopped for client {client_id}")

return {"status": "ok"}


@app.post("/record_start")
async def record_start(x_client_id: str = Header(alias="X-Client-ID")):
async def record_start(client_id: str = Query(alias="client_id")):
recorder.start()
logger.info(f"Record start for client {x_client_id}")
logger.info(f"Record start for client {client_id}")
return {"status": "ok"}


@app.post("/record_stop")
async def record_stop(x_client_id: str = Header(alias="X-Client-ID")):
async def record_stop(client_id: str = Query(alias="client_id")):
recorder.stop()
logger.info(f"Record stop for client {x_client_id}")
logger.info(f"Record stop for client {client_id}")
return {"status": "ok"}


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, x_client_id: str = Header(alias="X-Client-ID")):
if x_client_id in connected_clients:
logger.warning(f"Client already connected: {x_client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str = Query(alias="client_id")):
if client_id in connected_clients:
logger.warning(f"Client already connected: {client_id}")
raise WebSocketException(code=1008, reason="Client already connected")

await websocket.accept()
Expand All @@ -173,17 +180,17 @@ async def websocket_endpoint(websocket: WebSocket, x_client_id: str = Header(ali
websocket=websocket,
mode="live"
)
connected_clients[x_client_id] = state
logger.info(f"Client connected: {x_client_id}")
connected_clients[client_id] = state
logger.info(f"Client connected: {client_id}")

try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
if state.replay:
await state.replay.cleanup()
connected_clients.pop(x_client_id, None)
logger.info(f"Client disconnected: {x_client_id}")
connected_clients.pop(client_id, None)
logger.info(f"Client disconnected: {client_id}")


def run_server(host: str = "0.0.0.0", port: int = 8000, from_recording: Path | None = None):
Expand Down
Loading