Skip to content

Commit 0adc0e1

Browse files
committed
chore: api routers
1 parent 8da02c3 commit 0adc0e1

File tree

4 files changed

+126
-0
lines changed

4 files changed

+126
-0
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""API routers package."""
2+
3+
from fastapi import APIRouter
4+
5+
router = APIRouter()
6+
7+
from . import auth as auth_router # noqa: E402,F401
8+
from . import backtest as backtest_router # noqa: E402,F401
9+
from . import assets as assets_router # noqa: E402,F401
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Assets router to expose available symbols / sample data."""
2+
from fastapi import APIRouter
3+
from quant_research_starter.data.sample_loader import SampleDataLoader
4+
5+
router = APIRouter(prefix="/api/assets", tags=["assets"])
6+
7+
8+
@router.get("/")
9+
async def list_assets():
10+
loader = SampleDataLoader()
11+
df = loader.load_sample_prices()
12+
symbols = []
13+
for sym in df.columns:
14+
symbols.append({"symbol": sym, "price": float(df[sym].iloc[-1])})
15+
return symbols
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""Authentication routes: register and token endpoints."""
2+
from fastapi import APIRouter, Depends, HTTPException, status
3+
from sqlalchemy.ext.asyncio import AsyncSession
4+
from fastapi.security import OAuth2PasswordRequestForm
5+
6+
from .. import schemas, db, models, auth
7+
8+
router = APIRouter(prefix="/api/auth", tags=["auth"])
9+
10+
11+
@router.post("/register", response_model=schemas.UserRead)
12+
async def register_user(user_in: schemas.UserCreate, session: AsyncSession = Depends(db.get_session)):
13+
q = await session.execute(models.User.__table__.select().where(models.User.username == user_in.username))
14+
if q.first():
15+
raise HTTPException(status_code=400, detail="Username already registered")
16+
hashed = auth.get_password_hash(user_in.password)
17+
user = models.User(username=user_in.username, hashed_password=hashed)
18+
session.add(user)
19+
await session.commit()
20+
await session.refresh(user)
21+
return schemas.UserRead(id=user.id, username=user.username, is_active=user.is_active, role=user.role)
22+
23+
24+
@router.post("/token", response_model=schemas.Token)
25+
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends(), session: AsyncSession = Depends(db.get_session)):
26+
q = await session.execute(models.User.__table__.select().where(models.User.username == form_data.username))
27+
row = q.first()
28+
if not row:
29+
raise HTTPException(status_code=400, detail="Incorrect username or password")
30+
user = row[0]
31+
if not auth.verify_password(form_data.password, user.hashed_password):
32+
raise HTTPException(status_code=400, detail="Incorrect username or password")
33+
34+
access_token = auth.create_access_token(data={"sub": user.username})
35+
return {"access_token": access_token, "token_type": "bearer"}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Backtest endpoints: enqueue backtest jobs and fetch results."""
2+
from __future__ import annotations
3+
import uuid
4+
import os
5+
from fastapi import APIRouter, Depends, BackgroundTasks, UploadFile, File, HTTPException, WebSocket
6+
from sqlalchemy.ext.asyncio import AsyncSession
7+
from typing import Optional
8+
9+
from .. import schemas, db, models, auth
10+
from ..tasks.celery_app import celery_app
11+
from ..utils.ws_manager import manager
12+
13+
router = APIRouter(prefix="/api/backtest", tags=["backtest"])
14+
15+
16+
@router.post("/", response_model=schemas.BacktestStatus)
17+
async def submit_backtest(req: schemas.BacktestRequest, current_user=Depends(auth.require_active_user), session: AsyncSession = Depends(db.get_session)):
18+
# Create job
19+
job_id = uuid.uuid4().hex
20+
job = models.BacktestJob(id=job_id, user_id=current_user.id, status="queued", params=req.dict())
21+
session.add(job)
22+
await session.commit()
23+
24+
# Enqueue celery task
25+
celery_app.send_task("quant_research_starter.api.tasks.tasks.run_backtest", args=[job_id, req.dict()])
26+
27+
return {"job_id": job_id, "status": "queued"}
28+
29+
30+
@router.get("/{job_id}/results")
31+
async def get_results(job_id: str, current_user=Depends(auth.require_active_user), session: AsyncSession = Depends(db.get_session)):
32+
q = await session.execute(models.BacktestJob.__table__.select().where(models.BacktestJob.id == job_id))
33+
row = q.first()
34+
if not row:
35+
raise HTTPException(status_code=404, detail="Job not found")
36+
job = row[0]
37+
if job.user_id != current_user.id and current_user.role != "admin":
38+
raise HTTPException(status_code=403, detail="Not authorized to view this job")
39+
40+
if job.result_path and os.path.exists(job.result_path):
41+
import json
42+
43+
with open(job.result_path, "r") as f:
44+
return json.load(f)
45+
return {"status": job.status}
46+
47+
48+
@router.websocket("/ws/{job_id}")
49+
async def websocket_backtest(websocket: WebSocket, job_id: str):
50+
"""WebSocket endpoint that registers the client and relays messages from Redis pub/sub.
51+
52+
The Redis listener broadcasts messages to the ConnectionManager which then sends
53+
them to connected WebSocket clients.
54+
"""
55+
await manager.connect(job_id, websocket)
56+
try:
57+
while True:
58+
# keep the connection alive; client may send ping messages
59+
msg = await websocket.receive_text()
60+
# ignore incoming messages; server pushes updates
61+
await websocket.send_text("ok")
62+
except Exception:
63+
manager.disconnect(job_id, websocket)
64+
try:
65+
await websocket.close()
66+
except Exception:
67+
pass

0 commit comments

Comments
 (0)