Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion src/mjlab/scripts/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mjlab.utils.torch import configure_torch_backends
from mjlab.utils.wrappers import VideoRecorder
from mjlab.viewer import NativeMujocoViewer, ViserPlayViewer
from mjlab.viewer.viser.viewer import CheckpointManager


@dataclass(frozen=True)
Expand Down Expand Up @@ -194,6 +195,103 @@ def __call__(self, obs) -> torch.Tensor:
)
policy = runner.get_inference_policy(device=device)

ckpt_manager = None
if TRAINED_MODE and cfg.wandb_run_path is None and resume_path is not None:
import time as _time

ckpt_dir = resume_path.parent
_runner_local = runner

def fetch_available_local() -> list[tuple[str, str]]:
now = _time.time()
entries: list[tuple[str, str, int]] = []
for f in sorted(ckpt_dir.glob("*.pt")):
try:
step = int(f.stem.split("_")[1])
except (IndexError, ValueError):
step = 0
s = int(now - f.stat().st_mtime)
for div, unit in ((86400, "d"), (3600, "h"), (60, "m")):
if s >= div:
t = f"{s // div}{unit} ago"
break
else:
t = f"{s}s ago"
entries.append((f.name, t, step))
entries.sort(key=lambda x: x[2])
return [(name, t) for name, t, _ in entries]

def load_checkpoint_local(name: str):
_runner_local.load(
str(ckpt_dir / name),
load_cfg={"actor": True},
strict=True,
map_location=device,
)
return _runner_local.get_inference_policy(device=device)

ckpt_manager = CheckpointManager(
current_name=resume_path.name,
fetch_available=fetch_available_local,
load_checkpoint=load_checkpoint_local,
)

if TRAINED_MODE and cfg.wandb_run_path is not None:
from datetime import datetime, timezone

import wandb

def parse_wandb_dt(value: str | datetime) -> datetime:
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value

api = wandb.Api()
run_path = str(cfg.wandb_run_path)
wandb_run = api.run(run_path)

def fetch_available() -> list[tuple[str, str]]:
run = api.run(run_path)
now = datetime.now(tz=timezone.utc)
entries: list[tuple[str, str, int]] = []
for f in run.files():
if not f.name.endswith(".pt"):
continue
step = int(f.name.split("_")[1].split(".")[0])
s = int((now - parse_wandb_dt(f.updated_at)).total_seconds())
for div, unit in ((86400, "d"), (3600, "h"), (60, "m")):
if s >= div:
t = f"{s // div}{unit} ago"
break
else:
t = f"{s}s ago"
entries.append((f.name, t, step))
entries.sort(key=lambda x: x[2])
return [(name, t) for name, t, _ in entries]

_log_root = log_root_path # type: ignore[possibly-undefined]
_runner = runner # type: ignore[possibly-undefined]

def load_checkpoint(name: str):
path, _ = get_wandb_checkpoint_path(_log_root, Path(run_path), name)
_runner.load(
str(path),
load_cfg={"actor": True},
strict=True,
map_location=device,
)
return _runner.get_inference_policy(device=device)

assert resume_path is not None
ckpt_manager = CheckpointManager(
run_name=parse_wandb_dt(wandb_run.created_at).strftime("%Y-%m-%d_%H-%M-%S"),
run_url=wandb_run.url,
run_status=wandb_run.state,
current_name=resume_path.name,
fetch_available=fetch_available,
load_checkpoint=load_checkpoint,
)

