Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions autointent/_optimization_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Any

from pydantic import BaseModel, PositiveInt

from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
from .custom_types import SamplerType
from .nodes.schemes import OptimizationSearchSpaceConfig


class OptimizationConfig(BaseModel):
"""Configuration for the optimization process."""

data_config: DataConfig = DataConfig()
search_space: OptimizationSearchSpaceConfig
search_space: list[dict[str, Any]]
logging_config: LoggingConfig = LoggingConfig()
embedder_config: EmbedderConfig = EmbedderConfig()
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()
Expand Down
2 changes: 1 addition & 1 deletion autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
optimization_config = OptimizationConfig(**dict_params)

pipeline = cls(
[NodeOptimizer(**node.model_dump()) for node in optimization_config.search_space],
[NodeOptimizer(**node) for node in optimization_config.search_space],
optimization_config.sampler,
optimization_config.seed,
)
Expand Down
2 changes: 1 addition & 1 deletion autointent/_presets/light_extra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ search_space:
thresh:
low: 0.1
high: 0.9
n_trials: 10
n_trials: 10
- module_name: argmax
sampler: random
4 changes: 4 additions & 0 deletions autointent/modules/decision/_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def __init__(self, search_space: list[FloatFromZeroToOne] | None = None) -> None
"""
self.search_space = search_space if search_space is not None else default_search_space

if any(val < 0 or val > 1 for val in self.search_space):
msg = "Unsupported items in `search_space` arg of `AdaptiveDecision` module"
raise ValueError(msg)

@classmethod
def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "AdaptiveDecision":
"""
Expand Down
4 changes: 4 additions & 0 deletions autointent/modules/decision/_jinoos.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(
"""
self.search_space = np.array(search_space) if search_space is not None else default_search_space

if any(val < 0 or val > 1 for val in self.search_space):
msg = "Items pf `search_space` of `AdaptiveDecision` module must be a floats from zero to one"
raise ValueError(msg)

@classmethod
def from_context(cls, context: Context, search_space: list[FloatFromZeroToOne] | None = None) -> "JinoosDecision":
"""
Expand Down
9 changes: 9 additions & 0 deletions autointent/modules/decision/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,16 @@ def __init__(

:param thresh: Threshold for the scores, shape (n_classes,) or float
"""
val_error = False
self.thresh = thresh if isinstance(thresh, float) else np.array(thresh)
if isinstance(thresh, float):
val_error = val_error or thresh < 0 or thresh > 1
else:
val_error = val_error or any(val < 0 or val > 1 for val in thresh)

if val_error:
msg = "`thresh` arg of `ThresholdDecision` must contain a float from zero to one (or list of floats)."
raise ValueError(msg)

@classmethod
def from_context(
Expand Down
10 changes: 9 additions & 1 deletion autointent/modules/decision/_tunable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tunable predictor module."""

from typing import Any, Literal
from typing import Any, Literal, get_args

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -96,6 +96,14 @@ def __init__(
self.seed = seed
self.tags = tags

if self.n_optuna_trials < 0 or not isinstance(self.n_optuna_trials, int):
msg = "Unsupported value for `n_optuna_trial` of `TunableDecision` module"
raise ValueError(msg)

if self.target_metric not in get_args(MetricType):
msg = "Unsupported value for `target_metric` of `TunableDecision` module"
raise TypeError(msg)

@classmethod
def from_context(
cls, context: Context, target_metric: MetricType = "decision_accuracy", n_optuna_trials: PositiveInt = 320
Expand Down
4 changes: 4 additions & 0 deletions autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def __init__(
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
self.cv = cv

if self.cv < 0 or not isinstance(self.cv, int):
msg = "`cv` argument of `LogregAimedEmbedding` must be a positive int"
raise ValueError(msg)

@classmethod
def from_context(
cls,
Expand Down
13 changes: 7 additions & 6 deletions autointent/modules/embedding/_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class RetrievalAimedEmbedding(BaseEmbedding):

def __init__(
self,
k: PositiveInt,
embedder_config: EmbedderConfig | str | dict[str, Any],
k: PositiveInt = 10,
) -> None:
"""
Initialize the RetrievalAimedEmbedding.
Expand All @@ -56,18 +56,19 @@ def __init__(
:param embedder_config: Config of the embedder used for creating embeddings.
"""
self.k = k
if isinstance(embedder_config, dict):
embedder_config = EmbedderConfig(**embedder_config)
if isinstance(embedder_config, str):
embedder_config = EmbedderConfig(model_name=embedder_config)
embedder_config = EmbedderConfig.from_search_config(embedder_config)
self.embedder_config = embedder_config

if self.k < 0 or not isinstance(self.k, int):
msg = "`k` argument of `RetrievalAimedEmbedding` must be a positive int"
raise ValueError(msg)

@classmethod
def from_context(
cls,
context: Context,
k: PositiveInt,
embedder_config: EmbedderConfig | str,
k: PositiveInt = 10,
) -> "RetrievalAimedEmbedding":
"""
Create an instance using a Context object.
Expand Down
4 changes: 4 additions & 0 deletions autointent/modules/scoring/_description/description.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(
self.temperature = temperature
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)

if self.temperature < 0 or not isinstance(self.temperature, float | int):
msg = "`temperature` argument of `DescriptionScorer` must be a positive float"
raise ValueError(msg)

@classmethod
def from_context(
cls,
Expand Down
4 changes: 4 additions & 0 deletions autointent/modules/scoring/_dnnc/dnnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def __init__(
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
self.k = k

if self.k < 0 or not isinstance(self.k, int):
msg = "`k` argument of `DNNCScorer` must be a positive int"
raise ValueError(msg)

@classmethod
def from_context(
cls,
Expand Down
10 changes: 9 additions & 1 deletion autointent/modules/scoring/_knn/knn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""KNNScorer class for k-nearest neighbors scoring."""

from typing import Any
from typing import Any, get_args

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -76,6 +76,14 @@ def __init__(
self.k = k
self.weights = weights

if self.k < 0 or not isinstance(self.k, int):
msg = "`k` argument of `KNNScorer` must be a positive int"
raise ValueError(msg)

if weights not in get_args(WEIGHT_TYPES):
msg = f"`weights` argument of `KNNScorer` must be a literal from a list: {get_args(WEIGHT_TYPES)}"
raise TypeError(msg)

@classmethod
def from_context(
cls,
Expand Down
10 changes: 10 additions & 0 deletions autointent/modules/scoring/_knn/rerank_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def __init__(
self.m = k if m is None else m
self.rank_threshold_cutoff = rank_threshold_cutoff

if self.m < 0 or not isinstance(self.m, int):
msg = "`m` argument of `RerankScorer` must be a positive int"
raise ValueError(msg)

if self.rank_threshold_cutoff is not None and (
self.rank_threshold_cutoff < 0 or not isinstance(self.rank_threshold_cutoff, int)
):
msg = "`rank_threshold_cutoff` argument of `RerankScorer` must be a positive int or None"
raise ValueError(msg)

@classmethod
def from_context(
cls,
Expand Down
8 changes: 5 additions & 3 deletions autointent/modules/scoring/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self,
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
cv: int = 3,
n_jobs: int | None = None,
seed: int = 0,
) -> None:
"""
Expand All @@ -67,10 +66,13 @@ def __init__(
:param seed: Random seed for reproducibility, defaults to 0.
"""
self.cv = cv
self.n_jobs = n_jobs
self.seed = seed
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)

if self.cv < 0 or not isinstance(self.cv, int):
msg = "`cv` argument of `LinearScorer` must be a positive int"
raise ValueError(msg)

@classmethod
def from_context(
cls,
Expand Down Expand Up @@ -125,7 +127,7 @@ def fit(
base_clf = LogisticRegression()
clf = MultiOutputClassifier(base_clf)
else:
clf = LogisticRegressionCV(cv=self.cv, n_jobs=self.n_jobs, random_state=self.seed)
clf = LogisticRegressionCV(cv=self.cv, random_state=self.seed)

clf.fit(features, labels)

Expand Down
8 changes: 8 additions & 0 deletions autointent/modules/scoring/_mlknn/mlknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from numpy.typing import NDArray
from pydantic import NonNegativeInt, PositiveFloat, PositiveInt
from typing_extensions import assert_never

from autointent import Context, VectorIndex
from autointent.configs import EmbedderConfig
Expand Down Expand Up @@ -77,6 +78,13 @@ def __init__(
self.s = s
self.ignore_first_neighbours = ignore_first_neighbours

if self.k < 0 or not isinstance(self.k, int):
msg = "`k` argument of `MLKnnScorer` must be a positive int"
raise ValueError(msg)

if not isinstance(self.s, float | int):
assert_never(self.s)

@classmethod
def from_context(
cls,
Expand Down
5 changes: 3 additions & 2 deletions autointent/modules/scoring/_sklearn/sklearn_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def __init__(
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
self.clf_name = clf_name

if AVAILABLE_CLASSIFIERS.get(self.clf_name):
self._base_clf = AVAILABLE_CLASSIFIERS[self.clf_name](**clf_args)
clf_type = AVAILABLE_CLASSIFIERS.get(self.clf_name, None)
if clf_type:
self._base_clf = clf_type(**clf_args)
else:
msg = f"Class {self.clf_name} does not exist in sklearn or does not have predict_proba method"
logger.error(msg)
Expand Down
2 changes: 0 additions & 2 deletions autointent/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

from ._inference_node import InferenceNode
from ._optimization import NodeOptimizer
from .schemes import OptimizationSearchSpaceConfig

__all__ = [
"InferenceNode",
"NodeOptimizer",
"OptimizationSearchSpaceConfig",
]
64 changes: 52 additions & 12 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Node optimizer."""

import gc
import itertools as it
import logging
from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -33,6 +34,9 @@ class ParamSpaceFloat(BaseModel):
log: bool = Field(False, description="Whether to use a logarithmic scale.")


logger = logging.getLogger(__name__)


class NodeOptimizer:
"""Node optimizer class."""

Expand All @@ -50,6 +54,7 @@ def __init__(
:param search_space: Search space for the optimization
:param metrics: Metrics to optimize.
"""
self._logger = logger
self.node_type = node_type
self.node_info = NODES_INFO[node_type]
self.target_metric = target_metric
Expand All @@ -58,8 +63,8 @@ def __init__(
if self.target_metric not in self.metrics:
self.metrics.append(self.target_metric)

self.validate_search_space(search_space)
self.modules_search_spaces = search_space
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem

def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
"""
Expand Down Expand Up @@ -151,27 +156,27 @@ def objective(
def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]:
res: dict[str, Any] = {}

def is_valid_param_space(
param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat]
) -> bool:
try:
space_type(**param_space)
return True # noqa: TRY300
except ValueError:
return False

for param_name, param_space in search_space.items():
if isinstance(param_space, list):
res[param_name] = trial.suggest_categorical(param_name, choices=param_space)
elif is_valid_param_space(param_space, ParamSpaceInt):
elif self._is_valid_param_space(param_space, ParamSpaceInt):
res[param_name] = trial.suggest_int(param_name, **param_space)
elif is_valid_param_space(param_space, ParamSpaceFloat):
elif self._is_valid_param_space(param_space, ParamSpaceFloat):
res[param_name] = trial.suggest_float(param_name, **param_space)
else:
msg = f"Unsupported type of param search space: {param_space}"
raise TypeError(msg)
return res

def _is_valid_param_space(
self, param_space: dict[str, Any], space_type: type[ParamSpaceInt | ParamSpaceFloat]
) -> bool:
try:
space_type(**param_space)
return True # noqa: TRY300
except ValueError:
return False

def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str:
"""
Get module dump directory.
Expand Down Expand Up @@ -222,3 +227,38 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat
filtered_search_space.append(search_space)

self.modules_search_spaces = filtered_search_space

def validate_search_space(self, search_space: list[dict[str, Any]]) -> None:
"""Check if search space is configured correctly."""
for module_search_space in search_space:
module_search_space_no_optuna, module_name = self._reformat_search_space(deepcopy(module_search_space))

for params_combination in it.product(*module_search_space_no_optuna.values()):
module_kwargs = dict(zip(module_search_space_no_optuna.keys(), params_combination, strict=False))

self._logger.debug("validating %s module...", module_name, extra=module_kwargs)
module = self.node_info.modules_available[module_name](**module_kwargs)
self._logger.debug("%s is ok", module_name)

del module
gc.collect()

def _reformat_search_space(self, module_search_space: dict[str, Any]) -> tuple[dict[str, Any], str]:
"""Remove optuna notation from search space."""
res = {}
module_name = module_search_space.pop("module_name")

for param_name, param_space in module_search_space.items():
if param_name == "n_trials":
continue
if isinstance(param_space, list):
res[param_name] = param_space
elif self._is_valid_param_space(param_space, ParamSpaceInt) or self._is_valid_param_space(
param_space, ParamSpaceFloat
):
res[param_name] = [param_space["low"], param_space["high"]]
else:
msg = f"Unsupported type of param search space: {param_space}"
raise TypeError(msg)

return res, module_name
Loading
Loading