diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index d433616e7..450920725 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -25,7 +25,7 @@ SearchSpaceValidationMode, ) from autointent.metrics import DECISION_METRICS -from autointent.nodes import InferenceNode, NodeOptimizer +from autointent.nodes import InferenceNode, NodeOptimizer, OptimizationSearchSpaceConfig from autointent.utils import load_preset, load_search_space from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput @@ -94,7 +94,8 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed """ if not isinstance(search_space, list): search_space = load_search_space(search_space) - nodes = [NodeOptimizer(**node) for node in search_space] + validated_search_space = OptimizationSearchSpaceConfig(search_space).model_dump() # type: ignore[arg-type] + nodes = [NodeOptimizer(**node) for node in validated_search_space] return cls(nodes=nodes, seed=seed) @classmethod diff --git a/autointent/_presets/light_extra.yaml b/autointent/_presets/light_extra.yaml index 4d5bb51ff..cb8396391 100644 --- a/autointent/_presets/light_extra.yaml +++ b/autointent/_presets/light_extra.yaml @@ -21,6 +21,6 @@ search_space: thresh: low: 0.1 high: 0.9 - n_trials: 10 + n_trials: 10 - module_name: argmax sampler: random \ No newline at end of file diff --git a/autointent/custom_types.py b/autointent/custom_types.py index d2b360717..d74f15e84 100644 --- a/autointent/custom_types.py +++ b/autointent/custom_types.py @@ -8,6 +8,7 @@ from typing import Annotated, Literal, TypeAlias from annotated_types import Interval +from pydantic import BaseModel, Field class LogLevel(Enum): @@ -83,3 +84,21 @@ class Split: SearchSpaceValidationMode = Literal["raise", "warning", "filter"] SearchSpacePresets = Literal["light", "light_moderate", "light_extra", "heavy", "heavy_moderate", "heavy_extra"] + + +class ParamSpaceInt(BaseModel): + """Param space for optimizing int parameters for Optuna.""" + + low: int = Field(..., description="Low boundary of the search space.") + high: int = Field(..., description="High boundary of the search space.") + step: int = Field(1, description="Step of the search space.") + log: bool = Field(False, description="Whether to use a logarithmic scale.") + + +class ParamSpaceFloat(BaseModel): + """Param space for optimizing float parameters for Optuna.""" + + low: float = Field(..., description="Low boundary of the search space.") + high: float = Field(..., description="High boundary of the search space.") + step: float | None = Field(None, description="Step of the search space.") + log: bool = Field(False, description="Whether to use a logarithmic scale.") diff --git a/autointent/modules/scoring/_sklearn/sklearn_scorer.py b/autointent/modules/scoring/_sklearn/sklearn_scorer.py index 7b8b6620e..d3fd529fc 100644 --- a/autointent/modules/scoring/_sklearn/sklearn_scorer.py +++ b/autointent/modules/scoring/_sklearn/sklearn_scorer.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Literal import numpy as np import numpy.typing as npt @@ -26,6 +26,8 @@ if hasattr(class_, "predict_proba") } +AVAILABLE_CLASSIFIERS_NAMES = tuple(AVAILABLE_CLASSIFIERS.keys()) + class SklearnScorer(BaseScorer): """ @@ -43,9 +45,9 @@ class SklearnScorer(BaseScorer): def __init__( self, - clf_name: str, + clf_name: Literal[AVAILABLE_CLASSIFIERS_NAMES], # type: ignore[valid-type] embedder_config: EmbedderConfig | str | dict[str, Any] | None = None, - **clf_args: Any, # noqa: ANN401 + **clf_args: dict[str, Any], ) -> None: """ Initialize the SklearnScorer. @@ -58,6 +60,9 @@ def __init__( self.clf_name = clf_name if AVAILABLE_CLASSIFIERS.get(self.clf_name): + if "clf_args" in clf_args: + # during inference wrong save + clf_args = clf_args["clf_args"] self._base_clf = AVAILABLE_CLASSIFIERS[self.clf_name](**clf_args) else: msg = f"Class {self.clf_name} does not exist in sklearn or does not have predict_proba method" @@ -68,9 +73,9 @@ def __init__( def from_context( cls, context: Context, - clf_name: str, + clf_name: Literal[AVAILABLE_CLASSIFIERS_NAMES], # type: ignore[valid-type] embedder_config: EmbedderConfig | str | None = None, - **clf_args: float | str | bool, + clf_args: dict[str, int | float | str | bool | list[Any]] | None = None, ) -> Self: """ Create a SklearnScorer instance using a Context object. @@ -84,10 +89,13 @@ def from_context( if embedder_config is None: embedder_config = context.resolve_embedder() + if clf_args is None: + clf_args = {} + return cls( embedder_config=embedder_config, clf_name=clf_name, - **clf_args, + **clf_args, # type: ignore[arg-type] ) def fit( diff --git a/autointent/nodes/_optimization/_node_optimizer.py b/autointent/nodes/_optimization/_node_optimizer.py index 60adfbd1d..1130f43e6 100644 --- a/autointent/nodes/_optimization/_node_optimizer.py +++ b/autointent/nodes/_optimization/_node_optimizer.py @@ -10,29 +10,14 @@ import optuna import torch from optuna.trial import Trial -from pydantic import BaseModel, Field from typing_extensions import assert_never from autointent import Dataset from autointent.context import Context -from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode +from autointent.custom_types import NodeType, ParamSpaceFloat, ParamSpaceInt, SamplerType, SearchSpaceValidationMode from autointent.nodes.info import NODES_INFO -class ParamSpaceInt(BaseModel): - low: int = Field(..., description="Low boundary of the search space.") - high: int = Field(..., description="High boundary of the search space.") - step: int = Field(1, description="Step of the search space.") - log: bool = Field(False, description="Whether to use a logarithmic scale.") - - -class ParamSpaceFloat(BaseModel): - low: float = Field(..., description="Low boundary of the search space.") - high: float = Field(..., description="High boundary of the search space.") - step: float | None = Field(None, description="Step of the search space.") - log: bool = Field(False, description="Whether to use a logarithmic scale.") - - class NodeOptimizer: """Node optimizer class.""" @@ -148,7 +133,7 @@ def objective( return target_metric - def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]: + def suggest(self, trial: Trial, search_space: dict[str, Any | list[Any]]) -> dict[str, Any]: # noqa: C901 res: dict[str, Any] = {} def is_valid_param_space( @@ -167,6 +152,20 @@ def is_valid_param_space( res[param_name] = trial.suggest_int(param_name, **param_space) elif is_valid_param_space(param_space, ParamSpaceFloat): res[param_name] = trial.suggest_float(param_name, **param_space) + elif isinstance(param_space, dict): + # sklearn_scorer clf_args + clf_args: dict[str, Any] = {} + for k, v in param_space.items(): + if isinstance(v, list): + clf_args[k] = trial.suggest_categorical(f"{param_name}_{k}", choices=v) + elif is_valid_param_space(v, ParamSpaceInt): + clf_args[k] = trial.suggest_int(f"{param_name}_{k}", **v) + elif is_valid_param_space(v, ParamSpaceFloat): + clf_args[k] = trial.suggest_float(f"{param_name}_{k}", **v) + else: + msg = f"Unsupported type of param search space: {v}" + raise TypeError(msg) + res["clf_args"] = clf_args else: msg = f"Unsupported type of param search space: {param_space}" raise TypeError(msg) diff --git a/autointent/nodes/schemes.py b/autointent/nodes/schemes.py index 4e33e1314..ce564ef70 100644 --- a/autointent/nodes/schemes.py +++ b/autointent/nodes/schemes.py @@ -1,14 +1,16 @@ """Schemes.""" +import functools import inspect +import operator from collections.abc import Iterator +from types import NoneType, UnionType from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints -from pydantic import BaseModel, Field, PositiveInt, RootModel +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel -from autointent.custom_types import NodeType +from autointent.custom_types import NodeType, ParamSpaceFloat, ParamSpaceInt from autointent.modules.abc import BaseModule -from autointent.nodes._optimization._node_optimizer import ParamSpaceFloat, ParamSpaceInt from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo @@ -26,17 +28,26 @@ def type_matches(target: type, tp: type) -> bool: """ Recursively check if the target type is present in the given type. - This function handles union types by unwrapping Annotated types where necessary. + This function handles union types and generic types (e.g. dict[...] by checking + their origin) after unwrapping Annotated types. - :param target: Target type - :param tp: Given type - :return: If the target type is present in the given type + :param target: Target type to check for. + :param tp: Given type which may be a union, generic, or annotated type. + :return: True if the target type is present in the given type. """ origin = get_origin(tp) - - if origin is Union: # float | list[float] + if origin is Union: return any(type_matches(target, arg) for arg in get_args(tp)) - return unwrap_annotated(tp) is target + + # Unwrap Annotated types, if any. + unwrapped = unwrap_annotated(tp) + + # If the unwrapped type is a generic type, check its origin. + generic_origin = get_origin(unwrapped) + if generic_origin is not None: + return generic_origin is target + + return unwrapped is target def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] | None: @@ -56,7 +67,12 @@ def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] return None -def generate_models_and_union_type_for_classes( +def to_union(types: list[type]) -> type: + """Convert a tuple of types into a union type.""" + return functools.reduce(operator.or_, types) + + +def generate_models_and_union_type_for_classes( # noqa: PLR0912, C901 classes: list[type[BaseModule]], ) -> type[BaseModel]: """Dynamically generates Pydantic models for class constructors and creates a union type.""" @@ -70,6 +86,7 @@ def generate_models_and_union_type_for_classes( fields = { "module_name": (Literal[cls.name], Field(...)), "n_trials": (PositiveInt | None, Field(None, description="Number of trials")), + "model_config": (ConfigDict, ConfigDict(extra="forbid")), } for param_name, param in init_signature.parameters.items(): @@ -78,11 +95,33 @@ def generate_models_and_union_type_for_classes( param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042 field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...) - search_type = get_optuna_class(param_type) - if search_type is None: - fields[param_name] = (list[param_type], field) + if not type_matches(dict, param_type): + search_type = get_optuna_class(param_type) + if search_type is None: + fields[param_name] = (list[param_type], field) + else: + fields[param_name] = (list[param_type] | search_type, field) else: - fields[param_name] = (list[param_type] | search_type, field) + dict_key_type, dict_values_types = get_args(param_type) + is_optional = False + if dict_values_types is NoneType: # if dict is optional + is_optional = True + dict_key_type, dict_values_types = get_args(dict_key_type) + if get_origin(dict_values_types) is UnionType: + filed_types: list[type[Any]] = [] + for value in get_args(dict_values_types): + search_type = get_optuna_class(value) + if search_type is not None: + filed_types.append(search_type) + filed_types.append(list[value]) # type: ignore[valid-type] + filed_type = to_union(filed_types) + else: + filed_type = dict_values_types + + if is_optional: + fields[param_name] = (dict[dict_key_type, filed_type] | None, field) # type: ignore[valid-type] + else: + fields[param_name] = (dict[dict_key_type, filed_type], field) # type: ignore[valid-type] model_name = f"{cls.__name__}InitModel" models[cls.__name__] = type( @@ -157,7 +196,7 @@ class RegexNodeValidator(BaseModel): search_space: list[RegexpSearchSpaceType] -SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator +SearchSpaceTypes: TypeAlias = ScoringNodeValidator | EmbeddingNodeValidator | DecisionNodeValidator | RegexNodeValidator class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]): diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index e0cba086f..c23002db5 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -1,6 +1,7 @@ { "$defs": { "AdaptiveDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "adaptive", @@ -51,6 +52,7 @@ "type": "object" }, "ArgmaxDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "argmax", @@ -131,6 +133,7 @@ "type": "object" }, "DNNCScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "dnnc", @@ -333,6 +336,7 @@ "type": "object" }, "DescriptionScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "description", @@ -629,6 +633,7 @@ "type": "object" }, "JinoosDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "jinoos", @@ -679,6 +684,7 @@ "type": "object" }, "KNNScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "knn", @@ -758,6 +764,7 @@ "type": "object" }, "LinearScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "linear", @@ -875,6 +882,7 @@ "type": "object" }, "LogregAimedEmbeddingInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "logreg_embedding", @@ -936,6 +944,7 @@ "type": "object" }, "MLKnnScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "mlknn", @@ -1051,10 +1060,10 @@ "items": { "anyOf": [ { - "$ref": "#/$defs/EmbeddingNodeValidator" + "$ref": "#/$defs/ScoringNodeValidator" }, { - "$ref": "#/$defs/ScoringNodeValidator" + "$ref": "#/$defs/EmbeddingNodeValidator" }, { "$ref": "#/$defs/DecisionNodeValidator" @@ -1068,6 +1077,7 @@ "type": "array" }, "ParamSpaceFloat": { + "description": "Param space for optimizing float parameters for Optuna.", "properties": { "low": { "description": "Low boundary of the search space.", @@ -1107,6 +1117,7 @@ "type": "object" }, "ParamSpaceInt": { + "description": "Param space for optimizing int parameters for Optuna.", "properties": { "low": { "description": "Low boundary of the search space.", @@ -1139,6 +1150,7 @@ "type": "object" }, "RegexInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "regex", @@ -1216,6 +1228,7 @@ "type": "object" }, "RerankScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "rerank", @@ -1363,6 +1376,7 @@ "type": "object" }, "RetrievalAimedEmbeddingInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "retrieval", @@ -1509,6 +1523,7 @@ "type": "object" }, "SklearnScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "sklearn", @@ -1531,6 +1546,45 @@ }, "clf_name": { "items": { + "enum": [ + "AdaBoostClassifier", + "BaggingClassifier", + "BernoulliNB", + "CalibratedClassifierCV", + "CategoricalNB", + "ClassifierChain", + "ComplementNB", + "DecisionTreeClassifier", + "DummyClassifier", + "ExtraTreeClassifier", + "ExtraTreesClassifier", + "FixedThresholdClassifier", + "GaussianNB", + "GaussianProcessClassifier", + "GradientBoostingClassifier", + "HistGradientBoostingClassifier", + "KNeighborsClassifier", + "LabelPropagation", + "LabelSpreading", + "LinearDiscriminantAnalysis", + "LogisticRegression", + "LogisticRegressionCV", + "MLPClassifier", + "MultiOutputClassifier", + "MultinomialNB", + "NearestCentroid", + "NuSVC", + "OneVsRestClassifier", + "QuadraticDiscriminantAnalysis", + "RadiusNeighborsClassifier", + "RandomForestClassifier", + "SGDClassifier", + "SVC", + "SelfTrainingClassifier", + "StackingClassifier", + "TunedThresholdClassifierCV", + "VotingClassifier" + ], "type": "string" }, "title": "Clf Name", @@ -1557,32 +1611,70 @@ "type": "array" }, "clf_args": { - "items": { - "anyOf": [ - { - "type": "number" - }, - { - "type": "string" + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "$ref": "#/$defs/ParamSpaceInt" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceFloat" + }, + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "items": { + "type": "boolean" + }, + "type": "array" + }, + { + "items": { + "items": {}, + "type": "array" + }, + "type": "array" + } + ] }, - { - "type": "boolean" - } - ] - }, - "title": "Clf Args", - "type": "array" + "type": "object" + }, + { + "type": "null" + } + ], + "default": [ + null + ], + "title": "Clf Args" } }, "required": [ "module_name", - "clf_name", - "clf_args" + "clf_name" ], "title": "SklearnScorerInitModel", "type": "object" }, "ThresholdDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "threshold", @@ -1642,6 +1734,7 @@ "type": "object" }, "TunableDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "tunable", diff --git a/docs/optimizer_search_space_config.schema.json b/docs/optimizer_search_space_config.schema.json index 67932b98b..9d3cb8774 100644 --- a/docs/optimizer_search_space_config.schema.json +++ b/docs/optimizer_search_space_config.schema.json @@ -1,6 +1,7 @@ { "$defs": { "AdaptiveDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "adaptive", @@ -51,6 +52,7 @@ "type": "object" }, "ArgmaxDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "argmax", @@ -101,6 +103,12 @@ "description": "Maximum length of input sequences.", "title": "Max Length" }, + "model_name": { + "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "description": "Name of the hugging face model.", + "title": "Model Name", + "type": "string" + }, "device": { "anyOf": [ { @@ -114,12 +122,6 @@ "description": "Torch notation for CPU or CUDA.", "title": "Device" }, - "model_name": { - "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", - "description": "Name of the hugging face model.", - "title": "Model Name", - "type": "string" - }, "train_head": { "default": false, "description": "Whether to train the head of the model. If False, LogReg will be trained.", @@ -131,6 +133,7 @@ "type": "object" }, "DNNCScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "dnnc", @@ -286,6 +289,7 @@ "type": "object" }, "DescriptionScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "description", @@ -372,6 +376,12 @@ "description": "Maximum length of input sequences.", "title": "Max Length" }, + "model_name": { + "default": "sentence-transformers/all-MiniLM-L6-v2", + "description": "Name of the hugging face model.", + "title": "Model Name", + "type": "string" + }, "device": { "anyOf": [ { @@ -385,12 +395,6 @@ "description": "Torch notation for CPU or CUDA.", "title": "Device" }, - "model_name": { - "default": "sentence-transformers/all-MiniLM-L6-v2", - "description": "Name of the hugging face model.", - "title": "Model Name", - "type": "string" - }, "default_prompt": { "anyOf": [ { @@ -582,6 +586,7 @@ "type": "object" }, "JinoosDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "jinoos", @@ -632,6 +637,7 @@ "type": "object" }, "KNNScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "knn", @@ -711,6 +717,7 @@ "type": "object" }, "LinearScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "linear", @@ -759,6 +766,7 @@ "type": "object" }, "LogregAimedEmbeddingInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "logreg_embedding", @@ -820,6 +828,7 @@ "type": "object" }, "MLKnnScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "mlknn", @@ -931,6 +940,7 @@ "type": "string" }, "ParamSpaceFloat": { + "description": "Param space for optimizing float parameters for Optuna.", "properties": { "low": { "description": "Low boundary of the search space.", @@ -970,6 +980,7 @@ "type": "object" }, "ParamSpaceInt": { + "description": "Param space for optimizing int parameters for Optuna.", "properties": { "low": { "description": "Low boundary of the search space.", @@ -1002,6 +1013,7 @@ "type": "object" }, "RegexInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "regex", @@ -1079,6 +1091,7 @@ "type": "object" }, "RerankScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "rerank", @@ -1194,21 +1207,28 @@ "type": "array" }, "rank_threshold_cutoff": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], "default": [ null ], - "items": { - "anyOf": [ - { - "type": "integer" - }, - { - "type": "null" - } - ] - }, - "title": "Rank Threshold Cutoff", - "type": "array" + "title": "Rank Threshold Cutoff" } }, "required": [ @@ -1219,6 +1239,7 @@ "type": "object" }, "RetrievalAimedEmbeddingInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "retrieval", @@ -1365,6 +1386,7 @@ "type": "object" }, "SklearnScorerInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "sklearn", @@ -1387,6 +1409,45 @@ }, "clf_name": { "items": { + "enum": [ + "AdaBoostClassifier", + "BaggingClassifier", + "BernoulliNB", + "CalibratedClassifierCV", + "CategoricalNB", + "ClassifierChain", + "ComplementNB", + "DecisionTreeClassifier", + "DummyClassifier", + "ExtraTreeClassifier", + "ExtraTreesClassifier", + "FixedThresholdClassifier", + "GaussianNB", + "GaussianProcessClassifier", + "GradientBoostingClassifier", + "HistGradientBoostingClassifier", + "KNeighborsClassifier", + "LabelPropagation", + "LabelSpreading", + "LinearDiscriminantAnalysis", + "LogisticRegression", + "LogisticRegressionCV", + "MLPClassifier", + "MultiOutputClassifier", + "MultinomialNB", + "NearestCentroid", + "NuSVC", + "OneVsRestClassifier", + "QuadraticDiscriminantAnalysis", + "RadiusNeighborsClassifier", + "RandomForestClassifier", + "SGDClassifier", + "SVC", + "SelfTrainingClassifier", + "StackingClassifier", + "TunedThresholdClassifierCV", + "VotingClassifier" + ], "type": "string" }, "title": "Clf Name", @@ -1413,20 +1474,70 @@ "type": "array" }, "clf_args": { - "items": {}, - "title": "Clf Args", - "type": "array" + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "$ref": "#/$defs/ParamSpaceInt" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceFloat" + }, + { + "items": { + "type": "number" + }, + "type": "array" + }, + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "items": { + "type": "boolean" + }, + "type": "array" + }, + { + "items": { + "items": {}, + "type": "array" + }, + "type": "array" + } + ] + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": [ + null + ], + "title": "Clf Args" } }, "required": [ "module_name", - "clf_name", - "clf_args" + "clf_name" ], "title": "SklearnScorerInitModel", "type": "object" }, "ThresholdDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "threshold", @@ -1486,6 +1597,7 @@ "type": "object" }, "TunableDecisionInitModel": { + "additionalProperties": false, "properties": { "module_name": { "const": "tunable", @@ -1553,10 +1665,10 @@ "items": { "anyOf": [ { - "$ref": "#/$defs/EmbeddingNodeValidator" + "$ref": "#/$defs/ScoringNodeValidator" }, { - "$ref": "#/$defs/ScoringNodeValidator" + "$ref": "#/$defs/EmbeddingNodeValidator" }, { "$ref": "#/$defs/DecisionNodeValidator" diff --git a/tests/assets/configs/optuna.yaml b/tests/assets/configs/optuna.yaml index 56ef453a2..db63b7df8 100644 --- a/tests/assets/configs/optuna.yaml +++ b/tests/assets/configs/optuna.yaml @@ -20,9 +20,10 @@ - module_name: sklearn clf_name: - RandomForestClassifier - n_estimators: - low: 5 - high: 10 + clf_args: + n_estimators: + low: 5 + high: 10 - node_type: decision target_metric: decision_accuracy search_space: diff --git a/tests/configs/test_scoring.py b/tests/configs/test_scoring.py index e95d32be7..d22e5efa7 100644 --- a/tests/configs/test_scoring.py +++ b/tests/configs/test_scoring.py @@ -17,7 +17,6 @@ def valid_scoring_config(): "cross_encoder_config": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], "embedder_config": ["sergeyzh/rubert-tiny-turbo"], "k": [5, 10], - "train_head": [False, True], }, { "module_name": "knn", @@ -25,7 +24,7 @@ def valid_scoring_config(): "k": [5, 10], "weights": ["uniform", "distance"], }, - {"module_name": "linear", "embedder_config": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]}, + {"module_name": "linear", "embedder_config": ["sergeyzh/rubert-tiny-turbo"]}, { "module_name": "mlknn", "embedder_config": ["sergeyzh/rubert-tiny-turbo"], @@ -46,12 +45,14 @@ def valid_scoring_config(): "weights": ["distance"], "rank_threshold_cutoff": [None, 3], }, - # { - # "module_name": "sklearn", - # "embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"], - # "clf_name": ["LogisticRegression"], - # "clf_args": [{"C": 1.0}, {"C": 0.5}], - # }, + { + "module_name": "sklearn", + "embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"], + "clf_name": ["LogisticRegression"], + "clf_args": { + "C": [1.0, 0.5], + }, + }, ], } ] @@ -97,3 +98,33 @@ def test_invalid_scoring_config_wrong_type(): with pytest.raises(ValidationError): OptimizationSearchSpaceConfig(invalid_config) + + +def test_valid_scoring_config_sklearn(): + """Test that a valid scoring config passes validation.""" + sklearn_scoring = [ + { + "node_type": "scoring", + "target_metric": "scoring_roc_auc", + "search_space": [ + { + "module_name": "sklearn", + "embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"], + "clf_name": ["LogisticRegression"], + "clf_args": { + "C": { + "low": 0.5, + "high": 1.0, + }, + "w": [5], + }, + }, + ], + } + ] + + config = OptimizationSearchSpaceConfig(sklearn_scoring) + assert config[0].node_type == "scoring" + assert config[0].target_metric == "scoring_roc_auc" + assert isinstance(config[0].search_space, list) + assert config[0].search_space[0].module_name == "sklearn" diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index b2bebc7f6..42ef2abf8 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -42,6 +42,53 @@ def test_bayes(dataset, sampler): pipeline_optimizer.fit(dataset, refit_after=False, sampler=sampler) +# @pytest.mark.parametrize( +# "clf_args", +# [ +# {"C": [1.0], "tol": [0.5]}, +# { +# "C": { +# "low": 0.01, +# "high": 10, +# "step": 5, +# }, +# }, +# ], +# ) +# def test_multiple_sklearn_scorers(dataset, clf_args): +# project_dir = setup_environment() +# search_space = [ +# { +# "node_type": "scoring", +# "target_metric": "scoring_roc_auc", +# "search_space": [ +# { +# "module_name": "sklearn", +# "clf_name": ["LogisticRegression"], +# "clf_args": clf_args, +# "embedder_config": ["sentence-transformers/all-MiniLM-L6-v2"], +# }, +# ], +# }, +# { +# "node_type": "decision", +# "target_metric": "decision_accuracy", +# "search_space": [ +# { +# "module_name": "argmax", +# }, +# ], +# }, +# ] +# +# pipeline_optimizer = Pipeline.from_search_space(search_space) +# +# pipeline_optimizer.set_config(LoggingConfig(project_dir=project_dir, dump_modules=True, clear_ram=True)) +# pipeline_optimizer.set_config(DataConfig(scheme="ho", separate_nodes=True)) +# +# pipeline_optimizer.fit(dataset, refit_after=False) + + @pytest.mark.parametrize( "task_type", ["multiclass", "multilabel", "description"], diff --git a/user_guides/basic_usage/03_automl.py b/user_guides/basic_usage/03_automl.py index 2512dcf91..18734d7ef 100644 --- a/user_guides/basic_usage/03_automl.py +++ b/user_guides/basic_usage/03_automl.py @@ -52,7 +52,7 @@ """ # %% -preset["search_space"][1]["search_space"][0]["k"] = [1, 3] +preset["search_space"][0]["search_space"][0]["k"] = [1, 3] custom_pipeline = Pipeline.from_optimization_config(preset) # %% [markdown]