Skip to content

Commit c6a0749

Browse files
committed
added local checkpoints support
1 parent 8ee2659 commit c6a0749

File tree

2 files changed

+78
-21
lines changed

2 files changed

+78
-21
lines changed

src/mjlab/scripts/play.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,46 @@ def __call__(self, obs) -> torch.Tensor:
196196
policy = runner.get_inference_policy(device=device)
197197

198198
ckpt_manager = None
199+
if TRAINED_MODE and cfg.wandb_run_path is None and resume_path is not None:
200+
import time as _time
201+
202+
ckpt_dir = resume_path.parent
203+
_runner_local = runner
204+
205+
def fetch_available_local() -> list[tuple[str, str]]:
206+
now = _time.time()
207+
entries: list[tuple[str, str, int]] = []
208+
for f in sorted(ckpt_dir.glob("*.pt")):
209+
try:
210+
step = int(f.stem.split("_")[1])
211+
except (IndexError, ValueError):
212+
step = 0
213+
s = int(now - f.stat().st_mtime)
214+
for div, unit in ((86400, "d"), (3600, "h"), (60, "m")):
215+
if s >= div:
216+
t = f"{s // div}{unit} ago"
217+
break
218+
else:
219+
t = f"{s}s ago"
220+
entries.append((f.name, t, step))
221+
entries.sort(key=lambda x: x[2])
222+
return [(name, t) for name, t, _ in entries]
223+
224+
def load_checkpoint_local(name: str):
225+
_runner_local.load(
226+
str(ckpt_dir / name),
227+
load_cfg={"actor": True},
228+
strict=True,
229+
map_location=device,
230+
)
231+
return _runner_local.get_inference_policy(device=device)
232+
233+
ckpt_manager = CheckpointManager(
234+
current_name=resume_path.name,
235+
fetch_available=fetch_available_local,
236+
load_checkpoint=load_checkpoint_local,
237+
)
238+
199239
if TRAINED_MODE and cfg.wandb_run_path is not None:
200240
from datetime import datetime, timezone
201241

src/mjlab/viewer/viser/viewer.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@
3636

3737
@dataclass
3838
class CheckpointManager:
39-
run_name: str
40-
run_url: str
41-
run_status: str
4239
current_name: str
4340
fetch_available: Callable[[], list[tuple[str, str]]]
4441
load_checkpoint: Callable[[str], PolicyProtocol]
42+
run_name: str | None = None
43+
run_url: str | None = None
44+
run_status: str | None = None
4545

4646

4747
class UpdateReason(Enum):
@@ -186,24 +186,34 @@ def _debug_viz_extra() -> None:
186186
self._scene.create_groups_gui(tabs)
187187

188188
if self._ckpt_mgr is not None:
189-
with tabs.add_tab("W&B Run", icon=viser.Icon.CLOUD):
190-
self._server.gui.add_html(
191-
f'<div style="font-size: 0.85em; line-height: 1.25;'
192-
f' padding: 0 1em 0.5em 1em;">'
193-
f"<strong>Name:</strong> {self._ckpt_mgr.run_name}<br/>"
194-
f"<strong>Status:</strong> {self._ckpt_mgr.run_status}"
195-
f"</div>"
196-
)
197-
198-
open_button = self._server.gui.add_button(
199-
"Open Run",
200-
icon=viser.Icon.EXTERNAL_LINK,
201-
)
202-
203-
@open_button.on_click
204-
def _(_) -> None:
205-
assert self._ckpt_mgr is not None
206-
webbrowser.open(self._ckpt_mgr.run_url)
189+
is_wandb = self._ckpt_mgr.run_url is not None
190+
with tabs.add_tab("Checkpoints", icon=viser.Icon.DATABASE):
191+
if is_wandb:
192+
self._server.gui.add_html(
193+
f'<div style="font-size: 0.85em; line-height: 1.25;'
194+
f' padding: 0 1em 0.5em 1em;">'
195+
f"<strong>Source:</strong> W&B<br/>"
196+
f"<strong>Run:</strong> {self._ckpt_mgr.run_name}<br/>"
197+
f"<strong>Status:</strong> {self._ckpt_mgr.run_status}"
198+
f"</div>"
199+
)
200+
201+
open_button = self._server.gui.add_button(
202+
"Open Run",
203+
icon=viser.Icon.EXTERNAL_LINK,
204+
)
205+
206+
@open_button.on_click
207+
def _(_) -> None:
208+
assert self._ckpt_mgr is not None
209+
webbrowser.open(self._ckpt_mgr.run_url)
210+
else:
211+
self._server.gui.add_html(
212+
'<div style="font-size: 0.85em; line-height: 1.25;'
213+
' padding: 0 1em 0.5em 1em;">'
214+
"<strong>Source:</strong> Local"
215+
"</div>"
216+
)
207217

208218
self._ckpt_dropdown = self._server.gui.add_dropdown(
209219
"Checkpoint",
@@ -259,6 +269,13 @@ def _handle_custom_action(self, action: ViewerAction, payload: Optional[Any]) ->
259269
print(f"[INFO]: Loading {name}...")
260270
self.policy = self._ckpt_mgr.load_checkpoint(name)
261271
self._ckpt_mgr.current_name = name
272+
self._ckpt_updating = True
273+
cur = next(
274+
(lbl for lbl in self._ckpt_dropdown.options if lbl.startswith(name)),
275+
name,
276+
)
277+
self._ckpt_dropdown.value = cur
278+
self._ckpt_updating = False
262279
self.reset_environment()
263280
print(f"[INFO]: Loaded {name}")
264281
return True

0 commit comments

Comments
 (0)