Skip to content

Commit f5afdb5

Browse files
committed
move to separate files
1 parent 694aed2 commit f5afdb5

File tree

8 files changed

+288
-265
lines changed

8 files changed

+288
-265
lines changed

autointent/_callbacks/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from autointent._callbacks.base import OptimizerCallback
2+
from autointent._callbacks.callback_handler import CallbackHandler
3+
from autointent._callbacks.tensorboard import TensorBoardCallback
4+
from autointent._callbacks.wandb import WandbCallback
5+
6+
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
7+
8+
9+
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
10+
"""
11+
Get the list of callbacks.
12+
13+
:param reporters: List of reporters to use.
14+
:return: Callback handler.
15+
"""
16+
if not reporters:
17+
return CallbackHandler()
18+
19+
reporters_cb = []
20+
for reporter in reporters:
21+
if reporter not in REPORTERS:
22+
msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}"
23+
raise ValueError(msg)
24+
reporters_cb.append(REPORTERS[reporter])
25+
return CallbackHandler(callbacks=reporters_cb)
26+
27+
28+
__all__ = [
29+
"CallbackHandler",
30+
"OptimizerCallback",
31+
"TensorBoardCallback",
32+
"WandbCallback",
33+
"get_callbacks",
34+
]

autointent/_callbacks/base.py

