Skip to content

Commit cd6167a

Browse files
authored
log all metrics in callback (#142)
* log all metrics * log only essential metrics
1 parent 8fd616d commit cd6167a

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

autointent/_callbacks/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
from typing import Literal
2+
13
from autointent._callbacks.base import OptimizerCallback
24
from autointent._callbacks.callback_handler import CallbackHandler
35
from autointent._callbacks.tensorboard import TensorBoardCallback
46
from autointent._callbacks.wandb import WandbCallback
57

68
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
79

10+
REPORTERS_NAMES = list(REPORTERS.keys())
11+
812

913
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
1014
"""
@@ -26,6 +30,7 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
2630

2731

2832
__all__ = [
33+
"REPORTERS_NAMES",
2934
"CallbackHandler",
3035
"OptimizerCallback",
3136
"TensorBoardCallback",

autointent/_callbacks/wandb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
7979
name="final_metrics",
8080
config=metrics,
8181
)
82-
self.wandb.log(metrics)
82+
83+
self.wandb.log(metrics.get("pipeline_metrics", {}))
8384
self.wandb.finish()
8485

8586
def end_module(self) -> None:

autointent/_pipeline/_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def fit(
170170
context.data_handler.test_labels(),
171171
predictions,
172172
)
173-
context.callback_handler.log_final_metrics(context.optimization_info.pipeline_metrics)
173+
context.callback_handler.log_final_metrics(context.optimization_info.dump_evaluation_results())
174174

175175
return context
176176

autointent/configs/_optimization.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
from pathlib import Path
44

5-
from pydantic import BaseModel, Field, PositiveInt
5+
from pydantic import BaseModel, Field, PositiveInt, field_validator
66

7+
from autointent._callbacks import REPORTERS_NAMES
78
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme
89

910
from ._name import get_run_name
@@ -58,6 +59,18 @@ def dump_dir(self) -> Path:
5859
self._dump_dir = self.dirpath / "modules_dumps"
5960
return self._dump_dir
6061

62+
@field_validator("report_to")
63+
@classmethod
64+
def validate_report_to(cls, v: list[str] | None) -> list[str] | None:
65+
"""Validate the report_to field."""
66+
if v is None:
67+
return None
68+
for reporter in v:
69+
if reporter not in REPORTERS_NAMES:
70+
msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}"
71+
raise ValueError(msg)
72+
return v
73+
6174

6275
class VectorIndexConfig(BaseModel):
6376
"""Configuration for the vector index."""

0 commit comments

Comments
 (0)