Skip to content

Commit 8fbdd19

Browse files
authored
impl swanlab monitor (#450)
1 parent 5109237 commit 8fbdd19

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

tests/utils/swanlab_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
import unittest
3+
4+
5+
class TestSwanlabMonitor(unittest.TestCase):
6+
@classmethod
7+
def setUpClass(cls):
8+
os.environ["SWANLAB_API_KEY"] = "xxxxxxxxxxxxxxxxxxxxx"
9+
10+
@classmethod
11+
def tearDownClass(cls):
12+
# Restore original environment variables
13+
for k, v in cls._original_env.items():
14+
if v is None:
15+
os.environ.pop(k, None)
16+
else:
17+
os.environ[k] = v
18+
19+
@unittest.skip("Requires swanlab package and network access")
20+
def test_swanlab_monitor_smoke(self):
21+
from trinity.utils.monitor import SwanlabMonitor
22+
23+
# Try creating the monitor; if swanlab isn't installed, __init__ will assert
24+
mon = SwanlabMonitor(
25+
project="trinity-smoke",
26+
group="cradle",
27+
name="swanlab-env",
28+
role="tester",
29+
)
30+
31+
# Log a minimal metric to verify basic flow
32+
mon.log({"smoke/metric": 1.0}, step=1)
33+
mon.close()
34+
35+
36+
if __name__ == "__main__":
37+
unittest.main()

trinity/utils/monitor.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
import mlflow
1717
except ImportError:
1818
mlflow = None
19+
20+
try:
21+
import swanlab
22+
except ImportError:
23+
swanlab = None
24+
1925
from torch.utils.tensorboard import SummaryWriter
2026

2127
from trinity.common.config import Config
@@ -28,6 +34,7 @@
2834
"tensorboard": "trinity.utils.monitor.TensorboardMonitor",
2935
"wandb": "trinity.utils.monitor.WandbMonitor",
3036
"mlflow": "trinity.utils.monitor.MlflowMonitor",
37+
"swanlab": "trinity.utils.monitor.SwanlabMonitor",
3138
},
3239
)
3340

@@ -232,3 +239,120 @@ def default_args(cls) -> Dict:
232239
"username": None,
233240
"password": None,
234241
}
242+
243+
244+
class SwanlabMonitor(Monitor):
245+
"""Monitor with SwanLab.
246+
247+
This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments.
248+
249+
Supported monitor_args in config.monitor.monitor_args:
250+
- api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env
251+
(SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login.
252+
- workspace (Optional[str]): Organization/username workspace.
253+
- mode (Optional[str]): "cloud" | "local" | "offline" | "disabled".
254+
- logdir (Optional[str]): Local log directory when in local/offline modes.
255+
- experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}".
256+
- description (Optional[str]): Experiment description.
257+
- tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically.
258+
- id (Optional[str]): Resume target run id (21 chars) when using resume modes.
259+
- resume (Optional[Literal['must','allow','never']|bool]): Resume policy.
260+
- reinit (Optional[bool]): Whether to re-init on repeated init() calls.
261+
"""
262+
263+
def __init__(
264+
self, project: str, group: str, name: str, role: str, config: Config = None
265+
) -> None:
266+
assert (
267+
swanlab is not None
268+
), "swanlab is not installed. Please install it to use SwanlabMonitor."
269+
270+
monitor_args = (
271+
(config.monitor.monitor_args or {})
272+
if config and getattr(config, "monitor", None)
273+
else {}
274+
)
275+
276+
# Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`.
277+
api_key = os.environ.get("SWANLAB_API_KEY")
278+
if api_key:
279+
try:
280+
swanlab.login(api_key=api_key, save=True)
281+
except Exception:
282+
# Best-effort login; continue to init which may still work if already logged in
283+
pass
284+
else:
285+
raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.")
286+
287+
# Compose tags (ensure list and include role/group markers)
288+
tags = monitor_args.get("tags") or []
289+
if isinstance(tags, tuple):
290+
tags = list(tags)
291+
if role and role not in tags:
292+
tags.append(role)
293+
if group and group not in tags:
294+
tags.append(group)
295+
296+
# Determine experiment name
297+
exp_name = monitor_args.get("experiment_name") or f"{name}_{role}"
298+
self.exp_name = exp_name
299+
300+
# Prepare init kwargs, passing only non-None values to respect library defaults
301+
init_kwargs = {
302+
"project": project,
303+
"workspace": monitor_args.get("workspace"),
304+
"experiment_name": exp_name,
305+
"description": monitor_args.get("description"),
306+
"tags": tags or None,
307+
"logdir": monitor_args.get("logdir"),
308+
"mode": monitor_args.get("mode") or "cloud",
309+
"settings": monitor_args.get("settings"),
310+
"id": monitor_args.get("id"),
311+
"config": config.flatten() if config is not None else None,
312+
"resume": monitor_args.get("resume"),
313+
"reinit": monitor_args.get("reinit"),
314+
}
315+
# Strip None values to avoid overriding swanlab defaults
316+
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}
317+
318+
self.logger = swanlab.init(**init_kwargs)
319+
self.console_logger = get_logger(__name__, in_ray_actor=True)
320+
321+
def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
322+
# Convert pandas DataFrame to SwanLab ECharts Table
323+
headers: List[str] = list(experiences_table.columns)
324+
# Ensure rows are native Python types
325+
rows: List[List[object]] = experiences_table.astype(object).values.tolist()
326+
try:
327+
tbl = swanlab.echarts.Table()
328+
tbl.add(headers, rows)
329+
swanlab.log({table_name: tbl}, step=step)
330+
except Exception as e:
331+
self.console_logger.warning(
332+
f"Failed to log table '{table_name}' as echarts, falling back to CSV. Error: {e}"
333+
)
334+
# Fallback: log as CSV string if echarts table is unavailable
335+
csv_str = experiences_table.to_csv(index=False)
336+
swanlab.log({table_name: csv_str}, step=step)
337+
338+
def log(self, data: dict, step: int, commit: bool = False) -> None:
339+
"""Log metrics."""
340+
# SwanLab doesn't use commit flag; keep signature for compatibility
341+
swanlab.log(data, step=step)
342+
self.console_logger.info(f"Step {step}: {data}")
343+
344+
def close(self) -> None:
345+
try:
346+
# Prefer run.finish() if available
347+
if hasattr(self, "logger") and hasattr(self.logger, "finish"):
348+
self.logger.finish()
349+
else:
350+
# Fallback to global finish
351+
swanlab.finish()
352+
except Exception as e:
353+
self.console_logger.warning(f"Failed to close SwanlabMonitor: {e}")
354+
355+
@classmethod
356+
def default_args(cls) -> Dict:
357+
"""Return default arguments for the monitor."""
358+
return {}

0 commit comments

Comments
 (0)