Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions autointent/_callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Literal

from autointent._callbacks.base import OptimizerCallback
from autointent._callbacks.callback_handler import CallbackHandler
from autointent._callbacks.tensorboard import TensorBoardCallback
from autointent._callbacks.wandb import WandbCallback

REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}

REPORTERS_NAMES = list(REPORTERS.keys())


def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
"""
Expand All @@ -26,6 +30,7 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler:


__all__ = [
"REPORTERS_NAMES",
"CallbackHandler",
"OptimizerCallback",
"TensorBoardCallback",
Expand Down
3 changes: 2 additions & 1 deletion autointent/_callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
name="final_metrics",
config=metrics,
)
self.wandb.log(metrics)

self.wandb.log(metrics.get("pipeline_metrics", {}))
self.wandb.finish()

def end_module(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def fit(
context.data_handler.test_labels(),
predictions,
)
context.callback_handler.log_final_metrics(context.optimization_info.pipeline_metrics)
context.callback_handler.log_final_metrics(context.optimization_info.dump_evaluation_results())

return context

Expand Down
15 changes: 14 additions & 1 deletion autointent/configs/_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from pathlib import Path

from pydantic import BaseModel, Field, PositiveInt
from pydantic import BaseModel, Field, PositiveInt, field_validator

from autointent._callbacks import REPORTERS_NAMES
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme

from ._name import get_run_name
Expand Down Expand Up @@ -58,6 +59,18 @@ def dump_dir(self) -> Path:
self._dump_dir = self.dirpath / "modules_dumps"
return self._dump_dir

@field_validator("report_to")
@classmethod
def validate_report_to(cls, v: list[str] | None) -> list[str] | None:
"""Validate the report_to field."""
if v is None:
return None
for reporter in v:
if reporter not in REPORTERS_NAMES:
msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}"
raise ValueError(msg)
return v


class VectorIndexConfig(BaseModel):
"""Configuration for the vector index."""
Expand Down