Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion skyrl-train/skyrl_train/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

# TODO(tgriggs): Test all backends.
class Tracking:
supported_backends = ["wandb", "mlflow", "swanlab", "tensorboard", "console"]
supported_backends = ["wandb", "mlflow", "swanlab", "tensorboard", "console", "tviz"]

def __init__(self, project_name, experiment_name, backends: Union[str, List[str]] = "console", config=None):
if isinstance(backends, str):
Expand Down Expand Up @@ -73,6 +73,11 @@ def __init__(self, project_name, experiment_name, backends: Union[str, List[str]
self.console_logger = ConsoleLogger()
self.logger["console"] = self.console_logger

if "tviz" in backends:
from skyrl_train.utils.tviz_tracker import TvizTracker

self.logger["tviz"] = TvizTracker(experiment_name=experiment_name, config=config)

def log(self, data, step, commit=False):
for logger_name, logger_instance in self.logger.items():
if logger_name == "wandb":
Expand All @@ -94,6 +99,8 @@ def __del__(self):
self.logger["tensorboard"].finish()
if "mlflow" in self.logger:
self.logger["mlflow"].finish()
if "tviz" in self.logger:
self.logger["tviz"].finish()
except Exception as e:
logger.warning(f"Attempted to finish tracking but got error {e}")

Expand Down
185 changes: 185 additions & 0 deletions skyrl-train/skyrl_train/utils/tviz_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""TViz tracking backend for local visualization of SkyRL training runs."""

import json
import os
import sqlite3
import uuid
from pathlib import Path
from typing import Any


METRIC_MAPPING = {
"reward/avg_raw_reward": "reward_mean",
"reward/avg_reward": "reward_mean",
"reward/mean_positive_reward": "reward_mean",
"reward_mean": "reward_mean",
"policy/loss": "loss",
"critic/loss": "loss",
"loss": "loss",
"policy/kl_divergence": "kl_divergence",
"kl_divergence": "kl_divergence",
"kl": "kl_divergence",
"policy/entropy": "entropy",
"entropy": "entropy",
"trainer/learning_rate": "learning_rate",
"learning_rate": "learning_rate",
"lr": "learning_rate",
"generate/avg_num_tokens": "ac_tokens_per_turn",
"generate/avg_tokens_non_zero_rewards": "ac_tokens_per_turn",
"ac_tokens_per_turn": "ac_tokens_per_turn",
"timing/total": "time_total",
"timing/generate": "sampling_time_mean",
"time_total": "time_total",
}

# Schema for training visualization tables (added to tinker.db)
SCHEMA = """
CREATE TABLE IF NOT EXISTS training_runs (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
type TEXT DEFAULT 'rl',
modality TEXT DEFAULT 'text',
config TEXT,
started_at TEXT DEFAULT CURRENT_TIMESTAMP,
ended_at TEXT
);

CREATE TABLE IF NOT EXISTS training_steps (
id INTEGER PRIMARY KEY AUTOINCREMENT,
run_id TEXT NOT NULL,
step INTEGER NOT NULL,
reward_mean REAL,
reward_std REAL,
loss REAL,
kl_divergence REAL,
entropy REAL,
learning_rate REAL,
ac_tokens_per_turn REAL,
ob_tokens_per_turn REAL,
total_ac_tokens INTEGER,
total_turns INTEGER,
sampling_time_mean REAL,
time_total REAL,
frac_mixed REAL,
frac_all_good REAL,
frac_all_bad REAL,
extras TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (run_id) REFERENCES training_runs(id),
UNIQUE(run_id, step)
);

CREATE INDEX IF NOT EXISTS idx_training_steps_run_id ON training_steps(run_id);
"""


def get_tinker_db_path() -> Path:
"""Get the tinker database path from environment or default."""
if db_url := os.environ.get("TX_DATABASE_URL"):
if db_url.startswith("sqlite:///"):
return Path(db_url.replace("sqlite:///", ""))
return Path(__file__).parent.parent.parent.parent / "skyrl-tx" / "tx" / "tinker" / "tinker.db"


class TvizTracker:
"""TViz tracking backend that writes metrics to the tinker database."""

def __init__(self, experiment_name: str, config: dict[str, Any] | None = None):
self.db_path = get_tinker_db_path()
self.db_path.parent.mkdir(parents=True, exist_ok=True)

self.run_id = str(uuid.uuid4())[:8]
self.run_name = experiment_name
self._config = config
self._conn: sqlite3.Connection | None = None
self._initialized = False

def _get_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
return self._conn

def _ensure_initialized(self) -> None:
if self._initialized:
return

conn = self._get_conn()
conn.executescript(SCHEMA)

config_json = None
if self._config is not None:
try:
from omegaconf import OmegaConf
if hasattr(self._config, "_content"):
self._config = OmegaConf.to_container(self._config, resolve=True)
except ImportError:
pass
config_json = json.dumps(self._config)

conn.execute(
"INSERT INTO training_runs (id, name, type, modality, config) VALUES (?, ?, ?, ?, ?)",
(self.run_id, self.run_name, "rl", "text", config_json),
)
conn.commit()
self._initialized = True
print(f"TViz dashboard: http://localhost:3003/training-run/{self.run_id}")

def _map_metrics(self, data: dict[str, Any]) -> dict[str, Any]:
mapped = {}
for key, value in data.items():
if not isinstance(value, (int, float)):
continue
if key in METRIC_MAPPING:
tviz_key = METRIC_MAPPING[key]
if tviz_key not in mapped:
mapped[tviz_key] = value
else:
mapped[key] = value
return mapped

def log(self, data: dict[str, Any], step: int) -> None:
self._ensure_initialized()
metrics = self._map_metrics(data)

known_cols = [
"reward_mean", "reward_std", "loss", "kl_divergence", "entropy",
"learning_rate", "ac_tokens_per_turn", "ob_tokens_per_turn",
"total_ac_tokens", "total_turns", "sampling_time_mean", "time_total",
"frac_mixed", "frac_all_good", "frac_all_bad",
]

values = {col: metrics.pop(col, None) for col in known_cols}
extras = json.dumps(metrics) if metrics else None

conn = self._get_conn()
conn.execute(
"""
INSERT OR REPLACE INTO training_steps
(run_id, step, reward_mean, reward_std, loss, kl_divergence, entropy,
learning_rate, ac_tokens_per_turn, ob_tokens_per_turn, total_ac_tokens,
total_turns, sampling_time_mean, time_total, frac_mixed, frac_all_good,
frac_all_bad, extras)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
self.run_id, step,
values["reward_mean"], values["reward_std"], values["loss"],
values["kl_divergence"], values["entropy"], values["learning_rate"],
values["ac_tokens_per_turn"], values["ob_tokens_per_turn"],
values["total_ac_tokens"], values["total_turns"],
values["sampling_time_mean"], values["time_total"],
values["frac_mixed"], values["frac_all_good"], values["frac_all_bad"],
extras,
),
)
conn.commit()