Lines changed: 0 additions & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Base class for reporters (W&B, TensorBoard, etc)."""
22

3-
import os
43
from abc import ABC, abstractmethod
54
from pathlib import Path
65
from typing import Any
@@ -58,264 +57,3 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
5857
5958
:param metrics: Final metrics.
6059
"""
61-
62-
63-
class CallbackHandler(OptimizerCallback):
64-
"""Internal class that just calls the list of callbacks in order."""
65-
66-
callbacks: list[OptimizerCallback]
67-
68-
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
69-
"""Initialize the callback handler."""
70-
if not callbacks:
71-
self.callbacks = []
72-
return
73-
74-
self.callbacks = [cb() for cb in callbacks]
75-
76-
def start_run(self, run_name: str, dirpath: Path) -> None:
77-
"""
78-
Start a new run.
79-
80-
:param run_name: Name of the run.
81-
:param dirpath: Path to the directory where the logs will be saved.
82-
"""
83-
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
84-
85-
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
86-
"""
87-
Start a new module.
88-
89-
:param module_name: Name of the module.
90-
:param num: Number of the module.
91-
:param module_kwargs: Module parameters.
92-
"""
93-
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
94-
95-
def log_value(self, **kwargs: dict[str, Any]) -> None:
96-
"""
97-
Log data.
98-
99-
:param kwargs: Data to log.
100-
"""
101-
self.call_events("log_value", **kwargs)
102-
103-
def end_module(self) -> None:
104-
"""End a module."""
105-
self.call_events("end_module")
106-
107-
def end_run(self) -> None:
108-
"""End a run."""
109-
self.call_events("end_run")
110-
111-
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
112-
"""
113-
Log final metrics.
114-
115-
:param metrics: Final metrics.
116-
"""
117-
self.call_events("log_final_metrics", metrics=metrics)
118-
119-
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
120-
for callback in self.callbacks:
121-
getattr(callback, event)(**kwargs)
122-
123-
124-
class WandbCallback(OptimizerCallback):
125-
"""
126-
Wandb callback.
127-
128-
This callback logs the optimization process to W&B.
129-
To specify the project name, set the `WANDB_PROJECT` environment variable. Default is `autointent`.
130-
"""
131-
132-
name = "wandb"
133-
134-
def __init__(self) -> None:
135-
"""Initialize the callback."""
136-
try:
137-
import wandb
138-
except ImportError:
139-
msg = "Please install wandb to use this callback. `pip install wandb`"
140-
raise ImportError(msg) from None
141-
142-
self.wandb = wandb
143-
144-
def start_run(self, run_name: str, dirpath: Path) -> None:
145-
"""
146-
Start a new run.
147-
148-
:param run_name: Name of the run.
149-
:param dirpath: Path to the directory where the logs will be saved. (Not used for this callback)
150-
"""
151-
self.project_name = os.getenv("WANDB_PROJECT", "autointent")
152-
self.group = run_name
153-
self.dirpath = dirpath
154-
155-
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
156-
"""
157-
Start a new module.
158-
159-
:param module_name: Name of the module.
160-
:param num: Number of the module.
161-
:param module_kwargs: Module parameters.
162-
"""
163-
self.wandb.init(
164-
project=self.project_name,
165-
group=self.group,
166-
name=f"{module_name}_{num}",
167-
config=module_kwargs,
168-
)
169-
170-
def log_value(self, **kwargs: dict[str, Any]) -> None:
171-
"""
172-
Log data.
173-
174-
:param kwargs: Data to log.
175-
"""
176-
self.wandb.log(kwargs)
177-
178-
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
179-
"""
180-
Log final metrics.
181-
182-
:param metrics: Final metrics.
183-
"""
184-
self.wandb.init(
185-
project=self.project_name,
186-
group=self.group,
187-
name="final_metrics",
188-
config=metrics,
189-
)
190-
self.wandb.log(metrics)
191-
self.wandb.finish()
192-
193-
def end_module(self) -> None:
194-
"""End a module."""
195-
self.wandb.finish()
196-
197-
def end_run(self) -> None:
198-
pass
199-
200-
201-
class TensorBoardCallback(OptimizerCallback):
202-
"""
203-
TensorBoard callback.
204-
205-
This callback logs the optimization process to TensorBoard.
206-
"""
207-
208-
name = "tensorboard"
209-
210-
def __init__(self) -> None:
211-
"""Initialize the callback."""
212-
try:
213-
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
214-
215-
self.writer = SummaryWriter
216-
except ImportError:
217-
try:
218-
from tensorboardX import SummaryWriter # type: ignore[no-redef]
219-
220-
self.writer = SummaryWriter
221-
except ImportError:
222-
msg = (
223-
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
224-
" install tensorboardX."
225-
)
226-
raise ImportError(msg) from None
227-
228-
def start_run(self, run_name: str, dirpath: Path) -> None:
229-
"""
230-
Start a new run.
231-
232-
:param run_name: Name of the run.
233-
:param dirpath: Path to the directory where the logs will be saved.
234-
"""
235-
self.run_name = run_name
236-
self.dirpath = dirpath
237-
238-
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
239-
"""
240-
Start a new module.
241-
242-
:param module_name: Name of the module.
243-
:param num: Number of the module.
244-
:param module_kwargs: Module parameters.
245-
"""
246-
module_run_name = f"{self.run_name}_{module_name}_{num}"
247-
log_dir = Path(self.dirpath) / module_run_name
248-
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
249-
250-
self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}") # type: ignore[no-untyped-call]
251-
for key, value in module_kwargs.items():
252-
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
253-
254-
def log_value(self, **kwargs: dict[str, Any]) -> None:
255-
"""
256-
Log data.
257-
258-
:param kwargs: Data to log.
259-
"""
260-
if self.module_writer is None:
261-
msg = "start_run must be called before log_value."
262-
raise RuntimeError(msg)
263-
264-
for key, value in kwargs.items():
265-
if isinstance(value, int | float):
266-
self.module_writer.add_scalar(key, value)
267-
else:
268-
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
269-
270-
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
271-
"""
272-
Log final metrics.
273-
274-
:param metrics: Final metrics.
275-
"""
276-
if self.module_writer is None:
277-
msg = "start_run must be called before log_final_metrics."
278-
raise RuntimeError(msg)
279-
280-
log_dir = Path(self.dirpath) / "final_metrics"
281-
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
282-
283-
for key, value in metrics.items():
284-
if isinstance(value, int | float):
285-
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
286-
else:
287-
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
288-
289-
def end_module(self) -> None:
290-
"""End a module."""
291-
if self.module_writer is None:
292-
msg = "start_run must be called before end_module."
293-
raise RuntimeError(msg)
294-
295-
self.module_writer.add_text("module_info", "Ending module") # type: ignore[no-untyped-call]
296-
self.module_writer.close() # type: ignore[no-untyped-call]
297-
298-
def end_run(self) -> None:
299-
pass
300-
301-
302-
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
303-
304-
305-
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
306-
"""
307-
Get the list of callbacks.
308-
309-
:param reporters: List of reporters to use.
310-
:return: Callback handler.
311-
"""
312-
if not reporters:
313-
return CallbackHandler()
314-
315-
reporters_cb = []
316-
for reporter in reporters:
317-
if reporter not in REPORTERS:
318-
msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}"
319-
raise ValueError(msg)
320-
reporters_cb.append(REPORTERS[reporter])
321-
return CallbackHandler(callbacks=reporters_cb)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
from autointent._callbacks.base import OptimizerCallback
5+
6+
7+
class CallbackHandler(OptimizerCallback):
8+
"""Internal class that just calls the list of callbacks in order."""
9+
10+
callbacks: list[OptimizerCallback]
11+
12+
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
13+
"""Initialize the callback handler."""
14+
if not callbacks:
15+
self.callbacks = []
16+
return
17+
18+
self.callbacks = [cb() for cb in callbacks]
19+
20+
def start_run(self, run_name: str, dirpath: Path) -> None:
21+
"""
22+
Start a new run.
23+
24+
:param run_name: Name of the run.
25+
:param dirpath: Path to the directory where the logs will be saved.
26+
"""
27+
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
28+
29+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
30+
"""
31+
Start a new module.
32+
33+
:param module_name: Name of the module.
34+
:param num: Number of the module.
35+
:param module_kwargs: Module parameters.
36+
"""
37+
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
38+
39+
def log_value(self, **kwargs: dict[str, Any]) -> None:
40+
"""
41+
Log data.
42+
43+
:param kwargs: Data to log.
44+
"""
45+
self.call_events("log_value", **kwargs)
46+
47+
def end_module(self) -> None:
48+
"""End a module."""
49+
self.call_events("end_module")
50+
51+
def end_run(self) -> None:
52+
"""End a run."""
53+
self.call_events("end_run")
54+
55+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
56+
"""
57+
Log final metrics.
58+
59+
:param metrics: Final metrics.
60+
"""
61+
self.call_events("log_final_metrics", metrics=metrics)
62+
63+
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
64+
for callback in self.callbacks:
65+
getattr(callback, event)(**kwargs)

0 commit comments

Comments
 (0)