Skip to content

Commit 419aa8c

Browse files
authored
Add reporters (#82)
* init * update test * add tensorboard * fix mypy * add wandb to ignore * add final metrics * move to separate files * update ruff
1 parent a4287b0 commit 419aa8c

File tree

18 files changed

+624
-13
lines changed

18 files changed

+624
-13
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,5 @@ poetry.lock
177177
indexes_dirnames.json
178178
tests_logs
179179
tests/logs
180+
runs/
181+
vector_db*

autointent/_callbacks/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from autointent._callbacks.base import OptimizerCallback
2+
from autointent._callbacks.callback_handler import CallbackHandler
3+
from autointent._callbacks.tensorboard import TensorBoardCallback
4+
from autointent._callbacks.wandb import WandbCallback
5+
6+
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
7+
8+
9+
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
10+
"""
11+
Get the list of callbacks.
12+
13+
:param reporters: List of reporters to use.
14+
:return: Callback handler.
15+
"""
16+
if not reporters:
17+
return CallbackHandler()
18+
19+
reporters_cb = []
20+
for reporter in reporters:
21+
if reporter not in REPORTERS:
22+
msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}"
23+
raise ValueError(msg)
24+
reporters_cb.append(REPORTERS[reporter])
25+
return CallbackHandler(callbacks=reporters_cb)
26+
27+
28+
__all__ = [
29+
"CallbackHandler",
30+
"OptimizerCallback",
31+
"TensorBoardCallback",
32+
"WandbCallback",
33+
"get_callbacks",
34+
]

