Skip to content

Commit 3f9f0e3

Browse files
committed
impl swanlab monitor
1 parent 38ba481 commit 3f9f0e3

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

trinity/utils/monitor.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
import mlflow
1717
except ImportError:
1818
mlflow = None
19+
20+
try:
21+
import swanlab
22+
except ImportError:
23+
swanlab = None
24+
1925
from torch.utils.tensorboard import SummaryWriter
2026

2127
from trinity.common.config import Config
@@ -232,3 +238,118 @@ def default_args(cls) -> Dict:
232238
"username": None,
233239
"password": None,
234240
}
241+
242+
243+
@MONITOR.register_module("swanlab")
244+
class SwanlabMonitor(Monitor):
245+
"""Monitor with SwanLab.
246+
247+
This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments.
248+
249+
Supported monitor_args in config.monitor.monitor_args:
250+
- api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env
251+
(SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login.
252+
- workspace (Optional[str]): Organization/username workspace.
253+
- mode (Optional[str]): "cloud" | "local" | "offline" | "disabled".
254+
- logdir (Optional[str]): Local log directory when in local/offline modes.
255+
- experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}".
256+
- description (Optional[str]): Experiment description.
257+
- tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically.
258+
- id (Optional[str]): Resume target run id (21 chars) when using resume modes.
259+
- resume (Optional[Literal['must','allow','never']|bool]): Resume policy.
260+
- reinit (Optional[bool]): Whether to re-init on repeated init() calls.
261+
"""
262+
263+
def __init__(
264+
self, project: str, group: str, name: str, role: str, config: Config = None
265+
) -> None:
266+
assert (
267+
swanlab is not None
268+
), "swanlab is not installed. Please install it to use SwanlabMonitor."
269+
270+
monitor_args = (
271+
(config.monitor.monitor_args or {})
272+
if config and getattr(config, "monitor", None)
273+
else {}
274+
)
275+
276+
# Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`.
277+
api_key = os.environ.get("SWANLAB_API_KEY")
278+
if api_key:
279+
try:
280+
swanlab.login(api_key=api_key, save=True)
281+
except Exception:
282+
# Best-effort login; continue to init which may still work if already logged in
283+
pass
284+
else:
285+
raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.")
286+
287+
# Compose tags (ensure list and include role/group markers)
288+
tags = monitor_args.get("tags") or []
289+
if isinstance(tags, tuple):
290+
tags = list(tags)
291+
if role and role not in tags:
292+
tags.append(role)
293+
if group and group not in tags:
294+
tags.append(group)
295+
296+
# Determine experiment name
297+
exp_name = monitor_args.get("experiment_name") or f"{name}_{role}"
298+
self.exp_name = exp_name
299+
300+
# Prepare init kwargs, passing only non-None values to respect library defaults
301+
init_kwargs = {
302+
"project": project,
303+
"workspace": monitor_args.get("workspace"),
304+
"experiment_name": exp_name,
305+
"description": monitor_args.get("description"),
306+
"tags": tags or None,
307+
"logdir": monitor_args.get("logdir"),
308+
"mode": monitor_args.get("mode") or "cloud",
309+
"settings": monitor_args.get("settings"),
310+
"id": monitor_args.get("id"),
311+
"config": config.flatten(),
312+
"resume": monitor_args.get("resume"),
313+
"reinit": monitor_args.get("reinit"),
314+
}
315+
# Strip None values to avoid overriding swanlab defaults
316+
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
317+
318+
self.logger = swanlab.init(**init_kwargs)
319+
self.console_logger = get_logger(__name__, in_ray_actor=True)
320+
321+
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
322+
# Convert pandas DataFrame to SwanLab ECharts Table
323+
headers: List[str] = list(experiences_table.columns)
324+
# Ensure rows are native Python types
325+
rows: List[List[object]] = experiences_table.astype(object).values.tolist()
326+
try:
327+
tbl = swanlab.echarts.Table()
328+
tbl.add(headers, rows)
329+
swanlab.log({table_name: tbl}, step=step)
330+
except Exception:
331+
# Fallback: log as CSV string if echarts table is unavailable
332+
csv_str = experiences_table.to_csv(index=False)
333+
swanlab.log({table_name: csv_str}, step=step)
334+
335+
def log(self, data: dict, step: int, commit: bool = False) -> None:
336+
"""Log metrics."""
337+
# SwanLab doesn't use commit flag; keep signature for compatibility
338+
swanlab.log(data, step=step)
339+
self.console_logger.info(f"Step {step}: {data}")
340+
341+
def close(self) -> None:
342+
try:
343+
# Prefer run.finish() if available
344+
if hasattr(self, "logger") and hasattr(self.logger, "finish"):
345+
self.logger.finish()
346+
else:
347+
# Fallback to global finish
348+
swanlab.finish()
349+
except Exception:
350+
pass
351+
352+
@classmethod
353+
def default_args(cls) -> Dict:
354+
"""Return default arguments for the monitor."""
355+
return {}

0 commit comments

Comments
 (0)