diff --git a/autointent/_callbacks/__init__.py b/autointent/_callbacks/__init__.py index 1ad7437d7..200556eda 100644 --- a/autointent/_callbacks/__init__.py +++ b/autointent/_callbacks/__init__.py @@ -1,3 +1,5 @@ +from typing import Literal + from autointent._callbacks.base import OptimizerCallback from autointent._callbacks.callback_handler import CallbackHandler from autointent._callbacks.tensorboard import TensorBoardCallback @@ -5,6 +7,8 @@ REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]} +REPORTERS_NAMES = list(REPORTERS.keys()) + def get_callbacks(reporters: list[str] | None) -> CallbackHandler: """ @@ -26,6 +30,7 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler: __all__ = [ + "REPORTERS_NAMES", "CallbackHandler", "OptimizerCallback", "TensorBoardCallback", diff --git a/autointent/_callbacks/wandb.py b/autointent/_callbacks/wandb.py index b89c7c9ee..cf38c530c 100644 --- a/autointent/_callbacks/wandb.py +++ b/autointent/_callbacks/wandb.py @@ -79,7 +79,8 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None: name="final_metrics", config=metrics, ) - self.wandb.log(metrics) + + self.wandb.log(metrics.get("pipeline_metrics", {})) self.wandb.finish() def end_module(self) -> None: diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index c98fcb852..478d4435e 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -170,7 +170,7 @@ def fit( context.data_handler.test_labels(), predictions, ) - context.callback_handler.log_final_metrics(context.optimization_info.pipeline_metrics) + context.callback_handler.log_final_metrics(context.optimization_info.dump_evaluation_results()) return context diff --git a/autointent/configs/_optimization.py b/autointent/configs/_optimization.py index ba3d8093c..d174da827 100644 --- a/autointent/configs/_optimization.py +++ b/autointent/configs/_optimization.py @@ -2,8 +2,9 @@ from pathlib import Path -from pydantic import BaseModel, Field, PositiveInt +from pydantic import BaseModel, Field, PositiveInt, field_validator +from autointent._callbacks import REPORTERS_NAMES from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme from ._name import get_run_name @@ -58,6 +59,18 @@ def dump_dir(self) -> Path: self._dump_dir = self.dirpath / "modules_dumps" return self._dump_dir + @field_validator("report_to") + @classmethod + def validate_report_to(cls, v: list[str] | None) -> list[str] | None: + """Validate the report_to field.""" + if v is None: + return None + for reporter in v: + if reporter not in REPORTERS_NAMES: + msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}" + raise ValueError(msg) + return v + class VectorIndexConfig(BaseModel): """Configuration for the vector index."""