def finish(self) -> None:
if self._conn is not None:
self._conn.execute(
"UPDATE training_runs SET ended_at = CURRENT_TIMESTAMP WHERE id = ?",
(self.run_id,),
)
self._conn.commit()
self._conn.close()
self._conn = None
2 changes: 1 addition & 1 deletion skyrl-tx/tx/loaders/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def chat(tokenizer: PreTrainedTokenizer, dataset: Dataset, batch_size: int) -> L
"text": batch["input_ids"][:, :-1],
"attention_mask": batch["attention_mask"][:, :-1],
"target": batch["input_ids"][:, 1:],
}, {"shape": batch["input_ids"].shape, "tokens": batch["attention_mask"].sum()}
}, {"shape": batch["input_ids"].shape, "tokens": batch["attention_mask"].sum()}t
41 changes: 41 additions & 0 deletions skyrl-tx/tx/tinker/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,44 @@ class SamplingSessionDB(SQLModel, table=True):
base_model: str | None = None
model_path: str | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))


class TrainingRunDB(SQLModel, table=True):
"""Training run for visualization dashboard."""

__tablename__ = "training_runs"

id: str = Field(primary_key=True)
name: str
type: str = Field(default="rl")
modality: str = Field(default="text")
config: dict | None = Field(default=None, sa_type=JSON)
started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
ended_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True))


class TrainingStepDB(SQLModel, table=True):
"""Training step metrics for visualization dashboard."""

__tablename__ = "training_steps"

id: int | None = Field(default=None, primary_key=True, sa_column_kwargs={"autoincrement": True})
run_id: str = Field(foreign_key="training_runs.id", index=True)
step: int
reward_mean: float | None = None
reward_std: float | None = None
loss: float | None = None
kl_divergence: float | None = None
entropy: float | None = None
learning_rate: float | None = None
ac_tokens_per_turn: float | None = None
ob_tokens_per_turn: float | None = None
total_ac_tokens: int | None = None
total_turns: int | None = None
sampling_time_mean: float | None = None
time_total: float | None = None
frac_mixed: float | None = None
frac_all_good: float | None = None
frac_all_bad: float | None = None
extras: dict | None = Field(default=None, sa_type=JSON)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
42 changes: 42 additions & 0 deletions skyrl-tx/viz/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# dependencies
node_modules/
.pnp
.pnp.js

# build
.next/
out/
dist/
build/

# misc
.DS_Store
*.pem

# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# local env files
.env*.local
.env

# vercel
.vercel

# typescript
*.tsbuildinfo
next-env.d.ts

# lock files (use bun.lock)
package-lock.json
yarn.lock

# database (local development)
*.db

# python
__pycache__/
*.pyc
.venv/
23 changes: 23 additions & 0 deletions skyrl-tx/viz/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SkyRL Dashboard

Local visualization dashboard for SkyRL training runs.

## Setup

```bash
cd skyrl-tx/viz
bun install
bun dev
```

Dashboard runs at http://localhost:3003

## Architecture

- Next.js App Router frontend
- Reads from `tinker.db` (SQLite)
- Tables: `training_runs`, `training_steps`, `sessions`, `models`, `futures`, `checkpoints`

## Environment Variables

- `TINKER_DB_PATH` - Path to tinker database (default: `./tx/tinker/tinker.db`)
11 changes: 11 additions & 0 deletions skyrl-tx/viz/app/(dashboard)/layout.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { Sidebar } from "@/components/ui/sidebar";

export default function DashboardLayout({ children }: { children: React.ReactNode }) {
return (
<div className="flex min-h-screen bg-background">
<Sidebar />
{/* pt-14 on mobile for fixed header, md:pt-0 for desktop */}
<main className="flex-1 overflow-auto pt-14 md:pt-0">{children}</main>
</div>
);
}
5 changes: 5 additions & 0 deletions skyrl-tx/viz/app/(dashboard)/page.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import Dashboard from "@/components/Dashboard";

export default function HomePage() {
return <Dashboard />;
}
Loading