|
16 | 16 | import mlflow |
17 | 17 | except ImportError: |
18 | 18 | mlflow = None |
| 19 | + |
| 20 | +try: |
| 21 | + import swanlab |
| 22 | +except ImportError: |
| 23 | + swanlab = None |
| 24 | + |
19 | 25 | from torch.utils.tensorboard import SummaryWriter |
20 | 26 |
|
21 | 27 | from trinity.common.config import Config |
|
28 | 34 | "tensorboard": "trinity.utils.monitor.TensorboardMonitor", |
29 | 35 | "wandb": "trinity.utils.monitor.WandbMonitor", |
30 | 36 | "mlflow": "trinity.utils.monitor.MlflowMonitor", |
| 37 | + "swanlab": "trinity.utils.monitor.SwanlabMonitor", |
31 | 38 | }, |
32 | 39 | ) |
33 | 40 |
|
@@ -232,3 +239,120 @@ def default_args(cls) -> Dict: |
232 | 239 | "username": None, |
233 | 240 | "password": None, |
234 | 241 | } |
| 242 | + |
| 243 | + |
| 244 | +class SwanlabMonitor(Monitor): |
| 245 | + """Monitor with SwanLab. |
| 246 | +
|
| 247 | + This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments. |
| 248 | +
|
| 249 | + Supported monitor_args in config.monitor.monitor_args: |
| 250 | + - api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env |
| 251 | + (SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login. |
| 252 | + - workspace (Optional[str]): Organization/username workspace. |
| 253 | + - mode (Optional[str]): "cloud" | "local" | "offline" | "disabled". |
| 254 | + - logdir (Optional[str]): Local log directory when in local/offline modes. |
| 255 | + - experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}". |
| 256 | + - description (Optional[str]): Experiment description. |
| 257 | + - tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically. |
| 258 | + - id (Optional[str]): Resume target run id (21 chars) when using resume modes. |
| 259 | + - resume (Optional[Literal['must','allow','never']|bool]): Resume policy. |
| 260 | + - reinit (Optional[bool]): Whether to re-init on repeated init() calls. |
| 261 | + """ |
| 262 | + |
| 263 | + def __init__( |
| 264 | + self, project: str, group: str, name: str, role: str, config: Config = None |
| 265 | + ) -> None: |
| 266 | + assert ( |
| 267 | + swanlab is not None |
| 268 | + ), "swanlab is not installed. Please install it to use SwanlabMonitor." |
| 269 | + |
| 270 | + monitor_args = ( |
| 271 | + (config.monitor.monitor_args or {}) |
| 272 | + if config and getattr(config, "monitor", None) |
| 273 | + else {} |
| 274 | + ) |
| 275 | + |
| 276 | + # Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`. |
| 277 | + api_key = os.environ.get("SWANLAB_API_KEY") |
| 278 | + if api_key: |
| 279 | + try: |
| 280 | + swanlab.login(api_key=api_key, save=True) |
| 281 | + except Exception: |
| 282 | + # Best-effort login; continue to init which may still work if already logged in |
| 283 | + pass |
| 284 | + else: |
| 285 | + raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.") |
| 286 | + |
| 287 | + # Compose tags (ensure list and include role/group markers) |
| 288 | + tags = monitor_args.get("tags") or [] |
| 289 | + if isinstance(tags, tuple): |
| 290 | + tags = list(tags) |
| 291 | + if role and role not in tags: |
| 292 | + tags.append(role) |
| 293 | + if group and group not in tags: |
| 294 | + tags.append(group) |
| 295 | + |
| 296 | + # Determine experiment name |
| 297 | + exp_name = monitor_args.get("experiment_name") or f"{name}_{role}" |
| 298 | + self.exp_name = exp_name |
| 299 | + |
| 300 | + # Prepare init kwargs, passing only non-None values to respect library defaults |
| 301 | + init_kwargs = { |
| 302 | + "project": project, |
| 303 | + "workspace": monitor_args.get("workspace"), |
| 304 | + "experiment_name": exp_name, |
| 305 | + "description": monitor_args.get("description"), |
| 306 | + "tags": tags or None, |
| 307 | + "logdir": monitor_args.get("logdir"), |
| 308 | + "mode": monitor_args.get("mode") or "cloud", |
| 309 | + "settings": monitor_args.get("settings"), |
| 310 | + "id": monitor_args.get("id"), |
| 311 | + "config": config.flatten() if config is not None else None, |
| 312 | + "resume": monitor_args.get("resume"), |
| 313 | + "reinit": monitor_args.get("reinit"), |
| 314 | + } |
| 315 | + # Strip None values to avoid overriding swanlab defaults |
| 316 | + init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None} |
| 317 | + |
| 318 | + self.logger = swanlab.init(**init_kwargs) |
| 319 | + self.console_logger = get_logger(__name__, in_ray_actor=True) |
| 320 | + |
| 321 | + def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): |
| 322 | + # Convert pandas DataFrame to SwanLab ECharts Table |
| 323 | + headers: List[str] = list(experiences_table.columns) |
| 324 | + # Ensure rows are native Python types |
| 325 | + rows: List[List[object]] = experiences_table.astype(object).values.tolist() |
| 326 | + try: |
| 327 | + tbl = swanlab.echarts.Table() |
| 328 | + tbl.add(headers, rows) |
| 329 | + swanlab.log({table_name: tbl}, step=step) |
| 330 | + except Exception as e: |
| 331 | + self.console_logger.warning( |
| 332 | + f"Failed to log table '{table_name}' as echarts, falling back to CSV. Error: {e}" |
| 333 | + ) |
| 334 | + # Fallback: log as CSV string if echarts table is unavailable |
| 335 | + csv_str = experiences_table.to_csv(index=False) |
| 336 | + swanlab.log({table_name: csv_str}, step=step) |
| 337 | + |
| 338 | + def log(self, data: dict, step: int, commit: bool = False) -> None: |
| 339 | + """Log metrics.""" |
| 340 | + # SwanLab doesn't use commit flag; keep signature for compatibility |
| 341 | + swanlab.log(data, step=step) |
| 342 | + self.console_logger.info(f"Step {step}: {data}") |
| 343 | + |
| 344 | + def close(self) -> None: |
| 345 | + try: |
| 346 | + # Prefer run.finish() if available |
| 347 | + if hasattr(self, "logger") and hasattr(self.logger, "finish"): |
| 348 | + self.logger.finish() |
| 349 | + else: |
| 350 | + # Fallback to global finish |
| 351 | + swanlab.finish() |
| 352 | + except Exception as e: |
| 353 | + self.console_logger.warning(f"Failed to close SwanlabMonitor: {e}") |
| 354 | + |
| 355 | + @classmethod |
| 356 | + def default_args(cls) -> Dict: |
| 357 | + """Return default arguments for the monitor.""" |
| 358 | + return {} |
0 commit comments