-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathlogger.py
More file actions
87 lines (78 loc) · 3.51 KB
/
logger.py
File metadata and controls
87 lines (78 loc) · 3.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import html
from datetime import datetime
from pathlib import Path
import pandas as pd
from omegaconf import DictConfig, OmegaConf
from torch.utils.tensorboard.writer import SummaryWriter
class Logger:
"""
A unified logger that supports TensorBoard, Weights & Biases (wandb), and SwanLab.
It initializes based on the configuration provided and allows logging of metrics and text.
"""
def __init__(self, config: DictConfig):
self.config = config
self.logger_type = config.logger.type
self.writer = None
config_dict = OmegaConf.to_container(config, resolve=True)
init_kwargs = dict(
project=config.logger.project_name,
config=config_dict,
reinit=True
)
if hasattr(config.logger, 'run_name') and config.logger.run_name is not None:
init_kwargs["name"] = config.logger.run_name
if self.logger_type == "swanlab":
import swanlab
swanlab.init(**init_kwargs)
self.writer = swanlab
elif self.logger_type == "wandb":
import wandb
wandb.init(**init_kwargs)
self.writer = wandb
elif self.logger_type == "tensorboard":
current_time = datetime.now().strftime(r"%Y%m%d-%H%M%S")
log_dir = Path.cwd() / "tensorboard" / current_time
log_dir.mkdir(parents=True, exist_ok=True)
self.writer = SummaryWriter(log_dir=str(log_dir))
else:
print(f"Unknown logger type '{self.logger_type}'. Logging will be disabled.")
self.logger_type = "none"
def log_metrics(self, metrics: dict, step: int):
if self.logger_type == "none":
return
if self.logger_type in ["wandb", "swanlab"]:
self.writer.log(metrics, step=step)
elif self.logger_type == "tensorboard":
for key, value in metrics.items():
self.writer.add_scalar(key, value, step)
def log_text(self, key: str, text: str, step: int):
if self.logger_type == "none":
return
if self.logger_type == "wandb":
# wandb can directly record HTML
self.writer.log({key: self.writer.Html(text)}, step=step)
elif self.logger_type == "swanlab":
self.writer.log({key: self.writer.Text(text)}, step=step)
elif self.logger_type == "tensorboard":
# TensorBoard treats text as markdown
escaped_text = html.escape(text)
self.writer.add_text(key, f"<pre>{escaped_text}</pre>", step)
def log_table(self, key: str, table: dict):
if self.logger_type == "wandb":
self.writer.log({key: self.writer.Table(dataframe=pd.DataFrame(table))})
elif self.logger_type == "swanlab":
etable = self.writer.echarts.Table()
etable.add(
list(table.keys()),
[list(row) for row in zip(*table.values())] # Transpose the table for swanlab
)
self.writer.log({"completions": etable})
elif self.logger_type == "tensorboard":
# TensorBoard does not support direct table logging, so we convert to text
table_str = pd.DataFrame(table).to_html(index=False)
self.writer.add_text(key, f"<pre>{html.escape(table_str)}</pre>")
def close(self):
if self.writer and self.logger_type in ["wandb", "swanlab"]:
self.writer.finish()
elif self.writer and self.logger_type == "tensorboard":
self.writer.close()