diff --git a/autointent/_callbacks/__init__.py b/autointent/_callbacks/__init__.py index 200556eda..376b12b46 100644 --- a/autointent/_callbacks/__init__.py +++ b/autointent/_callbacks/__init__.py @@ -7,7 +7,7 @@ REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]} -REPORTERS_NAMES = list(REPORTERS.keys()) +REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type] def get_callbacks(reporters: list[str] | None) -> CallbackHandler: diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 478d4435e..54bbd5e07 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -3,7 +3,7 @@ import json import logging from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, get_args import numpy as np import yaml @@ -13,7 +13,7 @@ from autointent.custom_types import ListOfGenericLabels, NodeType, SamplerType from autointent.metrics import DECISION_METRICS from autointent.nodes import InferenceNode, NodeOptimizer -from autointent.nodes.schemes import OptimizationConfig +from autointent.nodes.schemes import OptimizationConfig, OptimizationSearchSpaceConfig from autointent.utils import load_default_search_space, load_search_space from ._schemas import InferencePipelineOutput, InferencePipelineUtteranceOutput @@ -28,17 +28,24 @@ class Pipeline: def __init__( self, nodes: list[NodeOptimizer] | list[InferenceNode], + sampler: SamplerType = "brute", seed: int = 42, ) -> None: """ Initialize the pipeline optimizer. :param nodes: list of nodes + :param sampler: sampler type :param seed: random seed """ self._logger = logging.getLogger(__name__) self.nodes = {node.node_type: node for node in nodes} self.seed = seed + if sampler not in get_args(SamplerType): + msg = f"Sampler should be one of {get_args(SamplerType)}" + raise ValueError(msg) + + self.sampler = sampler if isinstance(nodes[0], NodeOptimizer): self.logging_config = LoggingConfig(dump_dir=None) @@ -74,10 +81,34 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed """ if isinstance(search_space, Path | str): search_space = load_search_space(search_space) - validated_search_space = OptimizationConfig(search_space).model_dump() # type: ignore[arg-type] + validated_search_space = OptimizationSearchSpaceConfig(search_space).model_dump() # type: ignore[arg-type] nodes = [NodeOptimizer(**node) for node in validated_search_space] return cls(nodes=nodes, seed=seed) + @classmethod + def from_optimization_config(cls, config: dict[str, Any] | Path | str) -> "Pipeline": + """ + Create pipeline optimizer from optimization config. + + :param config: Optimization config + :return: + """ + if isinstance(config, Path | str): + with Path(config).open() as file: + loaded_config = yaml.safe_load(file) + else: + loaded_config = config + optimization_config = OptimizationConfig(**loaded_config) + pipeline = cls( + [NodeOptimizer(**node.model_dump()) for node in optimization_config.task_config.search_space], + optimization_config.task_config.sampler, + optimization_config.seed, + ) + pipeline.set_config(optimization_config.logging_config) + pipeline.set_config(optimization_config.vector_index_config) + pipeline.set_config(optimization_config.data_config) + return pipeline + @classmethod def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline": """ @@ -90,7 +121,7 @@ def default_optimizer(cls, multilabel: bool, seed: int = 42) -> "Pipeline": """ return cls.from_search_space(search_space=load_default_search_space(multilabel), seed=seed) - def _fit(self, context: Context, sampler: SamplerType = "brute") -> None: + def _fit(self, context: Context, sampler: SamplerType) -> None: """ Optimize the pipeline. @@ -99,7 +130,7 @@ def _fit(self, context: Context, sampler: SamplerType = "brute") -> None: self.context = context self._logger.info("starting pipeline optimization...") self.context.callback_handler.start_run( - run_name=self.context.logging_config.run_name, + run_name=self.context.logging_config.get_run_name(), dirpath=self.context.logging_config.dirpath, ) for node_type in NodeType: @@ -123,7 +154,7 @@ def fit( self, dataset: Dataset, refit_after: bool = False, - sampler: SamplerType = "brute", + sampler: SamplerType | None = None, ) -> Context: """ Optimize the pipeline from dataset. @@ -148,6 +179,9 @@ def fit( "Test data is not provided. Final test metrics won't be calculated after pipeline optimization." ) + if sampler is None: + sampler = self.sampler or "brute" + self._fit(context, sampler) if context.is_ram_to_clear(): diff --git a/autointent/configs/__init__.py b/autointent/configs/__init__.py index 4ee9775c9..d267b03f3 100644 --- a/autointent/configs/__init__.py +++ b/autointent/configs/__init__.py @@ -4,7 +4,6 @@ from ._optimization import ( DataConfig, LoggingConfig, - TaskConfig, VectorIndexConfig, ) @@ -13,6 +12,5 @@ "InferenceNodeConfig", "InferenceNodeConfig", "LoggingConfig", - "TaskConfig", "VectorIndexConfig", ] diff --git a/autointent/configs/_optimization.py b/autointent/configs/_optimization.py index d174da827..910af2513 100644 --- a/autointent/configs/_optimization.py +++ b/autointent/configs/_optimization.py @@ -2,10 +2,10 @@ from pathlib import Path -from pydantic import BaseModel, Field, PositiveInt, field_validator +from pydantic import BaseModel, Field, PositiveInt from autointent._callbacks import REPORTERS_NAMES -from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme +from autointent.custom_types import FloatFromZeroToOne, ValidationScheme from ._name import get_run_name @@ -13,67 +13,65 @@ class DataConfig(BaseModel): """Configuration for the data used in the optimization process.""" - scheme: ValidationScheme = "ho" + scheme: ValidationScheme = Field("ho", description="Validation scheme to use.") """Hold-out or cross-validation.""" - n_folds: PositiveInt = 3 + n_folds: PositiveInt = Field(3, description="Number of folds in cross-validation.") """Number of folds in cross-validation.""" - validation_size: FloatFromZeroToOne = 0.2 + validation_size: FloatFromZeroToOne = Field( + 0.2, + description=( + "Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)." + ), + ) """Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).""" - separation_ratio: FloatFromZeroToOne | None = 0.5 + separation_ratio: FloatFromZeroToOne | None = Field( + 0.5, description="Set to float to prevent data leak between scoring and decision nodes." + ) """Set to float to prevent data leak between scoring and decision nodes.""" -class TaskConfig(BaseModel): - """Configuration for the task to optimize.""" - - search_space_path: Path | None = None - """Path to the search space configuration file. If None, the default search space will be used""" - sampler: SamplerType = "brute" - - class LoggingConfig(BaseModel): """Configuration for the logging.""" - project_dir: Path = Field(default_factory=lambda: Path.cwd() / "runs") + _dirpath: Path | None = None + _dump_dir: Path | None = None + + project_dir: Path | str | None = Field(None, description="Path to the directory with different runs.") """Path to the directory with different runs.""" - run_name: str = Field(default_factory=get_run_name) + run_name: str | None = Field(None, description="Name of the run. If None, a random name will be generated.") """Name of the run. If None, a random name will be generated""" - dump_modules: bool = False + dump_modules: bool = Field(False, description="Whether to dump the modules or not") """Whether to dump the modules or not""" - clear_ram: bool = False + clear_ram: bool = Field(False, description="Whether to clear the RAM after dumping the modules") """Whether to clear the RAM after dumping the modules""" - report_to: list[str] | None = None + report_to: list[REPORTERS_NAMES] | None = Field( # type: ignore[valid-type] + None, description="List of callbacks to report to. If None, no callbacks will be used" + ) """List of callbacks to report to. If None, no callbacks will be used""" @property def dirpath(self) -> Path: """Path to the directory where the logs will be saved.""" - if not hasattr(self, "_dirpath"): - self._dirpath = self.project_dir / self.run_name + if self._dirpath is None: + project_dir = Path.cwd() / "runs" if self.project_dir is None else Path(self.project_dir) + self._dirpath = project_dir / self.get_run_name() return self._dirpath @property def dump_dir(self) -> Path: """Path to the directory where the modules will be dumped.""" - if not hasattr(self, "_dump_dir"): + if self._dump_dir is None: self._dump_dir = self.dirpath / "modules_dumps" return self._dump_dir - @field_validator("report_to") - @classmethod - def validate_report_to(cls, v: list[str] | None) -> list[str] | None: - """Validate the report_to field.""" - if v is None: - return None - for reporter in v: - if reporter not in REPORTERS_NAMES: - msg = f"Reporter {reporter} is not supported. Supported reporters: {REPORTERS_NAMES}" - raise ValueError(msg) - return v + def get_run_name(self) -> str: + if self.run_name is None: + self.run_name = get_run_name() + return self.run_name class VectorIndexConfig(BaseModel): """Configuration for the vector index.""" - save_db: bool = False + save_db: bool = Field(False, description="Whether to save the vector index database or not") """Whether to save the vector index database or not""" diff --git a/autointent/nodes/__init__.py b/autointent/nodes/__init__.py index 0f77b66a7..1529ee331 100644 --- a/autointent/nodes/__init__.py +++ b/autointent/nodes/__init__.py @@ -2,10 +2,10 @@ from ._inference_node import InferenceNode from ._optimization import NodeOptimizer -from .schemes import OptimizationConfig +from .schemes import OptimizationSearchSpaceConfig __all__ = [ "InferenceNode", "NodeOptimizer", - "OptimizationConfig", + "OptimizationSearchSpaceConfig", ] diff --git a/autointent/nodes/schemes.py b/autointent/nodes/schemes.py index 64a2c7809..4f91572f5 100644 --- a/autointent/nodes/schemes.py +++ b/autointent/nodes/schemes.py @@ -6,7 +6,8 @@ from pydantic import BaseModel, Field, PositiveInt, RootModel -from autointent.custom_types import NodeType +from autointent.configs import DataConfig, LoggingConfig, VectorIndexConfig +from autointent.custom_types import NodeType, SamplerType from autointent.modules.abc import BaseModule from autointent.nodes._optimization._node_optimizer import ParamSpaceFloat, ParamSpaceInt from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo @@ -160,7 +161,7 @@ class RegexNodeValidator(BaseModel): SearchSpaceTypes: TypeAlias = EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator -class OptimizationConfig(RootModel[list[SearchSpaceTypes]]): +class OptimizationSearchSpaceConfig(RootModel[list[SearchSpaceTypes]]): """Optimizer configuration.""" def __iter__( @@ -178,3 +179,21 @@ def __getitem__(self, item: int) -> SearchSpaceTypes: :return: Item """ return self.root[item] + + +class TaskConfig(BaseModel): + """Configuration for the task to optimize.""" + + search_space: OptimizationSearchSpaceConfig + """Path to the search space configuration file. If None, the default search space will be used""" + sampler: SamplerType = "brute" + + +class OptimizationConfig(BaseModel): + """Configuration for the optimization process.""" + + data_config: DataConfig = DataConfig() + task_config: TaskConfig + logging_config: LoggingConfig = LoggingConfig() + vector_index_config: VectorIndexConfig = VectorIndexConfig() + seed: PositiveInt = 42 diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index f17fc96c4..f0eb70c38 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -211,6 +211,53 @@ "title": "DNNCScorerInitModel", "type": "object" }, + "DataConfig": { + "description": "Configuration for the data used in the optimization process.", + "properties": { + "scheme": { + "default": "ho", + "description": "Validation scheme to use.", + "enum": [ + "ho", + "cv" + ], + "title": "Scheme", + "type": "string" + }, + "n_folds": { + "default": 3, + "description": "Number of folds in cross-validation.", + "exclusiveMinimum": 0, + "title": "N Folds", + "type": "integer" + }, + "validation_size": { + "default": 0.2, + "description": "Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).", + "maximum": 1.0, + "minimum": 0.0, + "title": "Validation Size", + "type": "number" + }, + "separation_ratio": { + "anyOf": [ + { + "maximum": 1.0, + "minimum": 0.0, + "type": "number" + }, + { + "type": "null" + } + ], + "default": 0.5, + "description": "Set to float to prevent data leak between scoring and decision nodes.", + "title": "Separation Ratio" + } + }, + "title": "DataConfig", + "type": "object" + }, "DecisionNodeValidator": { "description": "Search space configuration for the Decision node.", "properties": { @@ -755,6 +802,75 @@ "title": "LinearScorerInitModel", "type": "object" }, + "LoggingConfig": { + "description": "Configuration for the logging.", + "properties": { + "project_dir": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Path to the directory with different runs.", + "title": "Project Dir" + }, + "run_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the run. If None, a random name will be generated.", + "title": "Run Name" + }, + "dump_modules": { + "default": false, + "description": "Whether to dump the modules or not", + "title": "Dump Modules", + "type": "boolean" + }, + "clear_ram": { + "default": false, + "description": "Whether to clear the RAM after dumping the modules", + "title": "Clear Ram", + "type": "boolean" + }, + "report_to": { + "anyOf": [ + { + "items": { + "enum": [ + "wandb", + "tensorboard" + ], + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "List of callbacks to report to. If None, no callbacks will be used", + "title": "Report To" + } + }, + "title": "LoggingConfig", + "type": "object" + }, "LogregAimedEmbeddingInitModel": { "properties": { "module_name": { @@ -927,6 +1043,27 @@ "title": "NodeType", "type": "string" }, + "OptimizationSearchSpaceConfig": { + "description": "Optimizer configuration.", + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbeddingNodeValidator" + }, + { + "$ref": "#/$defs/ScoringNodeValidator" + }, + { + "$ref": "#/$defs/DecisionNodeValidator" + }, + { + "$ref": "#/$defs/RegexNodeValidator" + } + ] + }, + "title": "OptimizationSearchSpaceConfig", + "type": "array" + }, "ParamSpaceFloat": { "properties": { "low": { @@ -1435,6 +1572,29 @@ "title": "SklearnScorerInitModel", "type": "object" }, + "TaskConfig": { + "description": "Configuration for the task to optimize.", + "properties": { + "search_space": { + "$ref": "#/$defs/OptimizationSearchSpaceConfig" + }, + "sampler": { + "default": "brute", + "enum": [ + "brute", + "tpe", + "random" + ], + "title": "Sampler", + "type": "string" + } + }, + "required": [ + "search_space" + ], + "title": "TaskConfig", + "type": "object" + }, "ThresholdDecisionInitModel": { "properties": { "module_name": { @@ -1556,25 +1716,61 @@ ], "title": "TunableDecisionInitModel", "type": "object" + }, + "VectorIndexConfig": { + "description": "Configuration for the vector index.", + "properties": { + "save_db": { + "default": false, + "description": "Whether to save the vector index database or not", + "title": "Save Db", + "type": "boolean" + } + }, + "title": "VectorIndexConfig", + "type": "object" } }, - "description": "Optimizer configuration.", - "items": { - "anyOf": [ - { - "$ref": "#/$defs/EmbeddingNodeValidator" - }, - { - "$ref": "#/$defs/ScoringNodeValidator" - }, - { - "$ref": "#/$defs/DecisionNodeValidator" - }, - { - "$ref": "#/$defs/RegexNodeValidator" + "description": "Configuration for the optimization process.", + "properties": { + "data_config": { + "$ref": "#/$defs/DataConfig", + "default": { + "scheme": "ho", + "n_folds": 3, + "validation_size": 0.2, + "separation_ratio": 0.5 } - ] + }, + "task_config": { + "$ref": "#/$defs/TaskConfig" + }, + "logging_config": { + "$ref": "#/$defs/LoggingConfig", + "default": { + "project_dir": null, + "run_name": null, + "dump_modules": false, + "clear_ram": false, + "report_to": null + } + }, + "vector_index_config": { + "$ref": "#/$defs/VectorIndexConfig", + "default": { + "save_db": false + } + }, + "seed": { + "default": 42, + "exclusiveMinimum": 0, + "title": "Seed", + "type": "integer" + } }, + "required": [ + "task_config" + ], "title": "OptimizationConfig", - "type": "array" + "type": "object" } \ No newline at end of file diff --git a/docs/optimizer_search_space_config.schema.json b/docs/optimizer_search_space_config.schema.json new file mode 100644 index 000000000..c0409c1fc --- /dev/null +++ b/docs/optimizer_search_space_config.schema.json @@ -0,0 +1,1580 @@ +{ + "$defs": { + "AdaptiveDecisionInitModel": { + "properties": { + "module_name": { + "const": "adaptive", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "search_space": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "items": { + "maximum": 1.0, + "minimum": 0.0, + "type": "number" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "title": "Search Space", + "type": "array" + } + }, + "required": [ + "module_name" + ], + "title": "AdaptiveDecisionInitModel", + "type": "object" + }, + "ArgmaxDecisionInitModel": { + "properties": { + "module_name": { + "const": "argmax", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + } + }, + "required": [ + "module_name" + ], + "title": "ArgmaxDecisionInitModel", + "type": "object" + }, + "CrossEncoderConfig": { + "properties": { + "batch_size": { + "default": 32, + "description": "Batch size for model inference.", + "exclusiveMinimum": 0, + "title": "Batch Size", + "type": "integer" + }, + "max_length": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum length of input sequences.", + "title": "Max Length" + }, + "model_name": { + "description": "Name of the hugging face model.", + "title": "Model Name", + "type": "string" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Torch notation for CPU or CUDA.", + "title": "Device" + }, + "train_head": { + "default": false, + "description": "Whether to train the head of the model. If False, LogReg will be trained.", + "title": "Train Head", + "type": "boolean" + } + }, + "required": [ + "model_name" + ], + "title": "CrossEncoderConfig", + "type": "object" + }, + "DNNCScorerInitModel": { + "properties": { + "module_name": { + "const": "dnnc", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "cross_encoder_config": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/CrossEncoderConfig" + }, + { + "type": "string" + } + ] + }, + "title": "Cross Encoder Config", + "type": "array" + }, + "k": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "title": "K" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name", + "cross_encoder_config", + "k" + ], + "title": "DNNCScorerInitModel", + "type": "object" + }, + "DecisionNodeValidator": { + "description": "Search space configuration for the Decision node.", + "properties": { + "node_type": { + "$ref": "#/$defs/NodeType", + "default": "decision" + }, + "target_metric": { + "enum": [ + "decision_accuracy", + "decision_f1", + "decision_precision", + "decision_recall", + "decision_roc_auc" + ], + "title": "Target Metric", + "type": "string" + }, + "metrics": { + "anyOf": [ + { + "items": { + "enum": [ + "decision_accuracy", + "decision_f1", + "decision_precision", + "decision_recall", + "decision_roc_auc" + ], + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Metrics" + }, + "search_space": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/ArgmaxDecisionInitModel" + }, + { + "$ref": "#/$defs/JinoosDecisionInitModel" + }, + { + "$ref": "#/$defs/ThresholdDecisionInitModel" + }, + { + "$ref": "#/$defs/TunableDecisionInitModel" + }, + { + "$ref": "#/$defs/AdaptiveDecisionInitModel" + } + ] + }, + "title": "Search Space", + "type": "array" + } + }, + "required": [ + "target_metric", + "search_space" + ], + "title": "DecisionNodeValidator", + "type": "object" + }, + "DescriptionScorerInitModel": { + "properties": { + "module_name": { + "const": "description", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "temperature": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0.0, + "type": "number" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceFloat" + } + ], + "title": "Temperature" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name", + "temperature" + ], + "title": "DescriptionScorerInitModel", + "type": "object" + }, + "EmbedderConfig": { + "properties": { + "batch_size": { + "default": 32, + "description": "Batch size for model inference.", + "exclusiveMinimum": 0, + "title": "Batch Size", + "type": "integer" + }, + "max_length": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum length of input sequences.", + "title": "Max Length" + }, + "model_name": { + "description": "Name of the hugging face model.", + "title": "Model Name", + "type": "string" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Torch notation for CPU or CUDA.", + "title": "Device" + }, + "default_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Default prompt for the model. This is used when no task specific prompt is not provided.", + "title": "Default Prompt" + }, + "classifier_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for classifier.", + "title": "Classifier Prompt" + }, + "cluster_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for clustering.", + "title": "Cluster Prompt" + }, + "sts_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for finding most similar sentences.", + "title": "Sts Prompt" + }, + "query_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for query.", + "title": "Query Prompt" + }, + "passage_prompt": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Prompt for passage.", + "title": "Passage Prompt" + }, + "use_cache": { + "default": false, + "description": "Whether to use embeddings caching.", + "title": "Use Cache", + "type": "boolean" + } + }, + "required": [ + "model_name" + ], + "title": "EmbedderConfig", + "type": "object" + }, + "EmbeddingNodeValidator": { + "description": "Search space configuration for the Embedding node.", + "properties": { + "node_type": { + "$ref": "#/$defs/NodeType", + "default": "embedding" + }, + "target_metric": { + "enum": [ + "retrieval_hit_rate", + "retrieval_map", + "retrieval_mrr", + "retrieval_ndcg", + "retrieval_precision", + "retrieval_hit_rate_intersecting", + "retrieval_hit_rate_macro", + "retrieval_map_intersecting", + "retrieval_map_macro", + "retrieval_mrr_intersecting", + "retrieval_mrr_macro", + "retrieval_ndcg_intersecting", + "retrieval_ndcg_macro", + "retrieval_precision_intersecting", + "retrieval_precision_macro", + "scoring_accuracy", + "scoring_f1", + "scoring_log_likelihood", + "scoring_precision", + "scoring_recall", + "scoring_roc_auc", + "scoring_hit_rate", + "scoring_map", + "scoring_neg_coverage", + "scoring_neg_ranking_loss" + ], + "title": "Target Metric", + "type": "string" + }, + "metrics": { + "anyOf": [ + { + "items": { + "enum": [ + "retrieval_hit_rate", + "retrieval_map", + "retrieval_mrr", + "retrieval_ndcg", + "retrieval_precision", + "retrieval_hit_rate_intersecting", + "retrieval_hit_rate_macro", + "retrieval_map_intersecting", + "retrieval_map_macro", + "retrieval_mrr_intersecting", + "retrieval_mrr_macro", + "retrieval_ndcg_intersecting", + "retrieval_ndcg_macro", + "retrieval_precision_intersecting", + "retrieval_precision_macro", + "scoring_accuracy", + "scoring_f1", + "scoring_log_likelihood", + "scoring_precision", + "scoring_recall", + "scoring_roc_auc", + "scoring_hit_rate", + "scoring_map", + "scoring_neg_coverage", + "scoring_neg_ranking_loss" + ], + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Metrics" + }, + "search_space": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/RetrievalAimedEmbeddingInitModel" + }, + { + "$ref": "#/$defs/LogregAimedEmbeddingInitModel" + } + ] + }, + "title": "Search Space", + "type": "array" + } + }, + "required": [ + "target_metric", + "search_space" + ], + "title": "EmbeddingNodeValidator", + "type": "object" + }, + "JinoosDecisionInitModel": { + "properties": { + "module_name": { + "const": "jinoos", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "search_space": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "items": { + "maximum": 1.0, + "minimum": 0.0, + "type": "number" + }, + "type": "array" + }, + { + "type": "null" + } + ] + }, + "title": "Search Space", + "type": "array" + } + }, + "required": [ + "module_name" + ], + "title": "JinoosDecisionInitModel", + "type": "object" + }, + "KNNScorerInitModel": { + "properties": { + "module_name": { + "const": "knn", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "k": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "title": "K" + }, + "weights": { + "items": { + "enum": [ + "uniform", + "distance", + "closest" + ], + "type": "string" + }, + "title": "Weights", + "type": "array" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name", + "k", + "weights" + ], + "title": "KNNScorerInitModel", + "type": "object" + }, + "LinearScorerInitModel": { + "properties": { + "module_name": { + "const": "linear", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name" + ], + "title": "LinearScorerInitModel", + "type": "object" + }, + "LogregAimedEmbeddingInitModel": { + "properties": { + "module_name": { + "const": "logreg_embedding", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "embedder_config": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + } + ] + }, + "title": "Embedder Config", + "type": "array" + }, + "cv": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "default": [ + 3 + ], + "title": "Cv" + } + }, + "required": [ + "module_name", + "embedder_config" + ], + "title": "LogregAimedEmbeddingInitModel", + "type": "object" + }, + "MLKnnScorerInitModel": { + "properties": { + "module_name": { + "const": "mlknn", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "k": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "title": "K" + }, + "s": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0.0, + "type": "number" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceFloat" + } + ], + "default": [ + 1.0 + ], + "title": "S" + }, + "ignore_first_neighbours": { + "anyOf": [ + { + "items": { + "minimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "default": [ + 0 + ], + "title": "Ignore First Neighbours" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name", + "k" + ], + "title": "MLKnnScorerInitModel", + "type": "object" + }, + "NodeType": { + "description": "Enumeration of node types in the AutoIntent pipeline.", + "enum": [ + "regex", + "embedding", + "scoring", + "decision" + ], + "title": "NodeType", + "type": "string" + }, + "ParamSpaceFloat": { + "properties": { + "low": { + "description": "Low boundary of the search space.", + "title": "Low", + "type": "number" + }, + "high": { + "description": "High boundary of the search space.", + "title": "High", + "type": "number" + }, + "step": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Step of the search space.", + "title": "Step" + }, + "log": { + "default": false, + "description": "Whether to use a logarithmic scale.", + "title": "Log", + "type": "boolean" + } + }, + "required": [ + "low", + "high" + ], + "title": "ParamSpaceFloat", + "type": "object" + }, + "ParamSpaceInt": { + "properties": { + "low": { + "description": "Low boundary of the search space.", + "title": "Low", + "type": "integer" + }, + "high": { + "description": "High boundary of the search space.", + "title": "High", + "type": "integer" + }, + "step": { + "default": 1, + "description": "Step of the search space.", + "title": "Step", + "type": "integer" + }, + "log": { + "default": false, + "description": "Whether to use a logarithmic scale.", + "title": "Log", + "type": "boolean" + } + }, + "required": [ + "low", + "high" + ], + "title": "ParamSpaceInt", + "type": "object" + }, + "RegexInitModel": { + "properties": { + "module_name": { + "const": "regex", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + } + }, + "required": [ + "module_name" + ], + "title": "RegexInitModel", + "type": "object" + }, + "RegexNodeValidator": { + "description": "Search space configuration for the Regexp node.", + "properties": { + "node_type": { + "$ref": "#/$defs/NodeType", + "default": "regex" + }, + "target_metric": { + "enum": [ + "regex_partial_accuracy", + "regex_partial_precision" + ], + "title": "Target Metric", + "type": "string" + }, + "metrics": { + "anyOf": [ + { + "items": { + "enum": [ + "regex_partial_accuracy", + "regex_partial_precision" + ], + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Metrics" + }, + "search_space": { + "items": { + "$ref": "#/$defs/RegexInitModel" + }, + "title": "Search Space", + "type": "array" + } + }, + "required": [ + "target_metric", + "search_space" + ], + "title": "RegexNodeValidator", + "type": "object" + }, + "RerankScorerInitModel": { + "properties": { + "module_name": { + "const": "rerank", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "k": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "title": "K" + }, + "weights": { + "items": { + "enum": [ + "uniform", + "distance", + "closest" + ], + "type": "string" + }, + "title": "Weights", + "type": "array" + }, + "cross_encoder_config": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/CrossEncoderConfig" + }, + { + "type": "string" + } + ] + }, + "title": "Cross Encoder Config", + "type": "array" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + }, + "m": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "default": [ + null + ], + "title": "M" + }, + "rank_threshold_cutoff": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ] + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "default": [ + null + ], + "title": "Rank Threshold Cutoff" + } + }, + "required": [ + "module_name", + "k", + "weights", + "cross_encoder_config" + ], + "title": "RerankScorerInitModel", + "type": "object" + }, + "RetrievalAimedEmbeddingInitModel": { + "properties": { + "module_name": { + "const": "retrieval", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "k": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "title": "K" + }, + "embedder_config": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name", + "k", + "embedder_config" + ], + "title": "RetrievalAimedEmbeddingInitModel", + "type": "object" + }, + "ScoringNodeValidator": { + "description": "Search space configuration for the Scoring node.", + "properties": { + "node_type": { + "$ref": "#/$defs/NodeType", + "default": "scoring" + }, + "target_metric": { + "enum": [ + "scoring_accuracy", + "scoring_f1", + "scoring_log_likelihood", + "scoring_precision", + "scoring_recall", + "scoring_roc_auc", + "scoring_hit_rate", + "scoring_map", + "scoring_neg_coverage", + "scoring_neg_ranking_loss" + ], + "title": "Target Metric", + "type": "string" + }, + "metrics": { + "anyOf": [ + { + "items": { + "enum": [ + "scoring_accuracy", + "scoring_f1", + "scoring_log_likelihood", + "scoring_precision", + "scoring_recall", + "scoring_roc_auc", + "scoring_hit_rate", + "scoring_map", + "scoring_neg_coverage", + "scoring_neg_ranking_loss" + ], + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Metrics" + }, + "search_space": { + "items": { + "anyOf": [ + { + "$ref": "#/$defs/DNNCScorerInitModel" + }, + { + "$ref": "#/$defs/KNNScorerInitModel" + }, + { + "$ref": "#/$defs/LinearScorerInitModel" + }, + { + "$ref": "#/$defs/DescriptionScorerInitModel" + }, + { + "$ref": "#/$defs/RerankScorerInitModel" + }, + { + "$ref": "#/$defs/SklearnScorerInitModel" + }, + { + "$ref": "#/$defs/MLKnnScorerInitModel" + } + ] + }, + "title": "Search Space", + "type": "array" + } + }, + "required": [ + "target_metric", + "search_space" + ], + "title": "ScoringNodeValidator", + "type": "object" + }, + "SklearnScorerInitModel": { + "properties": { + "module_name": { + "const": "sklearn", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "clf_name": { + "default": [ + "LogisticRegression" + ], + "items": { + "type": "string" + }, + "title": "Clf Name", + "type": "array" + }, + "clf_args": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "type": "object" + }, + { + "type": "null" + } + ] + }, + "title": "Clf Args", + "type": "array" + }, + "embedder_config": { + "default": [ + null + ], + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbedderConfig" + }, + { + "type": "string" + }, + { + "type": "null" + } + ] + }, + "title": "Embedder Config", + "type": "array" + } + }, + "required": [ + "module_name" + ], + "title": "SklearnScorerInitModel", + "type": "object" + }, + "ThresholdDecisionInitModel": { + "properties": { + "module_name": { + "const": "threshold", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "thresh": { + "anyOf": [ + { + "items": { + "anyOf": [ + { + "maximum": 1.0, + "minimum": 0.0, + "type": "number" + }, + { + "items": { + "maximum": 1.0, + "minimum": 0.0, + "type": "number" + }, + "type": "array" + } + ] + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceFloat" + } + ], + "default": [ + 0.5 + ], + "title": "Thresh" + } + }, + "required": [ + "module_name" + ], + "title": "ThresholdDecisionInitModel", + "type": "object" + }, + "TunableDecisionInitModel": { + "properties": { + "module_name": { + "const": "tunable", + "title": "Module Name", + "type": "string" + }, + "n_trials": { + "anyOf": [ + { + "exclusiveMinimum": 0, + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of trials", + "title": "N Trials" + }, + "target_metric": { + "default": [ + "decision_accuracy" + ], + "items": { + "enum": [ + "decision_accuracy", + "decision_f1", + "decision_roc_auc", + "decision_precision", + "decision_recall" + ], + "type": "string" + }, + "title": "Target Metric", + "type": "array" + }, + "n_optuna_trials": { + "anyOf": [ + { + "items": { + "exclusiveMinimum": 0, + "type": "integer" + }, + "type": "array" + }, + { + "$ref": "#/$defs/ParamSpaceInt" + } + ], + "default": [ + 320 + ], + "title": "N Optuna Trials" + } + }, + "required": [ + "module_name" + ], + "title": "TunableDecisionInitModel", + "type": "object" + } + }, + "description": "Optimizer configuration.", + "items": { + "anyOf": [ + { + "$ref": "#/$defs/EmbeddingNodeValidator" + }, + { + "$ref": "#/$defs/ScoringNodeValidator" + }, + { + "$ref": "#/$defs/DecisionNodeValidator" + }, + { + "$ref": "#/$defs/RegexNodeValidator" + } + ] + }, + "title": "OptimizationSearchSpaceConfig", + "type": "array" +} \ No newline at end of file diff --git a/scripts/generate_json_schema_config.py b/scripts/generate_json_schema_config.py index a80f8cf64..16e5035d6 100644 --- a/scripts/generate_json_schema_config.py +++ b/scripts/generate_json_schema_config.py @@ -1,10 +1,18 @@ import json from pathlib import Path -from autointent.nodes.schemes import OptimizationConfig +from autointent.nodes.schemes import OptimizationConfig, OptimizationSearchSpaceConfig -def generate_json_schema() -> None: +def generate_json_schema_search_space_config() -> None: + """Generate the JSON schema for the optimizer config.""" + schema = OptimizationSearchSpaceConfig.model_json_schema() + path = Path(__file__).parent.parent / "docs" / "optimizer_search_space_config.schema.json" + with path.open("w") as f: + json.dump(schema, f, indent=4) + + +def generate_json_schema_optimizer_config() -> None: """Generate the JSON schema for the optimizer config.""" schema = OptimizationConfig.model_json_schema() path = Path(__file__).parent.parent / "docs" / "optimizer_config.schema.json" @@ -13,4 +21,5 @@ def generate_json_schema() -> None: if __name__ == "__main__": - generate_json_schema() + generate_json_schema_search_space_config() + generate_json_schema_optimizer_config() diff --git a/tests/assets/configs/full_training.yaml b/tests/assets/configs/full_training.yaml new file mode 100644 index 000000000..ba15a9196 --- /dev/null +++ b/tests/assets/configs/full_training.yaml @@ -0,0 +1,28 @@ +task_config: + search_space: + - node_type: embedding + target_metric: retrieval_hit_rate + search_space: + - module_name: retrieval + k: [10] + embedder_config: + - model_name: sentence-transformers/all-MiniLM-L6-v2 + - node_type: scoring + target_metric: scoring_roc_auc + search_space: + - module_name: linear + - node_type: decision + target_metric: decision_accuracy + search_space: + - module_name: argmax + sampler: brute +data_config: + scheme: ho + n_folds: 3 + validation_size: 0.2 + separation_ratio: null +logging_config: + run_name: full_training +vector_index_config: + save_db: false +seed: 42 \ No newline at end of file diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index 6cb03aa8f..4bf9a8b95 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -88,7 +88,7 @@ def test_pipeline_callbacks(dataset): context.callback_handler = CallbackHandler([DummyCallback]) context.set_dataset(dataset, DataConfig(scheme="ho", separate_nodes=True)) - pipeline_optimizer._fit(context) + pipeline_optimizer._fit(context, "brute") dummy_callback = context.callback_handler.callbacks[0] diff --git a/tests/configs/test_combined_config.py b/tests/configs/test_combined_config.py index 7a243576f..bef606f48 100644 --- a/tests/configs/test_combined_config.py +++ b/tests/configs/test_combined_config.py @@ -2,7 +2,7 @@ from pydantic import ValidationError from autointent.nodes.schemes import ( - OptimizationConfig, + OptimizationSearchSpaceConfig, ) from tests.conftest import get_search_space @@ -44,7 +44,7 @@ def valid_optimizer_config(): def test_valid_optimizer_config(valid_optimizer_config): """Test that a valid optimizer config passes validation.""" - config = OptimizationConfig(valid_optimizer_config) + config = OptimizationSearchSpaceConfig(valid_optimizer_config) assert config[0].node_type == "scoring" assert config[1].node_type == "embedding" @@ -55,7 +55,7 @@ def test_valid_optimizer_config(valid_optimizer_config): ) def test_optimizer_config(task_type): search_space = get_search_space(task_type) - config = OptimizationConfig(search_space) + config = OptimizationSearchSpaceConfig(search_space) assert config @@ -72,7 +72,7 @@ def test_invalid_optimizer_config_missing_field(): ] with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) def test_invalid_optimizer_config_wrong_type(): @@ -93,4 +93,4 @@ def test_invalid_optimizer_config_wrong_type(): ] with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) diff --git a/tests/configs/test_decision.py b/tests/configs/test_decision.py index b7a580cdb..79a60fb33 100644 --- a/tests/configs/test_decision.py +++ b/tests/configs/test_decision.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from autointent.nodes.schemes import OptimizationConfig +from autointent.nodes.schemes import OptimizationSearchSpaceConfig @pytest.fixture @@ -27,7 +27,7 @@ def valid_decision_config(): def test_valid_decision_config(valid_decision_config): """Test that a valid decision config passes validation.""" - config = OptimizationConfig(valid_decision_config) + config = OptimizationSearchSpaceConfig(valid_decision_config) assert config[0].node_type == "decision" assert config[0].target_metric == "decision_roc_auc" assert isinstance(config[0].search_space, list) @@ -45,7 +45,7 @@ def test_invalid_decision_config_missing_field(): ] with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) def test_invalid_decision_config_wrong_type(): @@ -68,4 +68,4 @@ def test_invalid_decision_config_wrong_type(): ] with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) diff --git a/tests/configs/test_embedding.py b/tests/configs/test_embedding.py index fdf090348..868ba0901 100644 --- a/tests/configs/test_embedding.py +++ b/tests/configs/test_embedding.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from autointent.nodes import OptimizationConfig +from autointent.nodes import OptimizationSearchSpaceConfig @pytest.fixture @@ -25,7 +25,7 @@ def valid_embedding_config(): def test_valid_embedding_config(valid_embedding_config): """Test that a valid embedding config passes validation.""" - config = OptimizationConfig(valid_embedding_config) + config = OptimizationSearchSpaceConfig(valid_embedding_config) assert config[0].node_type == "embedding" assert config[0].target_metric == "retrieval_mrr" assert isinstance(config[0].search_space, list) @@ -50,7 +50,7 @@ def test_invalid_embedding_config_missing_field(): ] with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) def test_invalid_embedding_config_wrong_type(): @@ -70,4 +70,4 @@ def test_invalid_embedding_config_wrong_type(): ] with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) diff --git a/tests/configs/test_full_config.py b/tests/configs/test_full_config.py new file mode 100644 index 000000000..1398aa3de --- /dev/null +++ b/tests/configs/test_full_config.py @@ -0,0 +1,19 @@ +import pytest +from pydantic import ValidationError + +from autointent.nodes.schemes import OptimizationConfig +from tests.conftest import get_search_space + + +def test_validate_full_config(): + config = get_search_space("full_training") + validated_config = OptimizationConfig(**config) + assert isinstance(validated_config, OptimizationConfig) + + +def test_not_valid_reporting(): + config = get_search_space("full_training") + config["logging_config"]["report_to"] = "test" + + with pytest.raises(ValidationError): + OptimizationConfig(**config) diff --git a/tests/configs/test_scoring.py b/tests/configs/test_scoring.py index 9f9c4e440..9877d3add 100644 --- a/tests/configs/test_scoring.py +++ b/tests/configs/test_scoring.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from autointent.nodes import OptimizationConfig +from autointent.nodes import OptimizationSearchSpaceConfig @pytest.fixture @@ -59,7 +59,7 @@ def valid_scoring_config(): def test_valid_scoring_config(valid_scoring_config): """Test that a valid scoring config passes validation.""" - config = OptimizationConfig(valid_scoring_config) + config = OptimizationSearchSpaceConfig(valid_scoring_config) assert config[0].node_type == "scoring" assert config[0].target_metric == "scoring_roc_auc" assert isinstance(config[0].search_space, list) @@ -77,7 +77,7 @@ def test_invalid_scoring_config_missing_field(): } with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) def test_invalid_scoring_config_wrong_type(): @@ -96,4 +96,4 @@ def test_invalid_scoring_config_wrong_type(): } with pytest.raises(ValidationError): - OptimizationConfig(invalid_config) + OptimizationSearchSpaceConfig(invalid_config) diff --git a/tests/conftest.py b/tests/conftest.py index fb1ed3a4a..729123f1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,7 @@ def dataset_no_oos(): return Dataset.from_json(path) -TaskType = Literal["multiclass", "multilabel", "description", "optuna", "light"] +TaskType = Literal["multiclass", "multilabel", "description", "optuna", "light", "full_training"] def get_search_space_path(task_type: TaskType): diff --git a/tests/pipeline/test_optimization.py b/tests/pipeline/test_optimization.py index 310db986b..4ae6d7e3b 100644 --- a/tests/pipeline/test_optimization.py +++ b/tests/pipeline/test_optimization.py @@ -20,6 +20,12 @@ def test_no_node_separation(dataset_no_oos): pipeline_optimizer.fit(dataset_no_oos, refit_after=False) +def test_full_config(dataset_no_oos): + search_space = get_search_space("full_training") + pipeline_optimizer = Pipeline.from_optimization_config(search_space) + pipeline_optimizer.fit(dataset_no_oos, refit_after=False) + + @pytest.mark.parametrize( "sampler", ["tpe", "random"], diff --git a/tests/test_utils.py b/tests/test_utils.py index 106def342..08f526f21 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,10 @@ import pytest -from autointent.nodes import OptimizationConfig +from autointent.nodes import OptimizationSearchSpaceConfig from autointent.utils import load_default_search_space @pytest.mark.parametrize("multilabel", [True, False]) def test_load_default_configs(multilabel): search_space = load_default_search_space(multilabel=multilabel) - OptimizationConfig(search_space).model_dump() + OptimizationSearchSpaceConfig(search_space).model_dump()