diff --git a/skyrl-train/skyrl_train/utils/tracking.py b/skyrl-train/skyrl_train/utils/tracking.py index 5a6f45d8e..7734f853f 100644 --- a/skyrl-train/skyrl_train/utils/tracking.py +++ b/skyrl-train/skyrl_train/utils/tracking.py @@ -27,7 +27,7 @@ # TODO(tgriggs): Test all backends. class Tracking: - supported_backends = ["wandb", "mlflow", "swanlab", "tensorboard", "console"] + supported_backends = ["wandb", "mlflow", "swanlab", "tensorboard", "console", "tviz"] def __init__(self, project_name, experiment_name, backends: Union[str, List[str]] = "console", config=None): if isinstance(backends, str): @@ -73,6 +73,11 @@ def __init__(self, project_name, experiment_name, backends: Union[str, List[str] self.console_logger = ConsoleLogger() self.logger["console"] = self.console_logger + if "tviz" in backends: + from skyrl_train.utils.tviz_tracker import TvizTracker + + self.logger["tviz"] = TvizTracker(experiment_name=experiment_name, config=config) + def log(self, data, step, commit=False): for logger_name, logger_instance in self.logger.items(): if logger_name == "wandb": @@ -94,6 +99,8 @@ def __del__(self): self.logger["tensorboard"].finish() if "mlflow" in self.logger: self.logger["mlflow"].finish() + if "tviz" in self.logger: + self.logger["tviz"].finish() except Exception as e: logger.warning(f"Attempted to finish tracking but got error {e}") diff --git a/skyrl-train/skyrl_train/utils/tviz_tracker.py b/skyrl-train/skyrl_train/utils/tviz_tracker.py new file mode 100644 index 000000000..7cac248f3 --- /dev/null +++ b/skyrl-train/skyrl_train/utils/tviz_tracker.py @@ -0,0 +1,185 @@ +"""TViz tracking backend for local visualization of SkyRL training runs.""" + +import json +import os +import sqlite3 +import uuid +from pathlib import Path +from typing import Any + + +METRIC_MAPPING = { + "reward/avg_raw_reward": "reward_mean", + "reward/avg_reward": "reward_mean", + "reward/mean_positive_reward": "reward_mean", + "reward_mean": "reward_mean", + "policy/loss": "loss", + "critic/loss": "loss", + "loss": "loss", + "policy/kl_divergence": "kl_divergence", + "kl_divergence": "kl_divergence", + "kl": "kl_divergence", + "policy/entropy": "entropy", + "entropy": "entropy", + "trainer/learning_rate": "learning_rate", + "learning_rate": "learning_rate", + "lr": "learning_rate", + "generate/avg_num_tokens": "ac_tokens_per_turn", + "generate/avg_tokens_non_zero_rewards": "ac_tokens_per_turn", + "ac_tokens_per_turn": "ac_tokens_per_turn", + "timing/total": "time_total", + "timing/generate": "sampling_time_mean", + "time_total": "time_total", +} + +# Schema for training visualization tables (added to tinker.db) +SCHEMA = """ +CREATE TABLE IF NOT EXISTS training_runs ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + type TEXT DEFAULT 'rl', + modality TEXT DEFAULT 'text', + config TEXT, + started_at TEXT DEFAULT CURRENT_TIMESTAMP, + ended_at TEXT +); + +CREATE TABLE IF NOT EXISTS training_steps ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + run_id TEXT NOT NULL, + step INTEGER NOT NULL, + reward_mean REAL, + reward_std REAL, + loss REAL, + kl_divergence REAL, + entropy REAL, + learning_rate REAL, + ac_tokens_per_turn REAL, + ob_tokens_per_turn REAL, + total_ac_tokens INTEGER, + total_turns INTEGER, + sampling_time_mean REAL, + time_total REAL, + frac_mixed REAL, + frac_all_good REAL, + frac_all_bad REAL, + extras TEXT, + created_at TEXT DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (run_id) REFERENCES training_runs(id), + UNIQUE(run_id, step) +); + +CREATE INDEX IF NOT EXISTS idx_training_steps_run_id ON training_steps(run_id); +""" + + +def get_tinker_db_path() -> Path: + """Get the tinker database path from environment or default.""" + if db_url := os.environ.get("TX_DATABASE_URL"): + if db_url.startswith("sqlite:///"): + return Path(db_url.replace("sqlite:///", "")) + return Path(__file__).parent.parent.parent.parent / "skyrl-tx" / "tx" / "tinker" / "tinker.db" + + +class TvizTracker: + """TViz tracking backend that writes metrics to the tinker database.""" + + def __init__(self, experiment_name: str, config: dict[str, Any] | None = None): + self.db_path = get_tinker_db_path() + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + self.run_id = str(uuid.uuid4())[:8] + self.run_name = experiment_name + self._config = config + self._conn: sqlite3.Connection | None = None + self._initialized = False + + def _get_conn(self) -> sqlite3.Connection: + if self._conn is None: + self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + return self._conn + + def _ensure_initialized(self) -> None: + if self._initialized: + return + + conn = self._get_conn() + conn.executescript(SCHEMA) + + config_json = None + if self._config is not None: + try: + from omegaconf import OmegaConf + if hasattr(self._config, "_content"): + self._config = OmegaConf.to_container(self._config, resolve=True) + except ImportError: + pass + config_json = json.dumps(self._config) + + conn.execute( + "INSERT INTO training_runs (id, name, type, modality, config) VALUES (?, ?, ?, ?, ?)", + (self.run_id, self.run_name, "rl", "text", config_json), + ) + conn.commit() + self._initialized = True + print(f"TViz dashboard: http://localhost:3003/training-run/{self.run_id}") + + def _map_metrics(self, data: dict[str, Any]) -> dict[str, Any]: + mapped = {} + for key, value in data.items(): + if not isinstance(value, (int, float)): + continue + if key in METRIC_MAPPING: + tviz_key = METRIC_MAPPING[key] + if tviz_key not in mapped: + mapped[tviz_key] = value + else: + mapped[key] = value + return mapped + + def log(self, data: dict[str, Any], step: int) -> None: + self._ensure_initialized() + metrics = self._map_metrics(data) + + known_cols = [ + "reward_mean", "reward_std", "loss", "kl_divergence", "entropy", + "learning_rate", "ac_tokens_per_turn", "ob_tokens_per_turn", + "total_ac_tokens", "total_turns", "sampling_time_mean", "time_total", + "frac_mixed", "frac_all_good", "frac_all_bad", + ] + + values = {col: metrics.pop(col, None) for col in known_cols} + extras = json.dumps(metrics) if metrics else None + + conn = self._get_conn() + conn.execute( + """ + INSERT OR REPLACE INTO training_steps + (run_id, step, reward_mean, reward_std, loss, kl_divergence, entropy, + learning_rate, ac_tokens_per_turn, ob_tokens_per_turn, total_ac_tokens, + total_turns, sampling_time_mean, time_total, frac_mixed, frac_all_good, + frac_all_bad, extras) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + self.run_id, step, + values["reward_mean"], values["reward_std"], values["loss"], + values["kl_divergence"], values["entropy"], values["learning_rate"], + values["ac_tokens_per_turn"], values["ob_tokens_per_turn"], + values["total_ac_tokens"], values["total_turns"], + values["sampling_time_mean"], values["time_total"], + values["frac_mixed"], values["frac_all_good"], values["frac_all_bad"], + extras, + ), + ) + conn.commit() + + def finish(self) -> None: + if self._conn is not None: + self._conn.execute( + "UPDATE training_runs SET ended_at = CURRENT_TIMESTAMP WHERE id = ?", + (self.run_id,), + ) + self._conn.commit() + self._conn.close() + self._conn = None diff --git a/skyrl-tx/tx/loaders/chat.py b/skyrl-tx/tx/loaders/chat.py index e2ada93fb..bf55c98cf 100644 --- a/skyrl-tx/tx/loaders/chat.py +++ b/skyrl-tx/tx/loaders/chat.py @@ -17,4 +17,4 @@ def chat(tokenizer: PreTrainedTokenizer, dataset: Dataset, batch_size: int) -> L "text": batch["input_ids"][:, :-1], "attention_mask": batch["attention_mask"][:, :-1], "target": batch["input_ids"][:, 1:], - }, {"shape": batch["input_ids"].shape, "tokens": batch["attention_mask"].sum()} + }, {"shape": batch["input_ids"].shape, "tokens": batch["attention_mask"].sum()}t diff --git a/skyrl-tx/tx/tinker/db_models.py b/skyrl-tx/tx/tinker/db_models.py index ad3575d40..b8cb23d7f 100644 --- a/skyrl-tx/tx/tinker/db_models.py +++ b/skyrl-tx/tx/tinker/db_models.py @@ -113,3 +113,44 @@ class SamplingSessionDB(SQLModel, table=True): base_model: str | None = None model_path: str | None = None created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) + + +class TrainingRunDB(SQLModel, table=True): + """Training run for visualization dashboard.""" + + __tablename__ = "training_runs" + + id: str = Field(primary_key=True) + name: str + type: str = Field(default="rl") + modality: str = Field(default="text") + config: dict | None = Field(default=None, sa_type=JSON) + started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) + ended_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True)) + + +class TrainingStepDB(SQLModel, table=True): + """Training step metrics for visualization dashboard.""" + + __tablename__ = "training_steps" + + id: int | None = Field(default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}) + run_id: str = Field(foreign_key="training_runs.id", index=True) + step: int + reward_mean: float | None = None + reward_std: float | None = None + loss: float | None = None + kl_divergence: float | None = None + entropy: float | None = None + learning_rate: float | None = None + ac_tokens_per_turn: float | None = None + ob_tokens_per_turn: float | None = None + total_ac_tokens: int | None = None + total_turns: int | None = None + sampling_time_mean: float | None = None + time_total: float | None = None + frac_mixed: float | None = None + frac_all_good: float | None = None + frac_all_bad: float | None = None + extras: dict | None = Field(default=None, sa_type=JSON) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True)) diff --git a/skyrl-tx/viz/.gitignore b/skyrl-tx/viz/.gitignore new file mode 100644 index 000000000..b71013141 --- /dev/null +++ b/skyrl-tx/viz/.gitignore @@ -0,0 +1,42 @@ +# dependencies +node_modules/ +.pnp +.pnp.js + +# build +.next/ +out/ +dist/ +build/ + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env*.local +.env + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +# lock files (use bun.lock) +package-lock.json +yarn.lock + +# database (local development) +*.db + +# python +__pycache__/ +*.pyc +.venv/ diff --git a/skyrl-tx/viz/README.md b/skyrl-tx/viz/README.md new file mode 100644 index 000000000..8b10fefc2 --- /dev/null +++ b/skyrl-tx/viz/README.md @@ -0,0 +1,23 @@ +# SkyRL Dashboard + +Local visualization dashboard for SkyRL training runs. + +## Setup + +```bash +cd skyrl-tx/viz +bun install +bun dev +``` + +Dashboard runs at http://localhost:3003 + +## Architecture + +- Next.js App Router frontend +- Reads from `tinker.db` (SQLite) +- Tables: `training_runs`, `training_steps`, `sessions`, `models`, `futures`, `checkpoints` + +## Environment Variables + +- `TINKER_DB_PATH` - Path to tinker database (default: `./tx/tinker/tinker.db`) diff --git a/skyrl-tx/viz/app/(dashboard)/layout.tsx b/skyrl-tx/viz/app/(dashboard)/layout.tsx new file mode 100644 index 000000000..1b3c8e69c --- /dev/null +++ b/skyrl-tx/viz/app/(dashboard)/layout.tsx @@ -0,0 +1,11 @@ +import { Sidebar } from "@/components/ui/sidebar"; + +export default function DashboardLayout({ children }: { children: React.ReactNode }) { + return ( +
+ + {/* pt-14 on mobile for fixed header, md:pt-0 for desktop */} +
{children}
+
+ ); +} diff --git a/skyrl-tx/viz/app/(dashboard)/page.tsx b/skyrl-tx/viz/app/(dashboard)/page.tsx new file mode 100644 index 000000000..cb46d5d30 --- /dev/null +++ b/skyrl-tx/viz/app/(dashboard)/page.tsx @@ -0,0 +1,5 @@ +import Dashboard from "@/components/Dashboard"; + +export default function HomePage() { + return ; +} diff --git a/skyrl-tx/viz/app/(dashboard)/training-run/[runId]/page.tsx b/skyrl-tx/viz/app/(dashboard)/training-run/[runId]/page.tsx new file mode 100644 index 000000000..4d7366084 --- /dev/null +++ b/skyrl-tx/viz/app/(dashboard)/training-run/[runId]/page.tsx @@ -0,0 +1,228 @@ +"use client"; + +import { useState, useMemo } from "react"; +import { useParams } from "next/navigation"; +import { Badge } from "@/components/ui/badge"; +import { Card } from "@/components/ui/card"; +import { useRunData } from "@/hooks/useRunData"; +import { + XAxis, + YAxis, + Tooltip, + ResponsiveContainer, + CartesianGrid, + AreaChart, + Area, + LineChart, + Line, +} from "recharts"; + +export default function TrainingRunPage() { + const params = useParams(); + const runId = params.runId as string; + const decodedRunId = decodeURIComponent(runId); + + const { run, steps, isLoading } = useRunData(decodedRunId); + + const config = run?.config ? JSON.parse(run.config) : {}; + const latest = steps[steps.length - 1]; + const isEnded = !!run?.ended_at; + + const chartData = useMemo(() => { + return steps.map((s) => ({ + step: s.step, + reward_mean: s.reward_mean, + loss: s.loss, + kl_divergence: s.kl_divergence, + entropy: s.entropy, + learning_rate: s.learning_rate, + ac_tokens_per_turn: s.ac_tokens_per_turn, + sampling_time_mean: s.sampling_time_mean, + time_total: s.time_total, + })); + }, [steps]); + + if (isLoading) { + return ( +
+
+
+
+
+
+ ); + } + + return ( +
+ {/* Header */} +
+
+

{run?.name || decodedRunId}

+ + {isEnded ? "Ended" : "Live"} + + {run?.type && {run.type}} +
+ + {latest && ( +
+
Step {latest.step}
+
+ reward: {latest.reward_mean?.toFixed(3) ?? "N/A"} +
+
+ )} +
+ + {/* Config */} + {Object.keys(config).length > 0 && ( +
+ {Object.entries(config) + .filter(([k]) => { + const essentialKeys = [ + "model_name", "model", "env_type", "batch_size", "group_size", + "learning_rate", "lr", "max_steps", "max_tokens", "task", "hf_repo" + ]; + return essentialKeys.includes(k); + }) + .map(([k, v]) => ( +
+ {k}: + {String(v)} +
+ ))} +
+ )} + + {/* Charts */} +
+ {steps.some(s => s.reward_mean !== null) && ( + +
Reward Mean
+ + + + + + + + + +
+ )} + + {steps.some(s => s.loss !== null) && ( + +
Loss
+ + + + + + + + + +
+ )} + + {steps.some(s => s.kl_divergence !== null) && ( + +
KL Divergence
+ + + + + + + + + +
+ )} + + {steps.some(s => s.entropy !== null) && ( + +
Entropy
+ + + + + + + + + +
+ )} + + {steps.some(s => s.learning_rate !== null) && ( + +
Learning Rate
+ + + + + v.toExponential(1)} /> + + + + +
+ )} + + {steps.some(s => s.ac_tokens_per_turn !== null) && ( + +
Tokens per Turn
+ + + + + + + + + +
+ )} + + {steps.some(s => s.sampling_time_mean !== null) && ( + +
Sampling Time (s)
+ + + + + + + + + +
+ )} + + {steps.some(s => s.time_total !== null) && ( + +
Step Time (s)
+ + + + + + + + + +
+ )} +
+ + {steps.length === 0 && ( + + No training data available yet + + )} +
+ ); +} diff --git a/skyrl-tx/viz/app/(dashboard)/training-runs/page.tsx b/skyrl-tx/viz/app/(dashboard)/training-runs/page.tsx new file mode 100644 index 000000000..39c4643de --- /dev/null +++ b/skyrl-tx/viz/app/(dashboard)/training-runs/page.tsx @@ -0,0 +1,82 @@ +"use client"; + +import { useState, useEffect } from "react"; +import Link from "next/link"; +import { Badge } from "@/components/ui/badge"; +import { + Table, + TableHeader, + TableBody, + TableRow, + TableHead, + TableCell, +} from "@/components/ui/table"; +import type { Run } from "@/lib/db"; + +export default function TrainingRunsPage() { + const [runs, setRuns] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + fetch("/api/runs", { cache: "no-store" }) + .then((r) => r.json()) + .then(setRuns) + .catch(() => {}) + .finally(() => setLoading(false)); + }, []); + + return ( +
+

Training runs

+ + + + + ID + NAME + TYPE + MODALITY + STARTED + + + + {loading ? ( + + + Loading... + + + ) : runs.length === 0 ? ( + + + No training runs yet. + + + ) : ( + runs.map((run) => ( + + + + {run.id} + + + {run.name} + + {run.type} + + + + {run.modality} + + + + {run.started_at ? new Date(run.started_at).toLocaleString() : "N/A"} + + + )) + )} + +
+
+ ); +} diff --git a/skyrl-tx/viz/app/api/checkpoints/route.ts b/skyrl-tx/viz/app/api/checkpoints/route.ts new file mode 100644 index 000000000..4bf6b60da --- /dev/null +++ b/skyrl-tx/viz/app/api/checkpoints/route.ts @@ -0,0 +1,98 @@ +import { getTinkerDb, TinkerCheckpoint } from "@/lib/db"; +import { NextResponse } from "next/server"; + +export const dynamic = "force-dynamic"; +export const revalidate = 0; + +export async function GET(request: Request) { + try { + const { searchParams } = new URL(request.url); + const modelId = searchParams.get("model_id"); + + const db = getTinkerDb(); + + if (!db) { + return NextResponse.json({ + available: false, + checkpoints: [], + }); + } + + let query = ` + SELECT + c.model_id, + c.checkpoint_id, + c.checkpoint_type, + c.status, + c.created_at, + c.completed_at, + c.error_message, + m.base_model + FROM checkpoints c + LEFT JOIN models m ON c.model_id = m.model_id + `; + + const params: string[] = []; + + if (modelId) { + query += " WHERE c.model_id = ?"; + params.push(modelId); + } + + query += " ORDER BY c.created_at DESC LIMIT 100"; + + const checkpoints = db.prepare(query).all(...params) as (TinkerCheckpoint & { + base_model: string | null; + })[]; + + // Get checkpoint stats + const statsQuery = modelId + ? ` + SELECT + checkpoint_type, + status, + COUNT(*) as count + FROM checkpoints + WHERE model_id = ? + GROUP BY checkpoint_type, status + ` + : ` + SELECT + checkpoint_type, + status, + COUNT(*) as count + FROM checkpoints + GROUP BY checkpoint_type, status + `; + + const statsRaw = modelId + ? db.prepare(statsQuery).all(modelId) as { checkpoint_type: string; status: string; count: number }[] + : db.prepare(statsQuery).all() as { checkpoint_type: string; status: string; count: number }[]; + + const stats = { + training: { pending: 0, completed: 0, failed: 0 }, + sampler: { pending: 0, completed: 0, failed: 0 }, + }; + + for (const row of statsRaw) { + const type = row.checkpoint_type === "TRAINING" ? "training" : "sampler"; + const status = row.status as "pending" | "completed" | "failed"; + stats[type][status] = row.count; + } + + db.close(); + + return NextResponse.json({ + available: true, + checkpoints, + stats, + }); + } catch (error) { + console.error("Checkpoints API error:", error); + return NextResponse.json({ + available: false, + checkpoints: [], + error: String(error), + }); + } +} diff --git a/skyrl-tx/viz/app/api/models/route.ts b/skyrl-tx/viz/app/api/models/route.ts new file mode 100644 index 000000000..0dc381b60 --- /dev/null +++ b/skyrl-tx/viz/app/api/models/route.ts @@ -0,0 +1,79 @@ +import { getTinkerDb, TinkerModel } from "@/lib/db"; +import { NextResponse } from "next/server"; + +export const dynamic = "force-dynamic"; +export const revalidate = 0; + +export async function GET() { + try { + const db = getTinkerDb(); + if (!db) { + return NextResponse.json({ available: false, models: [] }); + } + + const models = db.prepare(` + SELECT m.model_id, m.base_model, m.lora_config, m.status, m.session_id, m.created_at, + s.tags as session_tags, s.status as session_status + FROM models m + LEFT JOIN sessions s ON m.session_id = s.session_id + ORDER BY m.created_at DESC + `).all() as (TinkerModel & { session_tags: string | null; session_status: string | null })[]; + + const checkpointCounts = db.prepare(` + SELECT model_id, checkpoint_type, COUNT(*) as count, + SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed_count + FROM checkpoints GROUP BY model_id, checkpoint_type + `).all() as { model_id: string; checkpoint_type: string; count: number; completed_count: number }[]; + + const checkpointMap: Record = {}; + for (const row of checkpointCounts) { + if (!checkpointMap[row.model_id]) { + checkpointMap[row.model_id] = { training: 0, sampler: 0 }; + } + if (row.checkpoint_type === "TRAINING") { + checkpointMap[row.model_id].training = row.completed_count; + } else if (row.checkpoint_type === "SAMPLER") { + checkpointMap[row.model_id].sampler = row.completed_count; + } + } + + const requestCounts = db.prepare(` + SELECT model_id, COUNT(*) as total_requests, + SUM(CASE WHEN status = 'completed' THEN 1 ELSE 0 END) as completed_requests, + SUM(CASE WHEN request_type = 'FORWARD_BACKWARD' THEN 1 ELSE 0 END) as training_steps + FROM futures WHERE model_id IS NOT NULL GROUP BY model_id + `).all() as { model_id: string; total_requests: number; completed_requests: number; training_steps: number }[]; + + const requestMap: Record = {}; + for (const row of requestCounts) { + requestMap[row.model_id] = { + total: row.total_requests, + completed: row.completed_requests, + trainingSteps: row.training_steps, + }; + } + + db.close(); + + const enrichedModels = models.map((model) => { + let loraConfig = null; + try { loraConfig = JSON.parse(model.lora_config); } catch {} + + let sessionTags: string[] = []; + try { if (model.session_tags) sessionTags = JSON.parse(model.session_tags); } catch {} + + return { + ...model, + lora_config: loraConfig, + session_tags: sessionTags, + checkpoints: checkpointMap[model.model_id] || { training: 0, sampler: 0 }, + requests: requestMap[model.model_id] || { total: 0, completed: 0, trainingSteps: 0 }, + }; + }); + + return NextResponse.json({ available: true, models: enrichedModels }); + } catch (error) { + console.error("Models API error:", error); + return NextResponse.json({ available: false, models: [], error: String(error) }); + } +} diff --git a/skyrl-tx/viz/app/api/queue/route.ts b/skyrl-tx/viz/app/api/queue/route.ts new file mode 100644 index 000000000..c1e29cbb3 --- /dev/null +++ b/skyrl-tx/viz/app/api/queue/route.ts @@ -0,0 +1,84 @@ +import { getTinkerDb, TinkerFuture, RequestType, QueueStats } from "@/lib/db"; +import { NextResponse } from "next/server"; + +export const dynamic = "force-dynamic"; +export const revalidate = 0; + +const REQUEST_TYPES: RequestType[] = [ + "CREATE_MODEL", + "FORWARD_BACKWARD", + "FORWARD", + "OPTIM_STEP", + "SAVE_WEIGHTS", + "SAVE_WEIGHTS_FOR_SAMPLER", + "LOAD_WEIGHTS", + "SAMPLE", + "EXTERNAL", +]; + +export async function GET() { + try { + const db = getTinkerDb(); + if (!db) { + return NextResponse.json({ available: false, stats: null, recentRequests: [] }); + } + + const stats = db.prepare(` + SELECT status, COUNT(*) as count FROM futures GROUP BY status + `).all() as { status: string; count: number }[]; + + const statsMap: Record = {}; + for (const row of stats) { + statsMap[row.status] = row.count; + } + + const byTypeRaw = db.prepare(` + SELECT request_type, status, COUNT(*) as count + FROM futures GROUP BY request_type, status + `).all() as { request_type: RequestType; status: string; count: number }[]; + + const byType: QueueStats["byType"] = {} as QueueStats["byType"]; + for (const type of REQUEST_TYPES) { + byType[type] = { pending: 0, completed: 0, failed: 0 }; + } + for (const row of byTypeRaw) { + if (byType[row.request_type]) { + byType[row.request_type][row.status as "pending" | "completed" | "failed"] = row.count; + } + } + + const recentRequests = db.prepare(` + SELECT request_id, request_type, model_id, status, created_at, completed_at + FROM futures ORDER BY request_id DESC LIMIT 50 + `).all() as Omit[]; + + const latencyStats = db.prepare(` + SELECT + AVG((julianday(completed_at) - julianday(created_at)) * 86400) as avg_latency_seconds, + MIN((julianday(completed_at) - julianday(created_at)) * 86400) as min_latency_seconds, + MAX((julianday(completed_at) - julianday(created_at)) * 86400) as max_latency_seconds + FROM futures WHERE status = 'completed' AND completed_at IS NOT NULL + `).get() as { + avg_latency_seconds: number | null; + min_latency_seconds: number | null; + max_latency_seconds: number | null; + }; + + db.close(); + + return NextResponse.json({ + available: true, + stats: { + pending: statsMap["pending"] || 0, + completed: statsMap["completed"] || 0, + failed: statsMap["failed"] || 0, + byType, + } as QueueStats, + latency: latencyStats, + recentRequests, + }); + } catch (error) { + console.error("Queue API error:", error); + return NextResponse.json({ available: false, stats: null, recentRequests: [], error: String(error) }); + } +} diff --git a/skyrl-tx/viz/app/api/runs/[runId]/route.ts b/skyrl-tx/viz/app/api/runs/[runId]/route.ts new file mode 100644 index 000000000..3a549e798 --- /dev/null +++ b/skyrl-tx/viz/app/api/runs/[runId]/route.ts @@ -0,0 +1,43 @@ +import { getDb, Run, Step } from "@/lib/db"; +import { NextResponse } from "next/server"; + +export async function GET( + request: Request, + { params }: { params: Promise<{ runId: string }> } +) { + try { + const { runId } = await params; + const db = getDb(); + + const run = db.prepare("SELECT * FROM training_runs WHERE id = ?").get(runId) as Run | undefined; + + if (!run) { + db.close(); + return NextResponse.json({ error: "Run not found" }, { status: 404 }); + } + + const steps = db + .prepare( + ` + SELECT + id, run_id, step, created_at as timestamp, + loss, reward_mean, reward_std, kl_divergence, entropy, learning_rate, + ac_tokens_per_turn, ob_tokens_per_turn, + total_ac_tokens, total_turns, + sampling_time_mean, time_total, + frac_mixed, frac_all_good, frac_all_bad, extras + FROM training_steps + WHERE run_id = ? + ORDER BY step + ` + ) + .all(runId) as Step[]; + + db.close(); + + return NextResponse.json({ run, steps }); + } catch (error) { + console.error("Run API error:", error); + return NextResponse.json({ error: "Database error" }, { status: 500 }); + } +} diff --git a/skyrl-tx/viz/app/api/runs/route.ts b/skyrl-tx/viz/app/api/runs/route.ts new file mode 100644 index 000000000..8958d5d68 --- /dev/null +++ b/skyrl-tx/viz/app/api/runs/route.ts @@ -0,0 +1,17 @@ +import { getDb, Run } from "@/lib/db"; +import { NextResponse } from "next/server"; + +export const dynamic = "force-dynamic"; +export const revalidate = 0; + +export async function GET() { + try { + const db = getDb(); + const runs = db.prepare("SELECT * FROM training_runs ORDER BY started_at DESC").all() as Run[]; + db.close(); + return NextResponse.json(runs); + } catch (error) { + console.error("Runs API error:", error); + return NextResponse.json([]); + } +} diff --git a/skyrl-tx/viz/app/api/stats/route.ts b/skyrl-tx/viz/app/api/stats/route.ts new file mode 100644 index 000000000..8afdde9fd --- /dev/null +++ b/skyrl-tx/viz/app/api/stats/route.ts @@ -0,0 +1,110 @@ +import { getDb } from "@/lib/db"; +import { NextResponse } from "next/server"; + +export const dynamic = "force-dynamic"; +export const revalidate = 0; + +export async function GET() { + try { + const db = getDb(); + + const totalRuns = db.prepare("SELECT COUNT(*) as count FROM training_runs").get() as { count: number }; + const runningRuns = db.prepare("SELECT COUNT(*) as count FROM training_runs WHERE ended_at IS NULL").get() as { count: number }; + + const runs = db.prepare(` + SELECT id, name, config, started_at, ended_at + FROM training_runs + ORDER BY started_at DESC + `).all() as { id: string; name: string; config: string | null; started_at: string; ended_at: string | null }[]; + + const runsWithModel = runs.map(run => { + let model = "Unknown"; + if (run.config) { + try { + const config = JSON.parse(run.config); + model = config.model || config.model_name || config.base_model || "Unknown"; + if (model.includes("/")) { + model = model.split("/").pop() || model; + } + } catch {} + } + return { ...run, model }; + }); + + const perfStats = db.prepare(` + SELECT + AVG(ac_tokens_per_turn) as avg_tokens_per_turn, + SUM(total_ac_tokens) as total_action_tokens, + SUM(total_turns) as total_turns, + AVG(sampling_time_mean) as avg_sampling_time, + AVG(time_total) as avg_step_time + FROM training_steps + WHERE ac_tokens_per_turn IS NOT NULL + OR total_ac_tokens IS NOT NULL + `).get() as { + avg_tokens_per_turn: number | null; + total_action_tokens: number | null; + total_turns: number | null; + avg_sampling_time: number | null; + avg_step_time: number | null; + }; + + const perfOverTime = db.prepare(` + SELECT + date(created_at) as day, + AVG(ac_tokens_per_turn) as tokens_per_turn, + AVG(total_ac_tokens) as tokens_per_step, + COUNT(*) as num_steps + FROM training_steps + WHERE total_ac_tokens IS NOT NULL + GROUP BY date(created_at) + ORDER BY day ASC + LIMIT 30 + `).all() as { day: string; tokens_per_turn: number | null; tokens_per_step: number | null; num_steps: number }[]; + + const rewardOverTime = db.prepare(` + SELECT + date(created_at) as day, + AVG(reward_mean) as avg_reward, + MAX(reward_mean) as max_reward + FROM training_steps + WHERE reward_mean IS NOT NULL + GROUP BY date(created_at) + ORDER BY day ASC + LIMIT 30 + `).all() as { day: string; avg_reward: number | null; max_reward: number | null }[]; + + db.close(); + + return NextResponse.json({ + totalRuns: totalRuns.count, + runningRuns: runningRuns.count, + runs: runsWithModel.slice(0, 10), + perfStats: { + avgTokensPerTurn: perfStats.avg_tokens_per_turn, + totalActionTokens: perfStats.total_action_tokens, + totalTurns: perfStats.total_turns, + avgSamplingTime: perfStats.avg_sampling_time, + avgStepTime: perfStats.avg_step_time, + }, + perfOverTime, + rewardOverTime, + }); + } catch (error) { + console.error("Stats API error:", error); + return NextResponse.json({ + totalRuns: 0, + runningRuns: 0, + runs: [], + perfStats: { + avgTokensPerTurn: null, + totalActionTokens: null, + totalTurns: null, + avgSamplingTime: null, + avgStepTime: null, + }, + perfOverTime: [], + rewardOverTime: [], + }); + } +} diff --git a/skyrl-tx/viz/app/globals.css b/skyrl-tx/viz/app/globals.css new file mode 100644 index 000000000..3ac467c75 --- /dev/null +++ b/skyrl-tx/viz/app/globals.css @@ -0,0 +1,180 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +@import "leaflet/dist/leaflet.css"; + +@layer base { + :root { + /* Typography */ + --font-sans: ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; + --font-mono: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; + + /* Base Colors & Typography Variables from User Snippet */ + --font-size-32: 32px; + --font-size-24: 24px; + --font-size-20: 20px; + --font-size-16: 16px; + --font-size-14: 14px; + --font-size-12: 12px; + --line-height-40: 40px; + --line-height-32: 32px; + --line-height-24: 24px; + --line-height-20: 20px; + --line-height-16: 16px; + + /* Color Palette */ + --blue-700: #0880cc; + --blue-600: #0a8ae6; + --blue-500: #0b99ff; + --blue-400: #28a5ff; + --red-500: #f45757; + --red-400: #ff7373; + --steel-700: #4d7699; + --steel-650: #5a83ac; + --steel-600: #54687f; + --steel-550: #6890bb; + --steel-500: #759bca; + --steel-400: #82a7d6; + --neutral-900: #1a1a1a; + --neutral-600: #666; + --neutral-200: #f4f5f7; + --neutral-100: #f6f6f6; + --neutral-50: #f6f7f9; + --neutral-000: #fff; + + /* Sizing & Spacing */ + --size-44: 44px; + --size-32: 32px; + --size-26: 26px; + --spacing-18: 18px; + --spacing-14: 14px; + --spacing-10: 10px; + --radius-10: 10px; + --radius-6: 6px; + --radius-4: 4px; + + /* Alpha Colors */ + --alpha-2: #00000005; + --alpha-6: #0000000f; + --alpha-10: #0000001a; + --alpha-12: #0000001f; + --alpha-16: #00000029; + + /* Button Gradients & Shadows */ + --btn-default-bg: linear-gradient(180deg, var(--neutral-000) 5.17%, var(--neutral-100) 23.1%, var(--neutral-200) 76.9%, var(--neutral-50) 94.83%); + --btn-primary-bg: linear-gradient(180deg, oklch(.67 .19 263), oklch(.59 .26 263) 62%, oklch(.64 .22 263)); + --btn-primary-border: oklch(.53 .25 265.05); + --btn-primary-shadow: inset 0 1px 0 oklch(1 0 0 / .2), inset 0 -1px 0 oklch(0 0 0 / .2), 0 1px 2px oklch(0 0 0 / .2); + --btn-secondary-bg: linear-gradient(180deg, var(--steel-400) 5.17%, var(--steel-500) 23.1%, var(--steel-500) 76.9%, var(--steel-400) 94.83%); + --btn-danger-bg: linear-gradient(180deg, var(--red-400) 5.17%, var(--red-500) 23.1%, var(--red-500) 76.9%, var(--red-400) 94.83%); + + --shadow-btn-inset: inset 0 1px 0 rgba(255, 255, 255, 0.25); + --overlay-hover: rgba(0, 0, 0, 0.04); + --overlay-active: rgba(0, 0, 0, 0.08); + + /* Text Colors */ + --text-body: var(--neutral-900); + --text-muted: var(--neutral-600); + --border-default: var(--alpha-10); + + /* Shadcn/Radix OKLCH Theme Variables */ + --radius: .625rem; + --background: 100% 0 0; + --foreground: 14.5% 0 0; + --card: 100% 0 0; + --card-foreground: 14.5% 0 0; + --popover: 100% 0 0; + --popover-foreground: 14.5% 0 0; + --primary: 20.5% 0 0; + --primary-foreground: 98.5% 0 0; + --secondary: 97% 0 0; + --secondary-foreground: 20.5% 0 0; + --muted: 97% 0 0; + --muted-foreground: 55.6% 0 0; + --accent: 97% 0 0; + --accent-foreground: 20.5% 0 0; + --destructive: 57.7% .245 27.325; + --border: 92.2% 0 0; + --input: 92.2% 0 0; + --ring: 70.8% 0 0; + + /* Charts */ + --chart-1: 64.6% .222 41.116; + --chart-2: 60% .118 184.704; + --chart-3: 39.8% .07 227.392; + --chart-4: 82.8% .189 84.429; + --chart-5: 76.9% .188 70.08; + + /* Sidebar */ + --sidebar: 98.5% 0 0; + --sidebar-foreground: 14.5% 0 0; + --sidebar-primary: 20.5% 0 0; + --sidebar-primary-foreground: 98.5% 0 0; + --sidebar-accent: 97% 0 0; + --sidebar-accent-foreground: 20.5% 0 0; + --sidebar-border: 92.2% 0 0; + --sidebar-ring: 70.8% 0 0; + + /* Additional Control Variables */ + --control-md: 32px; + --control-sm: 28px; + --control-xs: 24px; + --px-control-md: 12px; + --text-control-md: 14px; + --radius-control: 6px; + --px-control-sm: 12px; + + /* Badge colors (Restored) */ + --badge-blue-bg: #dbeafe; + --badge-blue-text: #1e40af; + --badge-green-bg: #dcfce7; + --badge-green-text: #166534; + --badge-gray-bg: #f3f4f6; + --badge-gray-text: #374151; + --badge-purple-bg: #f3e8ff; + --badge-purple-text: #6b21a8; + } + + .dark { + /* Dark mode overrides if needed, assuming system preference or class */ + --background: 14.5% 0 0; + --foreground: 98.5% 0 0; + --card: 14.5% 0 0; + --card-foreground: 98.5% 0 0; + --popover: 14.5% 0 0; + --popover-foreground: 98.5% 0 0; + --primary: 98.5% 0 0; + --primary-foreground: 20.5% 0 0; + --secondary: 20.5% 0 0; + --secondary-foreground: 98.5% 0 0; + --muted: 20.5% 0 0; + --muted-foreground: 55.6% 0 0; + --accent: 20.5% 0 0; + --accent-foreground: 98.5% 0 0; + --destructive: 57.7% .245 27.325; + --border: 20.5% 0 0; + --input: 20.5% 0 0; + --ring: 70.8% 0 0; + } +} + +@layer base { + * { + @apply border-border; + } + body { + @apply bg-background text-foreground; + } +} + +/* Utilities for token visualization */ +@layer utilities { + .rounded-control { border-radius: var(--radius-control); } + .h-control-sm { height: var(--control-sm); } + .px-control-sm { padding-left: var(--px-control-sm); padding-right: var(--px-control-sm); } + .text-control-md { font-size: var(--text-control-md); } + .shadow-btn-inset { box-shadow: var(--shadow-btn-inset); } + .bg-btn-primary { background: var(--btn-primary-bg); } + .bg-btn-default { background: var(--btn-default-bg); } +} diff --git a/skyrl-tx/viz/app/layout.tsx b/skyrl-tx/viz/app/layout.tsx new file mode 100644 index 000000000..dbfd18bfc --- /dev/null +++ b/skyrl-tx/viz/app/layout.tsx @@ -0,0 +1,22 @@ +import type { Metadata } from "next"; +import { Providers } from "@/components/providers"; +import "./globals.css"; + +export const metadata: Metadata = { + title: "SkyRL Dashboard", + description: "Local dashboard for visualizing SkyRL training runs", +}; + +export default function RootLayout({ + children, +}: { + children: React.ReactNode; +}) { + return ( + + + {children} + + + ); +} diff --git a/skyrl-tx/viz/components/Dashboard.tsx b/skyrl-tx/viz/components/Dashboard.tsx new file mode 100644 index 000000000..76ffc7825 --- /dev/null +++ b/skyrl-tx/viz/components/Dashboard.tsx @@ -0,0 +1,536 @@ +"use client"; + +import React, { useEffect, useState } from "react"; +import Link from "next/link"; +import { + Card, + CardHeader, + CardTitle, + CardDescription, + CardContent, +} from "@/components/ui/card"; +import { Badge } from "@/components/ui/badge"; +import { + Table, + TableHeader, + TableBody, + TableRow, + TableHead, + TableCell, +} from "@/components/ui/table"; +import { + Bar, + BarChart, + CartesianGrid, + ResponsiveContainer, + Tooltip, + XAxis, + YAxis, + Cell, + Pie, + PieChart, +} from "recharts"; +import type { QueueStats, RequestType } from "@/lib/db"; + +// Chart colors +const CHART_COLORS = { + red: "#ef4444", + lime: "#84cc16", + amber: "#f59e0b", + sky: "#0ea5e9", + orange: "#f97316", + teal: "#14b8a6", + blue: "#2563eb", + green: "#22c55e", + purple: "#a855f7", + pink: "#ec4899", +}; + +const MODEL_COLORS = [ + CHART_COLORS.blue, + CHART_COLORS.green, + CHART_COLORS.orange, + CHART_COLORS.purple, + CHART_COLORS.pink, + CHART_COLORS.teal, + CHART_COLORS.amber, + CHART_COLORS.red, +]; + +interface DashboardStats { + totalRuns: number; + runningRuns: number; + totalTokens: number; + totalTrajectories: number; + runs: { + id: string; + name: string; + model: string; + started_at: string; + ended_at: string | null; + }[]; + tokensPerDay: { day: string; tokens: number }[]; + modelTokens: Record; + perfStats: { + avgTokensPerTurn: number | null; + totalActionTokens: number | null; + totalTurns: number | null; + avgSamplingTime: number | null; + avgStepTime: number | null; + }; + perfOverTime: { day: string; tokens_per_turn: number | null; tokens_per_step: number | null; num_steps: number }[]; +} + +interface QueueData { + available: boolean; + stats: QueueStats | null; + latency: { + avg_latency_seconds: number | null; + min_latency_seconds: number | null; + max_latency_seconds: number | null; + } | null; + recentRequests: { + request_id: number; + request_type: RequestType; + model_id: string | null; + status: string; + created_at: string; + completed_at: string | null; + }[]; +} + +function formatNumber(n: number): string { + if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(2)}M`; + if (n >= 1_000) return `${(n / 1_000).toFixed(1)}K`; + return n.toLocaleString(); +} + +function formatDate(dateStr: string): string { + const d = new Date(dateStr); + return d.toLocaleDateString("en-US", { month: "short", day: "numeric" }); +} + +function formatDuration(seconds: number | null): string { + if (seconds === null || Number.isNaN(seconds)) return "—"; + if (seconds < 1) return `${Math.round(seconds * 1000)} ms`; + if (seconds < 60) return `${seconds.toFixed(2)} s`; + return `${(seconds / 60).toFixed(1)} min`; +} + +function formatMetric(value: number | null, digits = 2): string { + if (value === null || Number.isNaN(value)) return "—"; + return value.toFixed(digits); +} + +function StatsCard({ + label, + value, + subtitle, +}: { + label: string; + value: string; + subtitle?: string; +}) { + return ( + + + + {label} + + {value} + {subtitle && ( +

{subtitle}

+ )} +
+
+ ); +} + +export default function Dashboard() { + const [stats, setStats] = useState(null); + const [queueData, setQueueData] = useState(null); + const [loading, setLoading] = useState(true); + + useEffect(() => { + Promise.all([ + fetch("/api/stats", { cache: "no-store" }).then((r) => r.json()), + fetch("/api/queue", { cache: "no-store" }).then((r) => r.json()).catch(() => null), + ]) + .then(([statsData, qData]) => { + setStats(statsData); + setQueueData(qData); + setLoading(false); + }) + .catch(() => setLoading(false)); + }, []); + + if (loading) { + return ( +
+
Loading...
+
+ ); + } + + if (!stats || stats.totalRuns === 0) { + return ( +
+

SkyRL Dashboard

+ +
No training runs yet
+
+ Start a training run to see metrics here. +
+
+
+ ); + } + + const avgTokensPerDay = stats.tokensPerDay.length > 0 + ? stats.tokensPerDay.reduce((sum, d) => sum + d.tokens, 0) / stats.tokensPerDay.length + : 0; + + // Prepare model breakdown data for pie chart + const modelData = Object.entries(stats.modelTokens) + .map(([name, tokens], i) => ({ + name: name.length > 20 ? name.slice(0, 20) + "..." : name, + fullName: name, + tokens, + color: MODEL_COLORS[i % MODEL_COLORS.length], + })) + .sort((a, b) => b.tokens - a.tokens); + + // Format tokens per day for chart + const tokenChartData = stats.tokensPerDay.map(d => ({ + date: formatDate(d.day), + tokens: d.tokens / 1000, // Convert to K + })); + + const perfChartData = stats.perfOverTime + .filter((d) => d.tokens_per_step !== null) + .map((d) => ({ + date: formatDate(d.day), + tokensPerStep: d.tokens_per_step ?? 0, + })); + + const hasPerfStats = [ + stats.perfStats.avgTokensPerTurn, + stats.perfStats.totalActionTokens, + stats.perfStats.avgSamplingTime, + stats.perfStats.avgStepTime, + ].some((value) => value !== null); + + return ( +
+ {/* Header */} +
+

SkyRL Dashboard

+
+ + {/* Stats Row */} +
+ 0 ? `${stats.runningRuns} currently running` : undefined} + /> + 0 ? `Avg per day: ${formatNumber(avgTokensPerDay)}` : undefined} + /> + + +
+ + {/* Two column: Tokens per day + Model breakdown */} +
+ {/* Tokens per Day Chart */} + {tokenChartData.length > 0 && ( + + + Tokens per day + Daily token generation (thousands) + + +
+ + + + + `${v}K`} + /> + [`${(Number(value) * 1000).toLocaleString()} tokens`, "Tokens"]} + /> + + + +
+
+
+ )} + + {/* Model Breakdown */} + {modelData.length > 0 && ( + + + Tokens by model + Distribution across base models + + +
+ + + + {modelData.map((entry, index) => ( + + ))} + + [formatNumber(Number(value)), "Tokens"]} + /> + + +
+ {modelData.slice(0, 5).map((model, i) => ( +
+ + + {model.name} + + {formatNumber(model.tokens)} +
+ ))} +
+
+
+
+ )} +
+ + {(hasPerfStats || perfChartData.length > 0) && ( +
+
+

Performance

+ Averages across logged steps +
+
+ + + Step metrics + Token and timing averages + + +
+
+
Avg tokens / turn
+
{formatMetric(stats.perfStats.avgTokensPerTurn, 2)}
+
+
+
Total action tokens
+
+ {stats.perfStats.totalActionTokens !== null + ? formatNumber(stats.perfStats.totalActionTokens) + : "—"} +
+
+
+
Total turns
+
+ {stats.perfStats.totalTurns !== null + ? formatNumber(stats.perfStats.totalTurns) + : "—"} +
+
+
+
Avg sampling time
+
{formatDuration(stats.perfStats.avgSamplingTime)}
+
+
+
Avg step time
+
{formatDuration(stats.perfStats.avgStepTime)}
+
+
+
+
+ + {perfChartData.length > 0 && ( + + + Action tokens per step + Daily average + + +
+ + + + + + [formatNumber(Number(value)), "Tokens"]} + /> + + + +
+
+
+ )} +
+
+ )} + + {/* SkyRL-TX Queue Status (only shown when tinker DB is available) */} + {queueData?.available && queueData.stats && ( +
+
+

SkyRL-TX Queue

+ Tinker API +
+
+ + + + +
+ + {/* Request Type Breakdown */} + + + Requests by Type + Queue status per operation type + + +
+ {(["FORWARD_BACKWARD", "OPTIM_STEP", "SAMPLE", "SAVE_WEIGHTS", "CREATE_MODEL"] as const).map((type) => { + const typeStats = queueData.stats?.byType[type]; + if (!typeStats || (typeStats.pending === 0 && typeStats.completed === 0)) return null; + return ( +
+
{type.replace(/_/g, " ")}
+
+ {typeStats.pending > 0 && ( + {typeStats.pending} pending + )} + {typeStats.completed} done +
+
+ ); + })} +
+
+
+
+ )} + + {/* Training Runs Table */} + + +
+
+ Recent Training Runs + View and manage your training runs +
+ + View all → + +
+
+ + + + + Run + Model + Status + Started + + + + {stats.runs.map((run) => ( + + + + {run.name || run.id} + + + + {run.model} + + + + {run.ended_at ? "completed" : "running"} + + + + {formatDate(run.started_at)} + + + ))} + +
+
+
+
+ ); +} diff --git a/skyrl-tx/viz/components/providers.tsx b/skyrl-tx/viz/components/providers.tsx new file mode 100644 index 000000000..43851af70 --- /dev/null +++ b/skyrl-tx/viz/components/providers.tsx @@ -0,0 +1,21 @@ +"use client"; + +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import { useState } from "react"; + +export function Providers({ children }: { children: React.ReactNode }) { + const [queryClient] = useState(() => new QueryClient({ + defaultOptions: { + queries: { + staleTime: 60 * 1000, // 1 minute + refetchOnWindowFocus: false, + }, + }, + })); + + return ( + + {children} + + ); +} diff --git a/skyrl-tx/viz/components/ui/badge.tsx b/skyrl-tx/viz/components/ui/badge.tsx new file mode 100644 index 000000000..54d584db7 --- /dev/null +++ b/skyrl-tx/viz/components/ui/badge.tsx @@ -0,0 +1,43 @@ +import * as React from "react"; +import { cva, type VariantProps } from "class-variance-authority"; +import { cn } from "@/lib/utils"; + +const badgeVariants = cva( + "inline-flex items-center rounded-md border border-transparent px-2 py-0.5 text-xs font-medium transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2", + { + variants: { + variant: { + default: "bg-accent text-accent-foreground", + secondary: "bg-muted text-muted-foreground uppercase tracking-wide", + destructive: "bg-[var(--badge-gray-bg)] text-[var(--badge-gray-text)]", + outline: "border border-border text-foreground", + blue: "bg-[var(--badge-blue-bg)] text-[var(--badge-blue-text)] uppercase tracking-wide", + green: "bg-[var(--badge-green-bg)] text-[var(--badge-green-text)] uppercase tracking-wide", + gray: "bg-muted text-muted-foreground", + purple: "bg-[var(--badge-purple-bg)] text-[var(--badge-purple-text)] uppercase tracking-wide", + sampler: "bg-muted text-muted-foreground uppercase tracking-wide", + training: "bg-muted text-muted-foreground uppercase tracking-wide", + private: "bg-muted text-muted-foreground", + }, + }, + defaultVariants: { + variant: "default", + }, + } +); + +export interface BadgeProps + extends React.HTMLAttributes, + VariantProps {} + +function Badge({ className, variant, ...props }: BadgeProps) { + return ( +
+ ); +} + +export { Badge, badgeVariants }; diff --git a/skyrl-tx/viz/components/ui/button.tsx b/skyrl-tx/viz/components/ui/button.tsx new file mode 100644 index 000000000..d021fa873 --- /dev/null +++ b/skyrl-tx/viz/components/ui/button.tsx @@ -0,0 +1,58 @@ +import * as React from "react"; +import { Slot } from "@radix-ui/react-slot"; +import { cva, type VariantProps } from "class-variance-authority"; +import { cn } from "@/lib/utils"; + +const buttonVariants = cva( + "inline-flex items-center justify-center select-none whitespace-nowrap transition-colors disabled:opacity-50 disabled:pointer-events-none focus-visible:outline-[2px] focus-visible:outline-offset-1 focus-visible:outline-[#0B99FF]", + { + variants: { + variant: { + default: + "bg-[image:var(--btn-default-bg)] shadow-[var(--shadow-btn-inset)] hover:bg-[image:var(--btn-default-bg),linear-gradient(var(--overlay-hover),var(--overlay-hover))] hover:bg-blend-overlay active:bg-[image:var(--btn-default-bg),linear-gradient(var(--overlay-active),var(--overlay-active))] active:bg-blend-overlay text-foreground rounded-[var(--radius-control)]", + primary: + "bg-[image:var(--btn-primary-bg)] border border-[var(--btn-primary-border)] shadow-[var(--btn-primary-shadow)] text-white rounded-md hover:brightness-105 active:brightness-95 active:scale-[0.98] transition-all duration-200", + secondary: + "bg-[image:var(--btn-secondary-bg)] shadow-[var(--shadow-btn-inset)] text-white rounded-[var(--radius-control)]", + danger: + "bg-[image:var(--btn-danger-bg)] shadow-[var(--shadow-btn-inset)] text-white rounded-[var(--radius-control)]", + outline: "border border-border bg-background hover:bg-muted rounded-md", + ghost: "hover:bg-muted hover:text-foreground rounded-md", + link: "text-accent underline-offset-4 hover:underline", + }, + size: { + default: "h-[var(--control-md)] px-[var(--px-control-md)] text-[var(--text-control-md)]", + sm: "h-[var(--control-sm)] px-3 text-xs", + xs: "h-[var(--control-xs)] px-2 text-xs", + lg: "h-10 px-8 text-sm", + icon: "h-[var(--control-md)] w-[var(--control-md)]", + }, + }, + defaultVariants: { + variant: "default", + size: "default", + }, + } +); + +export interface ButtonProps + extends React.ButtonHTMLAttributes, + VariantProps { + asChild?: boolean; +} + +const Button = React.forwardRef( + ({ className, variant, size, asChild = false, ...props }, ref) => { + const Comp = asChild ? Slot : "button"; + return ( + + ); + } +); +Button.displayName = "Button"; + +export { Button, buttonVariants }; diff --git a/skyrl-tx/viz/components/ui/card.tsx b/skyrl-tx/viz/components/ui/card.tsx new file mode 100644 index 000000000..21abcfa0b --- /dev/null +++ b/skyrl-tx/viz/components/ui/card.tsx @@ -0,0 +1,85 @@ +import * as React from "react"; +import { cn } from "@/lib/utils"; + +const Card = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +Card.displayName = "Card"; + +const CardHeader = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardHeader.displayName = "CardHeader"; + +const CardTitle = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardTitle.displayName = "CardTitle"; + +const CardDescription = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardDescription.displayName = "CardDescription"; + +const CardContent = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardContent.displayName = "CardContent"; + +const CardFooter = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +CardFooter.displayName = "CardFooter"; + +export { Card, CardHeader, CardTitle, CardDescription, CardContent, CardFooter }; diff --git a/skyrl-tx/viz/components/ui/dialog.tsx b/skyrl-tx/viz/components/ui/dialog.tsx new file mode 100644 index 000000000..7a13b2c25 --- /dev/null +++ b/skyrl-tx/viz/components/ui/dialog.tsx @@ -0,0 +1,114 @@ +"use client"; + +import * as React from "react"; +import * as DialogPrimitive from "@radix-ui/react-dialog"; +import { cn } from "@/lib/utils"; + +const Dialog = DialogPrimitive.Root; +const DialogTrigger = DialogPrimitive.Trigger; +const DialogPortal = DialogPrimitive.Portal; +const DialogClose = DialogPrimitive.Close; + +const DialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +DialogOverlay.displayName = DialogPrimitive.Overlay.displayName; + +const DialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + {children} + + + + + + Close + + + + +)); +DialogContent.displayName = DialogPrimitive.Content.displayName; + +const DialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+); +DialogHeader.displayName = "DialogHeader"; + +const DialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +DialogTitle.displayName = DialogPrimitive.Title.displayName; + +const DialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)); +DialogDescription.displayName = DialogPrimitive.Description.displayName; + +export { + Dialog, + DialogPortal, + DialogOverlay, + DialogClose, + DialogTrigger, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, +}; diff --git a/skyrl-tx/viz/components/ui/pagination.tsx b/skyrl-tx/viz/components/ui/pagination.tsx new file mode 100644 index 000000000..ad7bdc236 --- /dev/null +++ b/skyrl-tx/viz/components/ui/pagination.tsx @@ -0,0 +1,107 @@ +"use client"; + +import * as React from "react"; +import { cn } from "@/lib/utils"; + +interface PaginationProps { + currentPage: number; + totalPages: number; + totalItems: number; + itemsPerPage: number; + onPageChange: (page: number) => void; + className?: string; +} + +export function Pagination({ + currentPage, + totalPages, + totalItems, + itemsPerPage, + onPageChange, + className, +}: PaginationProps) { + const startItem = (currentPage - 1) * itemsPerPage + 1; + const endItem = Math.min(currentPage * itemsPerPage, totalItems); + + // Generate page numbers to show + const getPageNumbers = () => { + const pages: (number | string)[] = []; + + if (totalPages <= 7) { + for (let i = 1; i <= totalPages; i++) pages.push(i); + } else { + // Always show first page + pages.push(1); + + if (currentPage > 3) { + pages.push("..."); + } + + // Show pages around current + for (let i = Math.max(2, currentPage - 1); i <= Math.min(totalPages - 1, currentPage + 1); i++) { + if (!pages.includes(i)) pages.push(i); + } + + if (currentPage < totalPages - 2) { + pages.push("..."); + } + + // Always show last page + if (!pages.includes(totalPages)) pages.push(totalPages); + } + + return pages; + }; + + return ( +
+

+ Showing {startItem}–{endItem} of {totalItems.toLocaleString()} +

+ +
+ + + {getPageNumbers().map((page, idx) => ( + + {page === "..." ? ( + ... + ) : ( + + )} + + ))} + + +
+
+ ); +} diff --git a/skyrl-tx/viz/components/ui/sidebar.tsx b/skyrl-tx/viz/components/ui/sidebar.tsx new file mode 100644 index 000000000..17d90b1d9 --- /dev/null +++ b/skyrl-tx/viz/components/ui/sidebar.tsx @@ -0,0 +1,139 @@ +"use client"; + +import * as React from "react"; +import { useState } from "react"; +import Link from "next/link"; +import { usePathname } from "next/navigation"; +import { cn } from "@/lib/utils"; + +const Icons = { + training: () => ( + + + + + ), + usage: () => ( + + + + + + ), + menu: () => ( + + + + + + ), + close: () => ( + + + + + ), +}; + +interface NavItemConfig { + label: string; + href: string; + icon: keyof typeof Icons; +} + +const navItems: NavItemConfig[] = [ + { label: "Overview", href: "/", icon: "usage" }, + { label: "Training Runs", href: "/training-runs", icon: "training" }, +]; + +export function Sidebar() { + const pathname = usePathname(); + const [isOpen, setIsOpen] = useState(false); + + const getIsActive = (href: string) => { + if (href === "/") return pathname === "/"; + return pathname === href; + }; + + const handleNavClick = () => { + setIsOpen(false); + }; + + const sidebarContent = ( + <> + + SkyRL + + + + + + + ); + + return ( + <> +
+ SkyRL + +
+ + {isOpen && ( +
setIsOpen(false)} + /> + )} + + + + + + ); +} diff --git a/skyrl-tx/viz/components/ui/table.tsx b/skyrl-tx/viz/components/ui/table.tsx new file mode 100644 index 000000000..f3112b216 --- /dev/null +++ b/skyrl-tx/viz/components/ui/table.tsx @@ -0,0 +1,131 @@ +import * as React from "react"; +import { cn } from "@/lib/utils"; + +const Table = React.forwardRef< + HTMLTableElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+ + +)); +Table.displayName = "Table"; + +const TableHeader = React.forwardRef< + HTMLTableSectionElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( + +)); +TableHeader.displayName = "TableHeader"; + +const TableBody = React.forwardRef< + HTMLTableSectionElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( + +)); +TableBody.displayName = "TableBody"; + +const TableFooter = React.forwardRef< + HTMLTableSectionElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( + tr]:last:border-b-0", + className + )} + {...props} + /> +)); +TableFooter.displayName = "TableFooter"; + +const TableRow = React.forwardRef< + HTMLTableRowElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( + +)); +TableRow.displayName = "TableRow"; + +const TableHead = React.forwardRef< + HTMLTableCellElement, + React.ThHTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +TableHead.displayName = "TableHead"; + +const TableCell = React.forwardRef< + HTMLTableCellElement, + React.TdHTMLAttributes +>(({ className, ...props }, ref) => ( + +)); +TableCell.displayName = "TableCell"; + +const TableCaption = React.forwardRef< + HTMLTableCaptionElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)); +TableCaption.displayName = "TableCaption"; + +export { + Table, + TableHeader, + TableBody, + TableFooter, + TableHead, + TableRow, + TableCell, + TableCaption, +}; diff --git a/skyrl-tx/viz/components/ui/toggle-group.tsx b/skyrl-tx/viz/components/ui/toggle-group.tsx new file mode 100644 index 000000000..4a286db75 --- /dev/null +++ b/skyrl-tx/viz/components/ui/toggle-group.tsx @@ -0,0 +1,43 @@ +"use client"; + +import * as React from "react"; +import { cn } from "@/lib/utils"; + +interface ToggleGroupProps { + value: string; + onValueChange: (value: string) => void; + options: { value: string; label: string }[]; + className?: string; +} + +export function ToggleGroup({ + value, + onValueChange, + options, + className, +}: ToggleGroupProps) { + return ( +
+ {options.map((option) => ( + + ))} +
+ ); +} diff --git a/skyrl-tx/viz/hooks/useRunData.ts b/skyrl-tx/viz/hooks/useRunData.ts new file mode 100644 index 000000000..ad6799c02 --- /dev/null +++ b/skyrl-tx/viz/hooks/useRunData.ts @@ -0,0 +1,33 @@ +"use client"; + +import { useState, useEffect } from "react"; +import type { Run, Step } from "@/lib/db"; + +export function useRunData(runId: string) { + const [run, setRun] = useState(null); + const [steps, setSteps] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const [error, setError] = useState(null); + + useEffect(() => { + if (!runId) return; + + setIsLoading(true); + setError(null); + + fetch(`/api/runs/${runId}`) + .then((r) => r.json()) + .then((data) => { + setRun(data.run); + setSteps(data.steps || []); + }) + .catch((e) => { + setError(e.message); + }) + .finally(() => { + setIsLoading(false); + }); + }, [runId]); + + return { run, steps, isLoading, error }; +} diff --git a/skyrl-tx/viz/next.config.mjs b/skyrl-tx/viz/next.config.mjs new file mode 100644 index 000000000..d5456a15d --- /dev/null +++ b/skyrl-tx/viz/next.config.mjs @@ -0,0 +1,6 @@ +/** @type {import('next').NextConfig} */ +const nextConfig = { + reactStrictMode: true, +}; + +export default nextConfig; diff --git a/skyrl-tx/viz/package.json b/skyrl-tx/viz/package.json new file mode 100644 index 000000000..54256e638 --- /dev/null +++ b/skyrl-tx/viz/package.json @@ -0,0 +1,35 @@ +{ + "name": "skyrl-viz", + "version": "0.1.0", + "description": "Local visualization dashboard for SkyRL training runs", + "private": true, + "scripts": { + "dev": "bun next dev -p 3003", + "build": "bun next build", + "start": "bun next start -p 3003" + }, + "dependencies": { + "@radix-ui/react-dialog": "^1.1.15", + "@radix-ui/react-slot": "^1.2.4", + "@tanstack/react-query": "^5.90.14", + "autoprefixer": "^10.4.23", + "better-sqlite3": "^12.5.0", + "class-variance-authority": "^0.7.1", + "clsx": "^2.1.1", + "lucide-react": "^0.561.0", + "next": "^14.0.0", + "postcss": "^8.5.6", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "recharts": "^3.5.1", + "tailwind-merge": "^3.4.0", + "tailwindcss": "3.4.17", + "tailwindcss-animate": "1.0.7" + }, + "devDependencies": { + "@types/better-sqlite3": "^7.6.13", + "@types/node": "^20.0.0", + "@types/react": "^18.2.0", + "typescript": "^5.0.0" + } +} diff --git a/skyrl-tx/viz/postcss.config.mjs b/skyrl-tx/viz/postcss.config.mjs new file mode 100644 index 000000000..a982c6414 --- /dev/null +++ b/skyrl-tx/viz/postcss.config.mjs @@ -0,0 +1,8 @@ +const config = { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +}; + +export default config; diff --git a/skyrl-tx/viz/tailwind.config.ts b/skyrl-tx/viz/tailwind.config.ts new file mode 100644 index 000000000..9f55b4a7a --- /dev/null +++ b/skyrl-tx/viz/tailwind.config.ts @@ -0,0 +1,82 @@ +import type { Config } from "tailwindcss"; + +const config: Config = { + darkMode: ["class"], + content: [ + "./app/**/*.{ts,tsx}", + "./components/**/*.{ts,tsx}", + "./lib/**/*.{ts,tsx}", + ], + theme: { + extend: { + colors: { + background: "oklch(var(--background))", + foreground: "oklch(var(--foreground))", + card: { + DEFAULT: "oklch(var(--card))", + foreground: "oklch(var(--card-foreground))", + }, + muted: { + DEFAULT: "oklch(var(--muted))", + foreground: "oklch(var(--muted-foreground))", + }, + accent: { + DEFAULT: "oklch(var(--accent))", + foreground: "oklch(var(--accent-foreground))", + }, + border: "oklch(var(--border))", + ring: "oklch(var(--ring))", + sidebar: { + DEFAULT: "oklch(var(--sidebar))", + foreground: "oklch(var(--sidebar-foreground))", + accent: "oklch(var(--sidebar-accent))", + "accent-foreground": "oklch(var(--sidebar-accent-foreground))", + border: "oklch(var(--sidebar-border))", + }, + "table-header": "oklch(var(--muted-foreground))", + chart: { + "1": "oklch(var(--chart-1))", + "2": "oklch(var(--chart-2))", + "3": "oklch(var(--chart-3))", + "4": "oklch(var(--chart-4))", + "5": "oklch(var(--chart-5))", + }, + success: "oklch(var(--success))", + warning: "oklch(var(--warning))", + destructive: "oklch(var(--destructive))", + }, + borderRadius: { + lg: "var(--radius)", + md: "calc(var(--radius) - 2px)", + sm: "calc(var(--radius) - 4px)", + }, + fontFamily: { + sans: ["var(--font-sans)"], + mono: ["var(--font-mono)"], + }, + keyframes: { + "fade-in": { + from: { opacity: "0", transform: "scale(0.98)" }, + to: { opacity: "1", transform: "scale(1)" }, + }, + "modal-bounce-in": { + "0%": { opacity: "0", transform: "scale(0.9)" }, + "50%": { opacity: "1", transform: "scale(1.02)" }, + "100%": { opacity: "1", transform: "scale(1)" }, + }, + "modal-bounce-out": { + "0%": { opacity: "1", transform: "scale(1)" }, + "100%": { opacity: "0", transform: "scale(0.95)" }, + }, + }, + animation: { + "fade-in": "fade-in 0.3s ease", + "modal-bounce-in": "modal-bounce-in 0.3s cubic-bezier(0.34, 1.56, 0.64, 1) forwards", + "modal-bounce-out": "modal-bounce-out 0.2s ease-out forwards", + }, + }, + }, + plugins: [require("tailwindcss-animate")], +}; + +export default config; diff --git a/skyrl-tx/viz/tsconfig.json b/skyrl-tx/viz/tsconfig.json new file mode 100644 index 000000000..455bf35f6 --- /dev/null +++ b/skyrl-tx/viz/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "lib": ["dom", "dom.iterable", "esnext"], + "allowJs": true, + "skipLibCheck": true, + "strict": true, + "noEmit": true, + "esModuleInterop": true, + "module": "esnext", + "moduleResolution": "bundler", + "resolveJsonModule": true, + "isolatedModules": true, + "jsx": "preserve", + "incremental": true, + "plugins": [{ "name": "next" }], + "paths": { "@/*": ["./*"] } + }, + "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], + "exclude": ["node_modules"] +}