autointent/_callbacks/base.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Base class for reporters (W&B, TensorBoard, etc)."""
2+
3+
from abc import ABC, abstractmethod
4+
from pathlib import Path
5+
from typing import Any
6+
7+
8+
class OptimizerCallback(ABC):
9+
"""Base class for reporters (W&B, TensorBoard, etc)."""
10+
11+
# Implementation inspired by TrainerCallback from HuggingFace Transformers. https://github.com/huggingface/transformers/blob/91b8ab18b778ae9e2f8191866e018cd1dc7097be/src/transformers/trainer_callback.py#L260
12+
name: str
13+
14+
@abstractmethod
15+
def __init__(self) -> None:
16+
pass
17+
18+
@abstractmethod
19+
def start_run(self, run_name: str, dirpath: Path) -> None:
20+
"""
21+
Start a new run.
22+
23+
:param run_name: Name of the run.
24+
:param dirpath: Path to the directory where the logs will be saved.
25+
"""
26+
27+
@abstractmethod
28+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
29+
"""
30+
Start a new module.
31+
32+
:param module_name: Name of the module.
33+
:param num: Number of the module.
34+
:param module_kwargs: Module parameters.
35+
"""
36+
37+
@abstractmethod
38+
def log_value(self, **kwargs: dict[str, Any]) -> None:
39+
"""
40+
Log data.
41+
42+
:param kwargs: Data to log.
43+
"""
44+
45+
@abstractmethod
46+
def end_module(self) -> None:
47+
"""End a module."""
48+
49+
@abstractmethod
50+
def end_run(self) -> None:
51+
"""End a run."""
52+
53+
@abstractmethod
54+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
55+
"""
56+
Log final metrics.
57+
58+
:param metrics: Final metrics.
59+
"""
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
from autointent._callbacks.base import OptimizerCallback
5+
6+
7+
class CallbackHandler(OptimizerCallback):
8+
"""Internal class that just calls the list of callbacks in order."""
9+
10+
callbacks: list[OptimizerCallback]
11+
12+
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
13+
"""Initialize the callback handler."""
14+
if not callbacks:
15+
self.callbacks = []
16+
return
17+
18+
self.callbacks = [cb() for cb in callbacks]
19+
20+
def start_run(self, run_name: str, dirpath: Path) -> None:
21+
"""
22+
Start a new run.
23+
24+
:param run_name: Name of the run.
25+
:param dirpath: Path to the directory where the logs will be saved.
26+
"""
27+
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
28+
29+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
30+
"""
31+
Start a new module.
32+
33+
:param module_name: Name of the module.
34+
:param num: Number of the module.
35+
:param module_kwargs: Module parameters.
36+
"""
37+
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
38+
39+
def log_value(self, **kwargs: dict[str, Any]) -> None:
40+
"""
41+
Log data.
42+
43+
:param kwargs: Data to log.
44+
"""
45+
self.call_events("log_value", **kwargs)
46+
47+
def end_module(self) -> None:
48+
"""End a module."""
49+
self.call_events("end_module")
50+
51+
def end_run(self) -> None:
52+
"""End a run."""
53+
self.call_events("end_run")
54+
55+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
56+
"""
57+
Log final metrics.
58+
59+
:param metrics: Final metrics.
60+
"""
61+
self.call_events("log_final_metrics", metrics=metrics)
62+
63+
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
64+
for callback in self.callbacks:
65+
getattr(callback, event)(**kwargs)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
from autointent._callbacks.base import OptimizerCallback
5+
6+
7+
class TensorBoardCallback(OptimizerCallback):
8+
"""
9+
TensorBoard callback.
10+
11+
This callback logs the optimization process to TensorBoard.
12+
"""
13+
14+
name = "tensorboard"
15+
16+
def __init__(self) -> None:
17+
"""Initialize the callback."""
18+
try:
19+
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
20+
21+
self.writer = SummaryWriter
22+
except ImportError:
23+
try:
24+
from tensorboardX import SummaryWriter # type: ignore[no-redef]
25+
26+
self.writer = SummaryWriter
27+
except ImportError:
28+
msg = (
29+
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
30+
" install tensorboardX."
31+
)
32+
raise ImportError(msg) from None
33+
34+
def start_run(self, run_name: str, dirpath: Path) -> None:
35+
"""
36+
Start a new run.
37+
38+
:param run_name: Name of the run.
39+
:param dirpath: Path to the directory where the logs will be saved.
40+
"""
41+
self.run_name = run_name
42+
self.dirpath = dirpath
43+
44+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
45+
"""
46+
Start a new module.
47+
48+
:param module_name: Name of the module.
49+
:param num: Number of the module.
50+
:param module_kwargs: Module parameters.
51+
"""
52+
module_run_name = f"{self.run_name}_{module_name}_{num}"
53+
log_dir = Path(self.dirpath) / module_run_name
54+
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
55+
56+
self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}") # type: ignore[no-untyped-call]
57+
for key, value in module_kwargs.items():
58+
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
59+
60+
def log_value(self, **kwargs: dict[str, Any]) -> None:
61+
"""
62+
Log data.
63+
64+
:param kwargs: Data to log.
65+
"""
66+
if self.module_writer is None:
67+
msg = "start_run must be called before log_value."
68+
raise RuntimeError(msg)
69+
70+
for key, value in kwargs.items():
71+
if isinstance(value, int | float):
72+
self.module_writer.add_scalar(key, value)
73+
else:
74+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
75+
76+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
77+
"""
78+
Log final metrics.
79+
80+
:param metrics: Final metrics.
81+
"""
82+
if self.module_writer is None:
83+
msg = "start_run must be called before log_final_metrics."
84+
raise RuntimeError(msg)
85+
86+
log_dir = Path(self.dirpath) / "final_metrics"
87+
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
88+
89+
for key, value in metrics.items():
90+
if isinstance(value, int | float):
91+
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
92+
else:
93+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
94+
95+
def end_module(self) -> None:
96+
"""End a module."""
97+
if self.module_writer is None:
98+
msg = "start_run must be called before end_module."
99+
raise RuntimeError(msg)
100+
101+
self.module_writer.add_text("module_info", "Ending module") # type: ignore[no-untyped-call]
102+
self.module_writer.close() # type: ignore[no-untyped-call]
103+
104+
def end_run(self) -> None:
105+
pass

autointent/_callbacks/wandb.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import os
2+
from pathlib import Path
3+
from typing import Any
4+
5+
from autointent._callbacks.base import OptimizerCallback
6+
7+
8+
class WandbCallback(OptimizerCallback):
9+
"""
10+
Wandb callback.
11+
12+
This callback logs the optimization process to W&B.
13+
To specify the project name, set the `WANDB_PROJECT` environment variable. Default is `autointent`.
14+
"""
15+
16+
name = "wandb"
17+
18+
def __init__(self) -> None:
19+
"""Initialize the callback."""
20+
try:
21+
import wandb
22+
except ImportError:
23+
msg = "Please install wandb to use this callback. `pip install wandb`"
24+
raise ImportError(msg) from None
25+
26+
self.wandb = wandb
27+
28+
def start_run(self, run_name: str, dirpath: Path) -> None:
29+
"""
30+
Start a new run.
31+
32+
:param run_name: Name of the run.
33+
:param dirpath: Path to the directory where the logs will be saved. (Not used for this callback)
34+
"""
35+
self.project_name = os.getenv("WANDB_PROJECT", "autointent")
36+
self.group = run_name
37+
self.dirpath = dirpath
38+
39+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
40+
"""
41+
Start a new module.
42+
43+
:param module_name: Name of the module.
44+
:param num: Number of the module.
45+
:param module_kwargs: Module parameters.
46+
"""
47+
self.wandb.init(
48+
project=self.project_name,
49+
group=self.group,
50+
name=f"{module_name}_{num}",
51+
config=module_kwargs,
52+
)
53+
54+
def log_value(self, **kwargs: dict[str, Any]) -> None:
55+
"""
56+
Log data.
57+
58+
:param kwargs: Data to log.
59+
"""
60+
self.wandb.log(kwargs)
61+
62+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
63+
"""
64+
Log final metrics.
65+
66+
:param metrics: Final metrics.
67+
"""
68+
self.wandb.init(
69+
project=self.project_name,
70+
group=self.group,
71+
name="final_metrics",
72+
config=metrics,
73+
)
74+
self.wandb.log(metrics)
75+
self.wandb.finish()
76+
77+
def end_module(self) -> None:
78+
"""End a module."""
79+
self.wandb.finish()
80+
81+
def end_run(self) -> None:
82+
pass

0 commit comments

Comments
 (0)