# Handle "auto" viewer selection.
if cfg.viewer == "auto":
has_display = bool(os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"))
Expand All @@ -205,7 +303,7 @@ def __call__(self, obs) -> torch.Tensor:
if resolved_viewer == "native":
NativeMujocoViewer(env, policy).run()
elif resolved_viewer == "viser":
ViserPlayViewer(env, policy).run()
ViserPlayViewer(env, policy, checkpoint_manager=ckpt_manager).run()
else:
raise RuntimeError(f"Unsupported viewer backend: {resolved_viewer}")

Expand Down
1 change: 1 addition & 0 deletions src/mjlab/viewer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class ViewerAction(Enum):
TOGGLE_PLOTS = "toggle_plots"
TOGGLE_DEBUG_VIS = "toggle_debug_vis"
TOGGLE_SHOW_ALL_ENVS = "toggle_show_all_envs"
FETCH_CHECKPOINT = "fetch_checkpoint"
CUSTOM = "custom"


Expand Down
112 changes: 112 additions & 0 deletions src/mjlab/viewer/viser/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from __future__ import annotations

import time
import webbrowser
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from threading import Lock
from typing import Any, Optional

import viser
from typing_extensions import override
Expand All @@ -19,6 +23,7 @@
EnvProtocol,
PolicyProtocol,
VerbosityLevel,
ViewerAction,
)
from mjlab.viewer.viser.overlays import (
ViserCameraOverlays,
Expand All @@ -29,6 +34,16 @@
from mjlab.viewer.viser.scene import ViserMujocoScene


@dataclass
class CheckpointManager:
current_name: str
fetch_available: Callable[[], list[tuple[str, str]]]
load_checkpoint: Callable[[str], PolicyProtocol]
run_name: str | None = None
run_url: str | None = None
run_status: str | None = None


class UpdateReason(Enum):
ACTION = auto()
ENV_SWITCH = auto()
Expand All @@ -45,8 +60,10 @@ def __init__(
frame_rate: float = 60.0,
verbosity: VerbosityLevel = VerbosityLevel.SILENT,
viser_server: viser.ViserServer | None = None,
checkpoint_manager: CheckpointManager | None = None,
) -> None:
super().__init__(env, policy, frame_rate, verbosity)
self._ckpt_mgr = checkpoint_manager
self._term_overlays: ViserTermOverlays | None = None
self._camera_overlays: ViserCameraOverlays | None = None
self._debug_overlays: ViserDebugOverlays | None = None
Expand Down Expand Up @@ -168,6 +185,101 @@ def _debug_viz_extra() -> None:
# Groups tab (geoms and sites).
self._scene.create_groups_gui(tabs)

if self._ckpt_mgr is not None:
is_wandb = self._ckpt_mgr.run_url is not None
with tabs.add_tab("Checkpoints", icon=viser.Icon.DATABASE):
if is_wandb:
self._server.gui.add_html(
f'<div style="font-size: 0.85em; line-height: 1.25;'
f' padding: 0 1em 0.5em 1em;">'
f"<strong>Source:</strong> W&B<br/>"
f"<strong>Run:</strong> {self._ckpt_mgr.run_name}<br/>"
f"<strong>Status:</strong> {self._ckpt_mgr.run_status}"
f"</div>"
)

open_button = self._server.gui.add_button(
"Open Run",
icon=viser.Icon.EXTERNAL_LINK,
)

@open_button.on_click
def _(_) -> None:
assert self._ckpt_mgr is not None
webbrowser.open(self._ckpt_mgr.run_url)
else:
self._server.gui.add_html(
'<div style="font-size: 0.85em; line-height: 1.25;'
' padding: 0 1em 0.5em 1em;">'
"<strong>Source:</strong> Local"
"</div>"
)

self._ckpt_dropdown = self._server.gui.add_dropdown(
"Checkpoint",
options=[self._ckpt_mgr.current_name],
initial_value=self._ckpt_mgr.current_name,
)

self._ckpt_updating = False

@self._ckpt_dropdown.on_update
def _(_) -> None:
if not self._ckpt_updating:
self._actions.append((ViewerAction.FETCH_CHECKPOINT, "selected"))

ckpt_buttons = self._server.gui.add_button_group(
"",
options=["Refresh", "Use Latest"],
)

@ckpt_buttons.on_click
def _(event) -> None:
if event.target.value == "Refresh":
self._actions.append((ViewerAction.FETCH_CHECKPOINT, "refresh"))
else:
self._actions.append((ViewerAction.FETCH_CHECKPOINT, "latest"))

self._actions.append((ViewerAction.FETCH_CHECKPOINT, "refresh"))

@override
def _handle_custom_action(self, action: ViewerAction, payload: Optional[Any]) -> bool:
if action != ViewerAction.FETCH_CHECKPOINT or self._ckpt_mgr is None:
return action == ViewerAction.FETCH_CHECKPOINT

if payload in ("refresh", "latest"):
entries = self._ckpt_mgr.fetch_available()
labels = [f"{n} ({t})" if t else n for n, t in entries]
self._ckpt_updating = True
self._ckpt_dropdown.options = labels
cur = next(
(lbl for lbl in labels if lbl.startswith(self._ckpt_mgr.current_name)),
self._ckpt_mgr.current_name,
)
self._ckpt_dropdown.value = cur
self._ckpt_updating = False
if payload == "refresh":
return True
payload = entries[-1][0]
else:
payload = self._ckpt_dropdown.value.split(" (")[0]

name = payload
if name != self._ckpt_mgr.current_name:
print(f"[INFO]: Loading {name}...")
self.policy = self._ckpt_mgr.load_checkpoint(name)
self._ckpt_mgr.current_name = name
self._ckpt_updating = True
cur = next(
(lbl for lbl in self._ckpt_dropdown.options if lbl.startswith(name)),
name,
)
self._ckpt_dropdown.value = cur
self._ckpt_updating = False
self.reset_environment()
print(f"[INFO]: Loaded {name}")
return True

@override
def _process_actions(self) -> None:
"""Process queued actions and sync UI state."""
Expand Down