Skip to content

Commit 329af9e

Browse files
committed
fix sklearn
1 parent 426af0d commit 329af9e

File tree

9 files changed

+432
-84
lines changed

9 files changed

+432
-84
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
SearchSpaceValidationMode,
2626
)
2727
from autointent.metrics import DECISION_METRICS
28-
from autointent.nodes import InferenceNode, NodeOptimizer
28+
from autointent.nodes import InferenceNode, NodeOptimizer, OptimizationSearchSpaceConfig
2929
from autointent.utils import load_preset, load_search_space
3030

3131
from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput
@@ -94,7 +94,8 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
9494
"""
9595
if not isinstance(search_space, list):
9696
search_space = load_search_space(search_space)
97-
nodes = [NodeOptimizer(**node) for node in search_space]
97+
validated_search_space = OptimizationSearchSpaceConfig(search_space).model_dump() # type: ignore[arg-type]
98+
nodes = [NodeOptimizer(**node) for node in validated_search_space]
9899
return cls(nodes=nodes, seed=seed)
99100

100101
@classmethod

autointent/custom_types.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Annotated, Literal, TypeAlias
99

1010
from annotated_types import Interval
11+
from pydantic import BaseModel, Field
1112

1213

1314
class LogLevel(Enum):
@@ -83,3 +84,21 @@ class Split:
8384
SearchSpaceValidationMode = Literal["raise", "warning", "filter"]
8485

8586
SearchSpacePresets = Literal["light", "light_moderate", "light_extra", "heavy", "heavy_moderate", "heavy_extra"]
87+
88+
89+
class ParamSpaceInt(BaseModel):
90+
"""Param space for optimizing int parameters for Optuna."""
91+
92+
low: int = Field(..., description="Low boundary of the search space.")
93+
high: int = Field(..., description="High boundary of the search space.")
94+
step: int = Field(1, description="Step of the search space.")
95+
log: bool = Field(False, description="Whether to use a logarithmic scale.")
96+
97+
98+
class ParamSpaceFloat(BaseModel):
99+
"""Param space for optimizing float parameters for Optuna."""
100+
101+
low: float = Field(..., description="Low boundary of the search space.")
102+
high: float = Field(..., description="High boundary of the search space.")
103+
step: float | None = Field(None, description="Step of the search space.")
104+
log: bool = Field(False, description="Whether to use a logarithmic scale.")

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any
2+
from typing import Any, Literal
33

44
import numpy as np
55
import numpy.typing as npt
@@ -26,6 +26,8 @@
2626
if hasattr(class_, "predict_proba")
2727
}
2828

29+
AVAILABLE_CLASSIFIERS_NAMES = tuple(AVAILABLE_CLASSIFIERS.keys())
30+
2931

3032
class SklearnScorer(BaseScorer):
3133
"""
@@ -45,7 +47,7 @@ def __init__(
4547
self,
4648
clf_name: str,
4749
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
48-
**clf_args: Any, # noqa: ANN401
50+
**clf_args: dict[str, Any],
4951
) -> None:
5052
"""
5153
Initialize the SklearnScorer.
@@ -58,6 +60,9 @@ def __init__(
5860
self.clf_name = clf_name
5961

6062
if AVAILABLE_CLASSIFIERS.get(self.clf_name):
63+
if "clf_args" in clf_args:
64+
# during inference wrong save
65+
clf_args = clf_args["clf_args"]
6166
self._base_clf = AVAILABLE_CLASSIFIERS[self.clf_name](**clf_args)
6267
else:
6368
msg = f"Class {self.clf_name} does not exist in sklearn or does not have predict_proba method"
@@ -68,9 +73,9 @@ def __init__(
6873
def from_context(
6974
cls,
7075
context: Context,
71-
clf_name: str,
76+
clf_name: Literal[AVAILABLE_CLASSIFIERS_NAMES], # type: ignore[valid-type]
7277
embedder_config: EmbedderConfig | str | None = None,
73-
**clf_args: float | str | bool,
78+
clf_args: dict[str, int | float | str | bool | list[Any]] | None = None,
7479
) -> Self:
7580
"""
7681
Create a SklearnScorer instance using a Context object.
@@ -84,10 +89,13 @@ def from_context(
8489
if embedder_config is None:
8590
embedder_config = context.resolve_embedder()
8691

92+
if clf_args is None:
93+
clf_args = {}
94+
8795
return cls(
8896
embedder_config=embedder_config,
8997
clf_name=clf_name,
90-
**clf_args,
98+
**clf_args, # type: ignore[arg-type]
9199
)
92100

93101
def fit(

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,14 @@
1010
import optuna
1111
import torch
1212
from optuna.trial import Trial
13-
from pydantic import BaseModel, Field
1413
from typing_extensions import assert_never
1514

1615
from autointent import Dataset
1716
from autointent.context import Context
18-
from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode
17+
from autointent.custom_types import NodeType, ParamSpaceFloat, ParamSpaceInt, SamplerType, SearchSpaceValidationMode
1918
from autointent.nodes.info import NODES_INFO
2019

2120

22-
class ParamSpaceInt(BaseModel):
23-
low: int = Field(..., description="Low boundary of the search space.")
24-
high: int = Field(..., description="High boundary of the search space.")
25-
step: int = Field(1, description="Step of the search space.")
26-
log: bool = Field(False, description="Whether to use a logarithmic scale.")
27-
28-
29-
class ParamSpaceFloat(BaseModel):
30-
low: float = Field(..., description="Low boundary of the search space.")
31-
high: float = Field(..., description="High boundary of the search space.")
32-
step: float | None = Field(None, description="Step of the search space.")
33-
log: bool = Field(False, description="Whether to use a logarithmic scale.")
34-
35-
3621
class NodeOptimizer:
3722
"""Node optimizer class."""
3823

@@ -148,7 +133,7 @@ def objective(
148133

149134
return target_metric
150135

151-
def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]:
136+
def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]: # noqa: C901
152137
res: dict[str, Any] = {}
153138

154139
def is_valid_param_space(
@@ -167,6 +152,20 @@ def is_valid_param_space(
167152
res[param_name] = trial.suggest_int(param_name, **param_space)
168153
elif is_valid_param_space(param_space, ParamSpaceFloat):
169154
res[param_name] = trial.suggest_float(param_name, **param_space)
155+
elif isinstance(param_space, dict):
156+
# sklearn_scorer clf_args
157+
clf_args: dict[str, Any] = {}
158+
for k, v in param_space.items():
159+
if isinstance(v, list):
160+
clf_args[k] = trial.suggest_categorical(f"{param_name}_{k}", choices=v)
161+
elif is_valid_param_space(v, ParamSpaceInt):
162+
clf_args[k] = trial.suggest_int(f"{param_name}_{k}", **v)
163+
elif is_valid_param_space(v, ParamSpaceFloat):
164+
clf_args[k] = trial.suggest_float(f"{param_name}_{k}", **v)
165+
else:
166+
msg = f"Unsupported type of param search space: {v}"
167+
raise TypeError(msg)
168+
res["clf_args"] = clf_args
170169
else:
171170
msg = f"Unsupported type of param search space: {param_space}"
172171
raise TypeError(msg)

autointent/nodes/schemes.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
"""Schemes."""
22

3+
import functools
34
import inspect
5+
import operator
46
from collections.abc import Iterator
7+
from types import NoneType, UnionType
58
from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints
69

7-
from pydantic import BaseModel, Field, PositiveInt, RootModel
10+
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel
811

9-
from autointent.custom_types import NodeType
12+
from autointent.custom_types import NodeType, ParamSpaceFloat, ParamSpaceInt
1013
from autointent.modules.abc import BaseModule
11-
from autointent.nodes._optimization._node_optimizer import ParamSpaceFloat, ParamSpaceInt
1214
from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo
1315

1416

@@ -26,17 +28,26 @@ def type_matches(target: type, tp: type) -> bool:
2628
"""
2729
Recursively check if the target type is present in the given type.
2830
29-
This function handles union types by unwrapping Annotated types where necessary.
31+
This function handles union types and generic types (e.g. dict[...] by checking
32+
their origin) after unwrapping Annotated types.
3033
31-
:param target: Target type
32-
:param tp: Given type
33-
:return: If the target type is present in the given type
34+
:param target: Target type to check for.
35+
:param tp: Given type which may be a union, generic, or annotated type.
36+
:return: True if the target type is present in the given type.
3437
"""
3538
origin = get_origin(tp)
36-
37-
if origin is Union: # float | list[float]
39+
if origin is Union:
3840
return any(type_matches(target, arg) for arg in get_args(tp))
39-
return unwrap_annotated(tp) is target
41+
42+
# Unwrap Annotated types, if any.
43+
unwrapped = unwrap_annotated(tp)
44+
45+
# If the unwrapped type is a generic type, check its origin.
46+
generic_origin = get_origin(unwrapped)
47+
if generic_origin is not None:
48+
return generic_origin is target
49+
50+
return unwrapped is target
4051

4152

4253
def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] | None:
@@ -56,7 +67,12 @@ def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat]
5667
return None
5768

5869

59-
def generate_models_and_union_type_for_classes(
70+
def to_union(types: list[type]) -> type:
71+
"""Convert a tuple of types into a union type."""
72+
return functools.reduce(operator.or_, types)
73+
74+
75+
def generate_models_and_union_type_for_classes( # noqa: PLR0912, C901
6076
classes: list[type[BaseModule]],
6177
) -> type[BaseModel]:
6278
"""Dynamically generates Pydantic models for class constructors and creates a union type."""
@@ -70,6 +86,7 @@ def generate_models_and_union_type_for_classes(
7086
fields = {
7187
"module_name": (Literal[cls.name], Field(...)),
7288
"n_trials": (PositiveInt | None, Field(None, description="Number of trials")),
89+
"model_config": (ConfigDict, ConfigDict(extra="forbid")),
7390
}
7491

7592
for param_name, param in init_signature.parameters.items():
@@ -78,11 +95,33 @@ def generate_models_and_union_type_for_classes(
7895

7996
param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042
8097
field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...)
81-
search_type = get_optuna_class(param_type)
82-
if search_type is None:
83-
fields[param_name] = (list[param_type], field)
98+
if not type_matches(dict, param_type):
99+
search_type = get_optuna_class(param_type)
100+
if search_type is None:
101+
fields[param_name] = (list[param_type], field)
102+
else:
103+
fields[param_name] = (list[param_type] | search_type, field)
84104
else:
85-
fields[param_name] = (list[param_type] | search_type, field)
105+
dict_key_type, dict_values_types = get_args(param_type)
106+
is_optional = False
107+
if dict_values_types is NoneType: # if dict is optional
108+
is_optional = True
109+
dict_key_type, dict_values_types = get_args(dict_key_type)
110+
if get_origin(dict_values_types) is UnionType:
111+
filed_types: list[type[Any]] = []
112+
for value in get_args(dict_values_types):
113+
filed_types.append(list[value]) # type: ignore[valid-type]
114+
search_type = get_optuna_class(value)
115+
if search_type is not None:
116+
filed_types.append(search_type)
117+
filed_type = to_union(filed_types)
118+
else:
119+
filed_type = dict_values_types
120+
121+
if is_optional:
122+
fields[param_name] = (dict[dict_key_type, filed_type] | None, field) # type: ignore[valid-type]
123+
else:
124+
fields[param_name] = (dict[dict_key_type, filed_type], field) # type: ignore[valid-type]
86125

87126
model_name = f"{cls.__name__}InitModel"
88127
models[cls.__name__] = type(

0 commit comments

Comments
 (0)