|
36 | 36 |
|
37 | 37 | @dataclass |
38 | 38 | class CheckpointManager: |
39 | | - run_name: str |
40 | | - run_url: str |
41 | | - run_status: str |
42 | 39 | current_name: str |
43 | 40 | fetch_available: Callable[[], list[tuple[str, str]]] |
44 | 41 | load_checkpoint: Callable[[str], PolicyProtocol] |
| 42 | + run_name: str | None = None |
| 43 | + run_url: str | None = None |
| 44 | + run_status: str | None = None |
45 | 45 |
|
46 | 46 |
|
47 | 47 | class UpdateReason(Enum): |
@@ -186,24 +186,34 @@ def _debug_viz_extra() -> None: |
186 | 186 | self._scene.create_groups_gui(tabs) |
187 | 187 |
|
188 | 188 | if self._ckpt_mgr is not None: |
189 | | - with tabs.add_tab("W&B Run", icon=viser.Icon.CLOUD): |
190 | | - self._server.gui.add_html( |
191 | | - f'<div style="font-size: 0.85em; line-height: 1.25;' |
192 | | - f' padding: 0 1em 0.5em 1em;">' |
193 | | - f"<strong>Name:</strong> {self._ckpt_mgr.run_name}<br/>" |
194 | | - f"<strong>Status:</strong> {self._ckpt_mgr.run_status}" |
195 | | - f"</div>" |
196 | | - ) |
197 | | - |
198 | | - open_button = self._server.gui.add_button( |
199 | | - "Open Run", |
200 | | - icon=viser.Icon.EXTERNAL_LINK, |
201 | | - ) |
202 | | - |
203 | | - @open_button.on_click |
204 | | - def _(_) -> None: |
205 | | - assert self._ckpt_mgr is not None |
206 | | - webbrowser.open(self._ckpt_mgr.run_url) |
| 189 | + is_wandb = self._ckpt_mgr.run_url is not None |
| 190 | + with tabs.add_tab("Checkpoints", icon=viser.Icon.DATABASE): |
| 191 | + if is_wandb: |
| 192 | + self._server.gui.add_html( |
| 193 | + f'<div style="font-size: 0.85em; line-height: 1.25;' |
| 194 | + f' padding: 0 1em 0.5em 1em;">' |
| 195 | + f"<strong>Source:</strong> W&B<br/>" |
| 196 | + f"<strong>Run:</strong> {self._ckpt_mgr.run_name}<br/>" |
| 197 | + f"<strong>Status:</strong> {self._ckpt_mgr.run_status}" |
| 198 | + f"</div>" |
| 199 | + ) |
| 200 | + |
| 201 | + open_button = self._server.gui.add_button( |
| 202 | + "Open Run", |
| 203 | + icon=viser.Icon.EXTERNAL_LINK, |
| 204 | + ) |
| 205 | + |
| 206 | + @open_button.on_click |
| 207 | + def _(_) -> None: |
| 208 | + assert self._ckpt_mgr is not None |
| 209 | + webbrowser.open(self._ckpt_mgr.run_url) |
| 210 | + else: |
| 211 | + self._server.gui.add_html( |
| 212 | + '<div style="font-size: 0.85em; line-height: 1.25;' |
| 213 | + ' padding: 0 1em 0.5em 1em;">' |
| 214 | + "<strong>Source:</strong> Local" |
| 215 | + "</div>" |
| 216 | + ) |
207 | 217 |
|
208 | 218 | self._ckpt_dropdown = self._server.gui.add_dropdown( |
209 | 219 | "Checkpoint", |
@@ -259,6 +269,13 @@ def _handle_custom_action(self, action: ViewerAction, payload: Optional[Any]) -> |
259 | 269 | print(f"[INFO]: Loading {name}...") |
260 | 270 | self.policy = self._ckpt_mgr.load_checkpoint(name) |
261 | 271 | self._ckpt_mgr.current_name = name |
| 272 | + self._ckpt_updating = True |
| 273 | + cur = next( |
| 274 | + (lbl for lbl in self._ckpt_dropdown.options if lbl.startswith(name)), |
| 275 | + name, |
| 276 | + ) |
| 277 | + self._ckpt_dropdown.value = cur |
| 278 | + self._ckpt_updating = False |
262 | 279 | self.reset_environment() |
263 | 280 | print(f"[INFO]: Loaded {name}") |
264 | 281 | return True |
|
0 commit comments