|
1 | 1 | """Base class for reporters (W&B, TensorBoard, etc).""" |
2 | 2 |
|
3 | | -import os |
4 | 3 | from abc import ABC, abstractmethod |
5 | 4 | from pathlib import Path |
6 | 5 | from typing import Any |
@@ -58,264 +57,3 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None: |
58 | 57 |
|
59 | 58 | :param metrics: Final metrics. |
60 | 59 | """ |
61 | | - |
62 | | - |
63 | | -class CallbackHandler(OptimizerCallback): |
64 | | - """Internal class that just calls the list of callbacks in order.""" |
65 | | - |
66 | | - callbacks: list[OptimizerCallback] |
67 | | - |
68 | | - def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None: |
69 | | - """Initialize the callback handler.""" |
70 | | - if not callbacks: |
71 | | - self.callbacks = [] |
72 | | - return |
73 | | - |
74 | | - self.callbacks = [cb() for cb in callbacks] |
75 | | - |
76 | | - def start_run(self, run_name: str, dirpath: Path) -> None: |
77 | | - """ |
78 | | - Start a new run. |
79 | | -
|
80 | | - :param run_name: Name of the run. |
81 | | - :param dirpath: Path to the directory where the logs will be saved. |
82 | | - """ |
83 | | - self.call_events("start_run", run_name=run_name, dirpath=dirpath) |
84 | | - |
85 | | - def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: |
86 | | - """ |
87 | | - Start a new module. |
88 | | -
|
89 | | - :param module_name: Name of the module. |
90 | | - :param num: Number of the module. |
91 | | - :param module_kwargs: Module parameters. |
92 | | - """ |
93 | | - self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs) |
94 | | - |
95 | | - def log_value(self, **kwargs: dict[str, Any]) -> None: |
96 | | - """ |
97 | | - Log data. |
98 | | -
|
99 | | - :param kwargs: Data to log. |
100 | | - """ |
101 | | - self.call_events("log_value", **kwargs) |
102 | | - |
103 | | - def end_module(self) -> None: |
104 | | - """End a module.""" |
105 | | - self.call_events("end_module") |
106 | | - |
107 | | - def end_run(self) -> None: |
108 | | - """End a run.""" |
109 | | - self.call_events("end_run") |
110 | | - |
111 | | - def log_final_metrics(self, metrics: dict[str, Any]) -> None: |
112 | | - """ |
113 | | - Log final metrics. |
114 | | -
|
115 | | - :param metrics: Final metrics. |
116 | | - """ |
117 | | - self.call_events("log_final_metrics", metrics=metrics) |
118 | | - |
119 | | - def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401 |
120 | | - for callback in self.callbacks: |
121 | | - getattr(callback, event)(**kwargs) |
122 | | - |
123 | | - |
124 | | -class WandbCallback(OptimizerCallback): |
125 | | - """ |
126 | | - Wandb callback. |
127 | | -
|
128 | | - This callback logs the optimization process to W&B. |
129 | | - To specify the project name, set the `WANDB_PROJECT` environment variable. Default is `autointent`. |
130 | | - """ |
131 | | - |
132 | | - name = "wandb" |
133 | | - |
134 | | - def __init__(self) -> None: |
135 | | - """Initialize the callback.""" |
136 | | - try: |
137 | | - import wandb |
138 | | - except ImportError: |
139 | | - msg = "Please install wandb to use this callback. `pip install wandb`" |
140 | | - raise ImportError(msg) from None |
141 | | - |
142 | | - self.wandb = wandb |
143 | | - |
144 | | - def start_run(self, run_name: str, dirpath: Path) -> None: |
145 | | - """ |
146 | | - Start a new run. |
147 | | -
|
148 | | - :param run_name: Name of the run. |
149 | | - :param dirpath: Path to the directory where the logs will be saved. (Not used for this callback) |
150 | | - """ |
151 | | - self.project_name = os.getenv("WANDB_PROJECT", "autointent") |
152 | | - self.group = run_name |
153 | | - self.dirpath = dirpath |
154 | | - |
155 | | - def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: |
156 | | - """ |
157 | | - Start a new module. |
158 | | -
|
159 | | - :param module_name: Name of the module. |
160 | | - :param num: Number of the module. |
161 | | - :param module_kwargs: Module parameters. |
162 | | - """ |
163 | | - self.wandb.init( |
164 | | - project=self.project_name, |
165 | | - group=self.group, |
166 | | - name=f"{module_name}_{num}", |
167 | | - config=module_kwargs, |
168 | | - ) |
169 | | - |
170 | | - def log_value(self, **kwargs: dict[str, Any]) -> None: |
171 | | - """ |
172 | | - Log data. |
173 | | -
|
174 | | - :param kwargs: Data to log. |
175 | | - """ |
176 | | - self.wandb.log(kwargs) |
177 | | - |
178 | | - def log_final_metrics(self, metrics: dict[str, Any]) -> None: |
179 | | - """ |
180 | | - Log final metrics. |
181 | | -
|
182 | | - :param metrics: Final metrics. |
183 | | - """ |
184 | | - self.wandb.init( |
185 | | - project=self.project_name, |
186 | | - group=self.group, |
187 | | - name="final_metrics", |
188 | | - config=metrics, |
189 | | - ) |
190 | | - self.wandb.log(metrics) |
191 | | - self.wandb.finish() |
192 | | - |
193 | | - def end_module(self) -> None: |
194 | | - """End a module.""" |
195 | | - self.wandb.finish() |
196 | | - |
197 | | - def end_run(self) -> None: |
198 | | - pass |
199 | | - |
200 | | - |
201 | | -class TensorBoardCallback(OptimizerCallback): |
202 | | - """ |
203 | | - TensorBoard callback. |
204 | | -
|
205 | | - This callback logs the optimization process to TensorBoard. |
206 | | - """ |
207 | | - |
208 | | - name = "tensorboard" |
209 | | - |
210 | | - def __init__(self) -> None: |
211 | | - """Initialize the callback.""" |
212 | | - try: |
213 | | - from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined] |
214 | | - |
215 | | - self.writer = SummaryWriter |
216 | | - except ImportError: |
217 | | - try: |
218 | | - from tensorboardX import SummaryWriter # type: ignore[no-redef] |
219 | | - |
220 | | - self.writer = SummaryWriter |
221 | | - except ImportError: |
222 | | - msg = ( |
223 | | - "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or" |
224 | | - " install tensorboardX." |
225 | | - ) |
226 | | - raise ImportError(msg) from None |
227 | | - |
228 | | - def start_run(self, run_name: str, dirpath: Path) -> None: |
229 | | - """ |
230 | | - Start a new run. |
231 | | -
|
232 | | - :param run_name: Name of the run. |
233 | | - :param dirpath: Path to the directory where the logs will be saved. |
234 | | - """ |
235 | | - self.run_name = run_name |
236 | | - self.dirpath = dirpath |
237 | | - |
238 | | - def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: |
239 | | - """ |
240 | | - Start a new module. |
241 | | -
|
242 | | - :param module_name: Name of the module. |
243 | | - :param num: Number of the module. |
244 | | - :param module_kwargs: Module parameters. |
245 | | - """ |
246 | | - module_run_name = f"{self.run_name}_{module_name}_{num}" |
247 | | - log_dir = Path(self.dirpath) / module_run_name |
248 | | - self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call] |
249 | | - |
250 | | - self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}") # type: ignore[no-untyped-call] |
251 | | - for key, value in module_kwargs.items(): |
252 | | - self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call] |
253 | | - |
254 | | - def log_value(self, **kwargs: dict[str, Any]) -> None: |
255 | | - """ |
256 | | - Log data. |
257 | | -
|
258 | | - :param kwargs: Data to log. |
259 | | - """ |
260 | | - if self.module_writer is None: |
261 | | - msg = "start_run must be called before log_value." |
262 | | - raise RuntimeError(msg) |
263 | | - |
264 | | - for key, value in kwargs.items(): |
265 | | - if isinstance(value, int | float): |
266 | | - self.module_writer.add_scalar(key, value) |
267 | | - else: |
268 | | - self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call] |
269 | | - |
270 | | - def log_final_metrics(self, metrics: dict[str, Any]) -> None: |
271 | | - """ |
272 | | - Log final metrics. |
273 | | -
|
274 | | - :param metrics: Final metrics. |
275 | | - """ |
276 | | - if self.module_writer is None: |
277 | | - msg = "start_run must be called before log_final_metrics." |
278 | | - raise RuntimeError(msg) |
279 | | - |
280 | | - log_dir = Path(self.dirpath) / "final_metrics" |
281 | | - self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call] |
282 | | - |
283 | | - for key, value in metrics.items(): |
284 | | - if isinstance(value, int | float): |
285 | | - self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call] |
286 | | - else: |
287 | | - self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call] |
288 | | - |
289 | | - def end_module(self) -> None: |
290 | | - """End a module.""" |
291 | | - if self.module_writer is None: |
292 | | - msg = "start_run must be called before end_module." |
293 | | - raise RuntimeError(msg) |
294 | | - |
295 | | - self.module_writer.add_text("module_info", "Ending module") # type: ignore[no-untyped-call] |
296 | | - self.module_writer.close() # type: ignore[no-untyped-call] |
297 | | - |
298 | | - def end_run(self) -> None: |
299 | | - pass |
300 | | - |
301 | | - |
302 | | -REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]} |
303 | | - |
304 | | - |
305 | | -def get_callbacks(reporters: list[str] | None) -> CallbackHandler: |
306 | | - """ |
307 | | - Get the list of callbacks. |
308 | | -
|
309 | | - :param reporters: List of reporters to use. |
310 | | - :return: Callback handler. |
311 | | - """ |
312 | | - if not reporters: |
313 | | - return CallbackHandler() |
314 | | - |
315 | | - reporters_cb = [] |
316 | | - for reporter in reporters: |
317 | | - if reporter not in REPORTERS: |
318 | | - msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}" |
319 | | - raise ValueError(msg) |
320 | | - reporters_cb.append(REPORTERS[reporter]) |
321 | | - return CallbackHandler(callbacks=reporters_cb) |
0 commit comments