Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
import mlflow
except ImportError:
mlflow = None

try:
import swanlab
except ImportError:
swanlab = None

from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
Expand Down Expand Up @@ -232,3 +238,118 @@ def default_args(cls) -> Dict:
"username": None,
"password": None,
}


@MONITOR.register_module("swanlab")
class SwanlabMonitor(Monitor):
"""Monitor with SwanLab.

This monitor integrates with SwanLab (https://swanlab.cn/) to track experiments.

Supported monitor_args in config.monitor.monitor_args:
- api_key (Optional[str]): API key for swanlab.login(). If omitted, will read from env
(SWANLAB_API_KEY, SWANLAB_APIKEY, SWANLAB_KEY, SWANLAB_TOKEN) or assume prior CLI login.
- workspace (Optional[str]): Organization/username workspace.
- mode (Optional[str]): "cloud" | "local" | "offline" | "disabled".
- logdir (Optional[str]): Local log directory when in local/offline modes.
- experiment_name (Optional[str]): Explicit experiment name. Defaults to "{name}_{role}".
- description (Optional[str]): Experiment description.
- tags (Optional[List[str]]): Tags to attach. Role and group are appended automatically.
- id (Optional[str]): Resume target run id (21 chars) when using resume modes.
- resume (Optional[Literal['must','allow','never']|bool]): Resume policy.
- reinit (Optional[bool]): Whether to re-init on repeated init() calls.
"""

def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert (
swanlab is not None
), "swanlab is not installed. Please install it to use SwanlabMonitor."

monitor_args = (
(config.monitor.monitor_args or {})
if config and getattr(config, "monitor", None)
else {}
)

# Optional API login via code if provided; otherwise try environment, then rely on prior `swanlab login`.
api_key = os.environ.get("SWANLAB_API_KEY")
if api_key:
try:
swanlab.login(api_key=api_key, save=True)
except Exception:
# Best-effort login; continue to init which may still work if already logged in
pass
else:
raise RuntimeError("Swanlab API key not found in environment variable SWANLAB_API_KEY.")

# Compose tags (ensure list and include role/group markers)
tags = monitor_args.get("tags") or []
if isinstance(tags, tuple):
tags = list(tags)
if role and role not in tags:
tags.append(role)
if group and group not in tags:
tags.append(group)

# Determine experiment name
exp_name = monitor_args.get("experiment_name") or f"{name}_{role}"
self.exp_name = exp_name

# Prepare init kwargs, passing only non-None values to respect library defaults
init_kwargs = {
"project": project,
"workspace": monitor_args.get("workspace"),
"experiment_name": exp_name,
"description": monitor_args.get("description"),
"tags": tags or None,
"logdir": monitor_args.get("logdir"),
"mode": monitor_args.get("mode") or "cloud",
"settings": monitor_args.get("settings"),
"id": monitor_args.get("id"),
"config": config.flatten(),
"resume": monitor_args.get("resume"),
"reinit": monitor_args.get("reinit"),
}
# Strip None values to avoid overriding swanlab defaults
init_kwargs = {k: v for k, v in init_kwargs.items() if v is not None}

self.logger = swanlab.init(**init_kwargs)
self.console_logger = get_logger(__name__, in_ray_actor=True)

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
# Convert pandas DataFrame to SwanLab ECharts Table
headers: List[str] = list(experiences_table.columns)
# Ensure rows are native Python types
rows: List[List[object]] = experiences_table.astype(object).values.tolist()
try:
tbl = swanlab.echarts.Table()
tbl.add(headers, rows)
swanlab.log({table_name: tbl}, step=step)
except Exception:
# Fallback: log as CSV string if echarts table is unavailable
csv_str = experiences_table.to_csv(index=False)
swanlab.log({table_name: csv_str}, step=step)

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
# SwanLab doesn't use commit flag; keep signature for compatibility
swanlab.log(data, step=step)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
try:
# Prefer run.finish() if available
if hasattr(self, "logger") and hasattr(self.logger, "finish"):
self.logger.finish()
else:
# Fallback to global finish
swanlab.finish()
except Exception:
pass

@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {}