|
1 | | -import os |
2 | | - |
3 | 1 | import typer |
4 | 2 | import uvicorn |
5 | | -from fastapi import FastAPI, Request |
6 | | -from fastapi.responses import HTMLResponse |
7 | | -from fastapi.templating import Jinja2Templates |
8 | 3 | from utils import get_token |
9 | 4 |
|
10 | 5 | app = typer.Typer() |
11 | | -api = FastAPI( |
12 | | - title="Dashboard API", |
13 | | - description="API for the dashboard application", |
14 | | - version="1.0.0", |
15 | | -) |
16 | | - |
17 | | -# Set up templates directory |
18 | | -templates = Jinja2Templates( |
19 | | - directory=os.path.join(os.path.dirname(__file__), "templates") |
20 | | -) |
21 | | - |
22 | | -# Training metrics storage |
23 | | -training_metrics = { |
24 | | - "iteration": 0, |
25 | | - "loss": 0.0, |
26 | | - "perplexity": 0.0, |
27 | | - "accuracy": 0.0, |
28 | | - "learning_rate": 0.0, |
29 | | - "gradient_norm": 0.0, |
30 | | - "bits_memorized": 0.0, |
31 | | - "bits_per_second": 0.0, |
32 | | - "gpu_utilization": [0.0] * 1024, # For 1024 GPUs |
33 | | - "skills": { |
34 | | - "Translation": 20, |
35 | | - "Summarization": 15, |
36 | | - "Reasoning": 10, |
37 | | - "Coding": 5, |
38 | | - "Comprehension": 25, |
39 | | - }, |
40 | | -} |
41 | | - |
42 | | - |
43 | | -@api.get("/", response_class=HTMLResponse) |
44 | | -async def root(request: Request): |
45 | | - """Serve the dashboard HTML""" |
46 | | - return templates.TemplateResponse("dashboard.html", {"request": request}) |
47 | | - |
48 | | - |
49 | | -@api.post("/update_metrics") |
50 | | -async def update_metrics(metrics: dict): |
51 | | - """Receive updated metrics from PyTorch training""" |
52 | | - # Update our stored metrics |
53 | | - training_metrics.update(metrics) |
54 | | - return {"status": "success"} |
55 | | - |
56 | | - |
57 | | -@api.get("/metrics") |
58 | | -async def get_metrics(): |
59 | | - """Return current training metrics""" |
60 | | - return training_metrics |
61 | 6 |
|
62 | 7 |
|
63 | 8 | @app.command() |
64 | 9 | def run(): |
65 | 10 | """Start the FastAPI server""" |
66 | 11 | uvicorn.run( |
67 | | - "dashboard.cli:api", |
| 12 | + "dashboard.api:api", |
68 | 13 | host=get_token("DASHBOARD_HOST"), |
69 | 14 | port=int(get_token("DASHBOARD_PORT")), |
70 | 15 | reload=True, |
|
0 commit comments