Skip to content

Commit 4d87a34

Browse files
committed
init
1 parent 0c1a8f9 commit 4d87a34

File tree

20 files changed

+556
-20
lines changed

20 files changed

+556
-20
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,5 @@ poetry.lock
177177
indexes_dirnames.json
178178
tests_logs
179179
tests/logs
180+
runs/
181+
vector_db*

autointent/_callbacks/__init__.py

Whitespace-only changes.

autointent/_callbacks/base.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""Base class for reporters (W&B, TensorBoard, etc)."""
2+
3+
import os
4+
from abc import ABC, abstractmethod
5+
from pathlib import Path
6+
from typing import Any
7+
8+
9+
class OptimizerCallback(ABC):
10+
"""Base class for reporters (W&B, TensorBoard, etc)."""
11+
12+
# Implementation inspired by TrainerCallback from HuggingFace Transformers. https://github.com/huggingface/transformers/blob/91b8ab18b778ae9e2f8191866e018cd1dc7097be/src/transformers/trainer_callback.py#L260
13+
name: str
14+
15+
@abstractmethod
16+
def __init__(self) -> None:
17+
pass
18+
19+
@abstractmethod
20+
def start_run(self, run_name: str, dirpath: Path) -> None:
21+
"""
22+
Start a new run.
23+
24+
:param run_name: Name of the run.
25+
:param dirpath: Path to the directory where the logs will be saved.
26+
"""
27+
28+
@abstractmethod
29+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
30+
"""
31+
Start a new module.
32+
33+
:param module_name: Name of the module.
34+
:param num: Number of the module.
35+
:param module_kwargs: Module parameters.
36+
"""
37+
38+
@abstractmethod
39+
def log_value(self, **kwargs: dict[str, Any]) -> None:
40+
"""
41+
Log data.
42+
43+
:param kwargs: Data to log.
44+
"""
45+
46+
@abstractmethod
47+
def end_module(self) -> None:
48+
"""End a module."""
49+
50+
@abstractmethod
51+
def end_run(self) -> None:
52+
"""End a run."""
53+
54+
55+
class CallbackHandler(OptimizerCallback):
56+
"""Internal class that just calls the list of callbacks in order."""
57+
58+
callbacks: list[OptimizerCallback]
59+
60+
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
61+
"""Initialize the callback handler."""
62+
if not callbacks:
63+
self.callbacks = []
64+
return
65+
66+
self.callbacks = [cb() for cb in callbacks]
67+
68+
def start_run(self, run_name: str, dirpath: Path) -> None:
69+
"""
70+
Start a new run.
71+
72+
:param run_name: Name of the run.
73+
:param dirpath: Path to the directory where the logs will be saved.
74+
"""
75+
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
76+
77+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
78+
"""
79+
Start a new module.
80+
81+
:param module_name: Name of the module.
82+
:param num: Number of the module.
83+
:param module_kwargs: Module parameters.
84+
"""
85+
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
86+
87+
def log_value(self, **kwargs: dict[str, Any]) -> None:
88+
"""
89+
Log data.
90+
91+
:param kwargs: Data to log.
92+
"""
93+
self.call_events("log_value", **kwargs)
94+
95+
def end_module(self) -> None:
96+
"""End a module."""
97+
self.call_events("end_module")
98+
99+
def end_run(self) -> None:
100+
self.call_events("end_run")
101+
102+
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
103+
for callback in self.callbacks:
104+
getattr(callback, event)(**kwargs)
105+
106+
107+
class WandbCallback(OptimizerCallback):
108+
"""
109+
Wandb callback.
110+
111+
This callback logs the optimization process to W&B.
112+
To specify the project name, set the `WANDB_PROJECT` environment variable. Default is `autointent`.
113+
"""
114+
115+
name = "wandb"
116+
117+
def __init__(self) -> None:
118+
"""Initialize the callback."""
119+
try:
120+
import wandb
121+
except ImportError:
122+
msg = "Please install wandb to use this callback. `pip install wandb`"
123+
raise ImportError(msg) from None
124+
125+
self.wandb = wandb
126+
127+
def start_run(self, run_name: str, dirpath: Path) -> None:
128+
"""
129+
Start a new run.
130+
131+
:param run_name: Name of the run.
132+
:param dirpath: Path to the directory where the logs will be saved. (Not used for this callback)
133+
"""
134+
self.project_name = os.getenv("WANDB_PROJECT", "autointent")
135+
self.group = run_name
136+
self.dirpath = dirpath
137+
138+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
139+
"""
140+
Start a new module.
141+
142+
:param module_name: Name of the module.
143+
:param num: Number of the module.
144+
:param module_kwargs: Module parameters.
145+
"""
146+
self.wandb.init(
147+
project=self.project_name,
148+
group=self.group,
149+
name=f"{module_name}_{num}",
150+
config=module_kwargs,
151+
)
152+
153+
def log_value(self, **kwargs: dict[str, Any]) -> None:
154+
"""
155+
Log data.
156+
157+
:param kwargs: Data to log.
158+
"""
159+
self.wandb.log(kwargs)
160+
161+
def end_module(self) -> None:
162+
"""End a module."""
163+
self.wandb.finish()
164+
165+
def end_run(self) -> None:
166+
pass
167+
168+
169+
class TensorBoardCallback(OptimizerCallback):
170+
"""
171+
TensorBoard callback.
172+
173+
This callback logs the optimization process to TensorBoard.
174+
"""
175+
176+
name = "tensorboard"
177+
178+
def __init__(self) -> None:
179+
"""Initialize the callback."""
180+
try:
181+
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
182+
183+
self.writer = SummaryWriter
184+
except ImportError:
185+
try:
186+
from tensorboardX import SummaryWriter # type: ignore[no-redef]
187+
188+
self.writer = SummaryWriter
189+
except ImportError:
190+
msg = (
191+
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
192+
" install tensorboardX."
193+
)
194+
raise ImportError(msg) from None
195+
196+
def start_run(self, run_name: str, dirpath: Path) -> None:
197+
"""
198+
Start a new run.
199+
200+
:param run_name: Name of the run.
201+
:param dirpath: Path to the directory where the logs will be saved.
202+
"""
203+
self.run_name = run_name
204+
self.dirpath = dirpath
205+
206+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
207+
"""
208+
Start a new module.
209+
210+
:param module_name: Name of the module.
211+
:param num: Number of the module.
212+
:param module_kwargs: Module parameters.
213+
"""
214+
module_run_name = f"{self.run_name}_{module_name}_{num}"
215+
log_dir = Path(self.dirpath) / module_run_name
216+
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
217+
218+
self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}") # type: ignore[no-untyped-call]
219+
for key, value in module_kwargs.items():
220+
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
221+
222+
def log_value(self, **kwargs: dict[str, Any]) -> None:
223+
"""
224+
Log data.
225+
226+
:param kwargs: Data to log.
227+
"""
228+
if self.module_writer is None:
229+
msg = "start_run must be called before log_value."
230+
raise RuntimeError(msg)
231+
232+
for key, value in kwargs.items():
233+
if isinstance(value, int | float):
234+
self.module_writer.add_scalar(key, value)
235+
else:
236+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
237+
238+
def end_module(self) -> None:
239+
"""End a module."""
240+
if self.module_writer is None:
241+
msg = "start_run must be called before end_module."
242+
raise RuntimeError(msg)
243+
244+
self.module_writer.add_text("module_info", "Ending module") # type: ignore[no-untyped-call]
245+
self.module_writer.close() # type: ignore[no-untyped-call]
246+
247+
def end_run(self) -> None:
248+
pass
249+
250+
251+
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
252+
253+
254+
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
255+
"""
256+
Get the list of callbacks.
257+
258+
:param reporters: List of reporters to use.
259+
:return: Callback handler.
260+
"""
261+
if not reporters:
262+
return CallbackHandler()
263+
264+
reporters_cb = []
265+
for reporter in reporters:
266+
if reporter not in REPORTERS:
267+
msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}"
268+
raise ValueError(msg)
269+
reporters_cb.append(REPORTERS[reporter])
270+
return CallbackHandler(callbacks=reporters_cb)

autointent/_pipeline/_pipeline.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,19 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str) -> "
6262
"""
6363
Create pipeline optimizer from dictionary search space.
6464
65-
:param config: Dictionary config
65+
:param search_space: Dictionary config
6666
"""
6767
if isinstance(search_space, Path | str):
6868
search_space = load_search_space(search_space)
69-
if isinstance(search_space, list):
70-
nodes = [NodeOptimizer(**node) for node in search_space]
69+
nodes = [NodeOptimizer(**node) for node in search_space]
7170
return cls(nodes)
7271

7372
@classmethod
7473
def default_optimizer(cls, multilabel: bool) -> "Pipeline":
7574
"""
7675
Create pipeline optimizer with default search space for given classification task.
7776
78-
:param multilabel: Wether the task multi-label, or single-label.
77+
:param multilabel: Whether the task multi-label, or single-label.
7978
"""
8079
return cls.from_search_space(load_default_search_space(multilabel))
8180

@@ -87,13 +86,19 @@ def _fit(self, context: Context) -> None:
8786
"""
8887
self.context = context
8988
self._logger.info("starting pipeline optimization...")
89+
# TODO what's difference between self.context.logging_config and self.logging_config
90+
self.context.callback_handler.start_run(
91+
run_name=self.context.logging_config.get_run_name(),
92+
dirpath=self.context.logging_config.get_dirpath(),
93+
)
9094
for node_type in NodeType:
9195
node_optimizer = self.nodes.get(node_type, None)
9296
if node_optimizer is not None:
9397
node_optimizer.fit(context) # type: ignore[union-attr]
9498
if not context.vector_index_config.save_db:
9599
self._logger.info("removing vector database from file system...")
96100
context.vector_index_client.delete_db()
101+
self.context.callback_handler.end_run()
97102

98103
def _is_inference(self) -> bool:
99104
"""
@@ -109,6 +114,7 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False, init_for_inferen
109114
110115
:param dataset: Dataset for optimization
111116
:param force_multilabel: Whether to force multilabel or not
117+
:param init_for_inference: Whether to initialize pipeline for inference
112118
:return: Context
113119
"""
114120
if self._is_inference():

autointent/configs/_optimization_cli.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class LoggingConfig:
4545
"""Whether to dump the modules or not"""
4646
clear_ram: bool = False
4747
"""Whether to clear the RAM after dumping the modules"""
48+
report_to: list[str] | None = None
49+
"""List of callbacks to report to. If None, no callbacks will be used"""
4850

4951
def __post_init__(self) -> None:
5052
"""Define the run name, directory path and dump directory."""
@@ -63,6 +65,18 @@ def define_dirpath(self) -> None:
6365
raise ValueError
6466
self.dirpath = dirpath / self.run_name
6567

68+
def get_dirpath(self) -> Path:
69+
"""Get the directory path."""
70+
if self.dirpath is None:
71+
raise ValueError
72+
return self.dirpath
73+
74+
def get_run_name(self) -> str:
75+
"""Get the run name."""
76+
if self.run_name is None:
77+
raise ValueError
78+
return self.run_name
79+
6680
def define_dump_dir(self) -> None:
6781
"""Define the dump directory. If None, the modules will not be dumped."""
6882
if self.dump_dir is None:

autointent/context/_context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import yaml
99

1010
from autointent import Dataset
11+
from autointent._callbacks.base import CallbackHandler, get_callbacks
1112
from autointent.configs import (
1213
DataConfig,
1314
EmbedderConfig,
@@ -32,6 +33,7 @@ class Context:
3233
data_handler: DataHandler
3334
vector_index_client: VectorIndexClient
3435
optimization_info: OptimizationInfo
36+
callback_handler = CallbackHandler()
3537

3638
def __init__(self, seed: int = 42) -> None:
3739
"""
@@ -49,6 +51,7 @@ def configure_logging(self, config: LoggingConfig) -> None:
4951
:param config: Logging configuration settings.
5052
"""
5153
self.logging_config = config
54+
self.callback_handler = get_callbacks(config.report_to)
5255
self.optimization_info = OptimizationInfo()
5356

5457
def configure_vector_index(self, config: VectorIndexConfig, embedder_config: EmbedderConfig | None = None) -> None:

0 commit comments

Comments
 (0)