diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index 212d886b1..b8ebdf3da 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -54,4 +54,25 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: ) -__all__ = [] # type: ignore[var-annotated] +__all__ = [ + "AdaptiveDecision", + "ArgmaxDecision", + "BaseDecision", + "BaseEmbedding", + "BaseModule", + "BaseRegex", + "BaseScorer", + "DNNCScorer", + "DescriptionScorer", + "JinoosDecision", + "KNNScorer", + "LinearScorer", + "LogregAimedEmbedding", + "MLKnnScorer", + "RerankScorer", + "RetrievalAimedEmbedding", + "SimpleRegex", + "SklearnScorer", + "ThresholdDecision", + "TunableDecision", +] diff --git a/autointent/modules/scoring/_bert.py b/autointent/modules/scoring/_bert.py index 5fd075ebc..9cbe50c40 100644 --- a/autointent/modules/scoring/_bert.py +++ b/autointent/modules/scoring/_bert.py @@ -31,14 +31,14 @@ class BertScorer(BaseScorer): def __init__( self, - model_config: HFModelConfig | str | dict[str, Any] | None = None, + classification_model_config: HFModelConfig | str | dict[str, Any] | None = None, num_train_epochs: int = 3, batch_size: int = 8, learning_rate: float = 5e-5, seed: int = 0, report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003 ) -> None: - self.model_config = HFModelConfig.from_search_config(model_config) + self.classification_model_config = HFModelConfig.from_search_config(classification_model_config) self.num_train_epochs = num_train_epochs self.batch_size = batch_size self.learning_rate = learning_rate @@ -49,19 +49,19 @@ def __init__( def from_context( cls, context: Context, - model_config: HFModelConfig | str | dict[str, Any] | None = None, + classification_model_config: HFModelConfig | str | dict[str, Any] | None = None, num_train_epochs: int = 3, batch_size: int = 8, learning_rate: float = 5e-5, seed: int = 0, ) -> "BertScorer": - if model_config is None: - model_config = context.resolve_embedder() + if classification_model_config is None: + classification_model_config = context.resolve_embedder() report_to = context.logging_config.report_to return cls( - model_config=model_config, + classification_model_config=classification_model_config, num_train_epochs=num_train_epochs, batch_size=batch_size, learning_rate=learning_rate, @@ -70,7 +70,7 @@ def from_context( ) def get_embedder_config(self) -> dict[str, Any]: - return self.model_config.model_dump() + return self.classification_model_config.model_dump() def fit( self, @@ -81,7 +81,7 @@ def fit( self.clear_cache() self._validate_task(labels) - model_name = self.model_config.model_name + model_name = self.classification_model_config.model_name self._tokenizer = AutoTokenizer.from_pretrained(model_name) label2id = {i: i for i in range(self._n_classes)} @@ -95,11 +95,11 @@ def fit( problem_type="multi_label_classification" if self._multilabel else "single_label_classification", ) - use_cpu = self.model_config.device == "cpu" + use_cpu = self.classification_model_config.device == "cpu" def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: return self._tokenizer( # type: ignore[no-any-return] - examples["text"], return_tensors="pt", **self.model_config.tokenizer_config.model_dump() + examples["text"], return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump() ) dataset = Dataset.from_dict({"text": utterances, "labels": labels}) @@ -148,7 +148,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]: all_predictions = [] for i in range(0, len(utterances), self.batch_size): batch = utterances[i : i + self.batch_size] - inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump()) + inputs = self._tokenizer( + batch, return_tensors="pt", **self.classification_model_config.tokenizer_config.model_dump() + ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = self._model(**inputs) diff --git a/autointent/modules/scoring/_linear.py b/autointent/modules/scoring/_linear.py index 06e04c4dd..be74ada89 100644 --- a/autointent/modules/scoring/_linear.py +++ b/autointent/modules/scoring/_linear.py @@ -4,6 +4,7 @@ import numpy as np import numpy.typing as npt +from pydantic import PositiveInt from sklearn.linear_model import LogisticRegression, LogisticRegressionCV from sklearn.multioutput import MultiOutputClassifier @@ -22,7 +23,6 @@ class LinearScorer(BaseScorer): Args: embedder_config: Config of the embedder model cv: Number of cross-validation folds, defaults to 3 - n_jobs: Number of parallel jobs for cross-validation, defaults to None seed: Random seed for reproducibility, defaults to 0 Example: @@ -72,18 +72,21 @@ def __init__( def from_context( cls, context: Context, + cv: PositiveInt = 3, embedder_config: EmbedderConfig | str | None = None, ) -> "LinearScorer": """Create a LinearScorer instance using a Context object. Args: context: Context containing configurations and utilities + cv: Number of cross-validation folds, defaults to 3 embedder_config: Config of the embedder, or None to use the best embedder """ if embedder_config is None: embedder_config = context.resolve_embedder() return cls( + cv=cv, embedder_config=embedder_config, ) diff --git a/autointent/nodes/_node_optimizer.py b/autointent/nodes/_node_optimizer.py index 9ff6caaec..8d1d4872f 100644 --- a/autointent/nodes/_node_optimizer.py +++ b/autointent/nodes/_node_optimizer.py @@ -11,7 +11,6 @@ import optuna import torch from optuna.trial import Trial -from pydantic import BaseModel, Field from typing_extensions import assert_never from autointent import Dataset @@ -19,25 +18,7 @@ from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode from autointent.nodes.emissions_tracker import EmissionsTracker from autointent.nodes.info import NODES_INFO - - -class ParamSpaceInt(BaseModel): - """Integer parameter search space configuration.""" - - low: int = Field(..., description="Lower boundary of the search space.") - high: int = Field(..., description="Upper boundary of the search space.") - step: int = Field(1, description="Step size for the search space.") - log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") - - -class ParamSpaceFloat(BaseModel): - """Float parameter search space configuration.""" - - low: float = Field(..., description="Lower boundary of the search space.") - high: float = Field(..., description="Upper boundary of the search space.") - step: float | None = Field(None, description="Step size for the search space (if applicable).") - log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") - +from autointent.schemas.node_validation import ParamSpaceFloat, ParamSpaceInt, SearchSpaceConfig logger = logging.getLogger(__name__) @@ -277,7 +258,8 @@ def validate_nodes_with_dataset(self, dataset: Dataset, mode: SearchSpaceValidat def validate_search_space(self, search_space: list[dict[str, Any]]) -> None: """Check if search space is configured correctly.""" - for module_search_space in search_space: + validated_search_space = SearchSpaceConfig(search_space).model_dump() + for module_search_space in validated_search_space: module_search_space_no_optuna, module_name = self._reformat_search_space(deepcopy(module_search_space)) for params_combination in it.product(*module_search_space_no_optuna.values()): diff --git a/autointent/schemas/node_validation.py b/autointent/schemas/node_validation.py new file mode 100644 index 000000000..ca118ecf6 --- /dev/null +++ b/autointent/schemas/node_validation.py @@ -0,0 +1,366 @@ +"""Schemes.""" + +import inspect +from collections.abc import Iterator +from typing import Annotated, Any, Literal, TypeAlias, Union, get_args, get_origin, get_type_hints + +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, RootModel, ValidationError, model_validator + +from autointent.custom_types import NodeType +from autointent.modules import BaseModule +from autointent.nodes.info import DecisionNodeInfo, EmbeddingNodeInfo, RegexNodeInfo, ScoringNodeInfo + + +class ParamSpaceInt(BaseModel): + """Integer parameter search space configuration.""" + + low: int = Field(..., description="Lower boundary of the search space.") + high: int = Field(..., description="Upper boundary of the search space.") + step: int = Field(1, description="Step size for the search space.") + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") + + +class ParamSpaceFloat(BaseModel): + """Float parameter search space configuration.""" + + low: float = Field(..., description="Lower boundary of the search space.") + high: float = Field(..., description="Upper boundary of the search space.") + step: float | None = Field(None, description="Step size for the search space (if applicable).") + log: bool = Field(False, description="Indicates whether to use a logarithmic scale.") + + +def unwrap_annotated(tp: type) -> type: + """Unwrap the Annotated type to get the actual type. + + :param tp: Type to unwrap + :return: Unwrapped type + """ + # Check if the type is an Annotated type using get_origin + # Annotated[int, "some metadata"] would have origin as Annotated + # If it is Annotated, extract the first argument which is the actual type + # Otherwise return the original type unchanged + return get_args(tp)[0] if get_origin(tp) is Annotated else tp + + +def type_matches(target: type, tp: type) -> bool: + """Recursively check if the target type is present in the given type. + + This function handles union types by unwrapping Annotated types where necessary. + + :param target: Target type + :param tp: Given type + :return: If the target type is present in the given type + """ + # Get the origin of the type to determine if it's a generic type + # For example, Union, List, Dict, etc. + origin = get_origin(tp) + + # If the type is a Union (e.g., int | str or Union[int, str]) + if origin is Union: + # Check if any of the union's arguments match the target type + # Recursively call type_matches for each argument in the union + return any(type_matches(target, arg) for arg in get_args(tp)) + + # For non-Union types, unwrap any Annotated wrapper and compare with the target type + # This handles cases like Annotated[int, "some description"] matching with int + return unwrap_annotated(tp) is target + + +def get_optuna_class(param_type: type) -> type[ParamSpaceInt | ParamSpaceFloat] | None: + """Get the Optuna class for the given parameter type. + + If the (possibly annotated or union) type includes int or float, this function + returns the corresponding search space class. + + :param param_type: Parameter type (could be a union, annotated type, or container) + :return: ParamSpaceInt if the type matches int, ParamSpaceFloat if it matches float, else None. + """ + # Check if the parameter type matches or includes int + if type_matches(int, param_type): + return ParamSpaceInt + # Check if the parameter type matches or includes float + if type_matches(float, param_type): + return ParamSpaceFloat + # Return None if neither int nor float types match + return None + + +def generate_models_and_union_type_for_classes( + classes: list[type[BaseModule]], +) -> tuple[type[BaseModel], dict[str, type[BaseModel]]]: + """Dynamically generates Pydantic models for class constructors and creates a union type. + + This function takes a list of module classes and creates Pydantic models that represent + their initialization parameters. It also creates a union type of all these models. + + Args: + classes: A list of BaseModule subclasses to generate models for + + Returns: + A tuple containing: + - A union type of all generated models + - A dictionary mapping module names to their generated model classes + """ + # Dictionary to store the generated models, keyed by module name + models: dict[str, type[BaseModel]] = {} + + # Iterate through each module class + for cls in classes: + # Get the signature of the from_context method to extract parameters + init_signature = inspect.signature(cls.from_context) + # Get the global namespace for resolving variables in type hints + globalns = getattr(cls.from_context, "__globals__", {}) + # Get type hints with forward references resolved and extra info preserved + type_hints = get_type_hints(cls.from_context, globalns, None, include_extras=True) + + # Check if the method accepts arbitrary keyword arguments (**kwargs) + has_kwarg_arg = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in init_signature.parameters.values()) + + # Initialize fields dictionary with common fields for all models + fields = { + # Module name field with a Literal type restricting it to this specific class name + "module_name": (Literal[cls.name], Field(...)), + # Optional field for number of trials in hyperparameter optimization + "n_trials": (PositiveInt | None, Field(None, description="Number of trials")), + # Config field to control extra fields behavior based on kwargs presence + "model_config": (ConfigDict, ConfigDict(extra="allow" if has_kwarg_arg else "forbid")), + } + + # Process each parameter from the method signature + for param_name, param in init_signature.parameters.items(): + # Skip self, cls, context parameters and **kwargs + if param_name in ("self", "cls", "context") or param.kind == inspect.Parameter.VAR_KEYWORD: + continue + + # Get the parameter's type annotation, defaulting to Any if not specified + param_type: TypeAlias = type_hints.get(param_name, Any) # type: ignore[valid-type] # noqa: PYI042 + + # Create a Field with default value if provided, otherwise make it required + field = Field(default=[param.default]) if param.default is not inspect.Parameter.empty else Field(...) + + # Check if this parameter should have an Optuna search space + search_type = get_optuna_class(param_type) + + if search_type is None: + # Regular parameter: use a list of the parameter's type + fields[param_name] = (list[param_type], field) + else: + # Parameter eligible for optimization: allow either list of values or search space + fields[param_name] = (list[param_type] | search_type, field) + + # Generate a name for the model class + model_name = f"{cls.__name__}InitModel" + + # Dynamically create a Pydantic model class for this module + models[cls.name] = type( + model_name, + (BaseModel,), # Inherit from BaseModel + { + # Set type annotations for all fields + "__annotations__": {k: v[0] for k, v in fields.items()}, + # Set field objects for all fields + **{k: v[1] for k, v in fields.items()}, + }, + ) + + # Return a union type of all models and the dictionary of models + return Union[tuple(models.values())], models # type: ignore[return-value] # noqa: UP007 + + +DecisionSearchSpaceType, DecisionNodesBaseModels = generate_models_and_union_type_for_classes( + list(DecisionNodeInfo.modules_available.values()) +) +DecisionMetrics = Literal[tuple(DecisionNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class DecisionNodeValidator(BaseModel): + """Search space configuration for the Decision node.""" + + node_type: NodeType = NodeType.decision + target_metric: DecisionMetrics # type: ignore[valid-type] + metrics: list[DecisionMetrics] | None = None # type: ignore[valid-type] + search_space: list[DecisionSearchSpaceType] # type: ignore[valid-type] + + +EmbeddingSearchSpaceType, EmbeddingBaseModels = generate_models_and_union_type_for_classes( + list(EmbeddingNodeInfo.modules_available.values()) +) +EmbeddingMetrics: TypeAlias = Literal[tuple(EmbeddingNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class EmbeddingNodeValidator(BaseModel): + """Search space configuration for the Embedding node.""" + + node_type: NodeType = NodeType.embedding + target_metric: EmbeddingMetrics + metrics: list[EmbeddingMetrics] | None = None + search_space: list[EmbeddingSearchSpaceType] # type: ignore[valid-type] + + +ScoringSearchSpaceType, ScoringNodesBaseModels = generate_models_and_union_type_for_classes( + list(ScoringNodeInfo.modules_available.values()) +) +ScoringMetrics: TypeAlias = Literal[tuple(ScoringNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class ScoringNodeValidator(BaseModel): + """Search space configuration for the Scoring node.""" + + node_type: NodeType = NodeType.scoring + target_metric: ScoringMetrics + metrics: list[ScoringMetrics] | None = None + search_space: list[ScoringSearchSpaceType] # type: ignore[valid-type] + + +RegexpSearchSpaceType, RegexNodesBaseModels = generate_models_and_union_type_for_classes( + list(RegexNodeInfo.modules_available.values()) +) +RegexpMetrics: TypeAlias = Literal[tuple(RegexNodeInfo.metrics_available.keys())] # type: ignore[valid-type] + + +class RegexNodeValidator(BaseModel): + """Search space configuration for the Regexp node.""" + + node_type: NodeType = NodeType.regex + target_metric: RegexpMetrics + metrics: list[RegexpMetrics] | None = None + search_space: list[RegexpSearchSpaceType] # type: ignore[valid-type] + + +NodeValidatorType: TypeAlias = ( + EmbeddingNodeValidator | ScoringNodeValidator | DecisionNodeValidator | RegexNodeValidator +) +SearchSpaceType: TypeAlias = ( + DecisionSearchSpaceType | EmbeddingSearchSpaceType | ScoringSearchSpaceType | RegexpSearchSpaceType # type: ignore[valid-type] +) + + +class SearchSpaceConfig(RootModel[list[SearchSpaceType]]): + """Search space configuration.""" + + def __iter__( + self, + ) -> Iterator[SearchSpaceType]: + """Iterate over the root.""" + return iter(self.root) + + def __getitem__(self, item: int) -> SearchSpaceType: + """To get item directly from the root. + + :param item: Index + + :return: Item + """ + return self.root[item] + + @model_validator(mode="before") + @classmethod + def validate_nodes(cls, data: list[Any]) -> list[Any]: # noqa: C901 + """Validate the search space configuration. + + Args: + data: List of search space configurations. + + Returns: + List of validated search space configurations. + """ + error_message = "" + for i, item in enumerate(data): + if isinstance(item, BaseModel): + continue + if not isinstance(item, dict): + msg = "Each search space configuration must be a dictionary." + raise TypeError(msg) + node_name = item.get("module_name") + if node_name is None: + error_message += f"Search space configuration at index {i} is missing 'module_name'.\n" + continue + + if node_name in DecisionNodesBaseModels: + node_class = DecisionNodesBaseModels[node_name] + elif node_name in EmbeddingBaseModels: + node_class = EmbeddingBaseModels[node_name] + elif node_name in ScoringNodesBaseModels: + node_class = ScoringNodesBaseModels[node_name] + elif node_name in RegexNodesBaseModels: + node_class = RegexNodesBaseModels[node_name] + else: + error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n" + break + try: + node_class(**item) + except ValidationError as e: + error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n" + continue + if len(error_message) > 0: + raise TypeError(error_message) + return data + + +class OptimizationSearchSpaceConfig(RootModel[list[NodeValidatorType]]): + """Optimizer configuration.""" + + def __iter__( + self, + ) -> Iterator[NodeValidatorType]: + """Iterate over the root.""" + return iter(self.root) + + def __getitem__(self, item: int) -> NodeValidatorType: + """To get item directly from the root. + + :param item: Index + + :return: Item + """ + return self.root[item] + + @model_validator(mode="before") + @classmethod + def validate_nodes(cls, data: list[Any]) -> list[Any]: # noqa: PLR0912,C901 + """Validate the search space configuration. + + Args: + data: List of search space configurations. + + Returns: + List of validated search space configurations. + """ + error_message = "" + for i, item in enumerate(data): + if isinstance(item, BaseModel): + continue + if not isinstance(item, dict): + msg = "Each search space configuration must be a dictionary." + raise TypeError(msg) + if "node_type" not in item: + msg = "Each search space configuration must have a 'node_type' key." + raise TypeError(msg) + if not isinstance(item.get("search_space"), list): + msg = "Each search space configuration must have a 'search_space' key of type list." + raise TypeError(msg) + for search_space in item["search_space"]: + node_name = search_space.get("module_name") + if node_name is None: + error_message += f"Search space configuration at index {i} is missing 'module_name'.\n" + continue + if item["node_type"] == NodeType.decision.value: + node_class = DecisionNodesBaseModels[node_name] + elif item["node_type"] == NodeType.embedding.value: + node_class = EmbeddingBaseModels[node_name] + elif item["node_type"] == NodeType.scoring.value: + node_class = ScoringNodesBaseModels[node_name] + elif item["node_type"] == NodeType.regex.value: + node_class = RegexNodesBaseModels[node_name] + else: + error_message += f"Unknown node type '{item['node_type']}' at index {i}.\n" + break + + try: + node_class(**search_space) + except ValidationError as e: + error_message += f"Search space configuration at index {i} {node_name} is invalid: {e}\n" + continue + if len(error_message) > 0: + raise TypeError(error_message) + return data diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index c21eb779a..a8c883b40 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -29,7 +29,7 @@ clf_name: [RandomForestClassifier] n_estimators: [5, 10] - module_name: bert - model_config: + classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 num_train_epochs: [1] batch_size: [8, 16] diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index f867c6109..a5702eb54 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -25,7 +25,7 @@ clf_name: [RandomForestClassifier] n_estimators: [5, 10] - module_name: bert - model_config: + classification_model_config: - model_name: avsolatorio/GIST-small-Embedding-v0 num_train_epochs: [1] batch_size: [8] diff --git a/tests/configs/test_combined_config.py b/tests/configs/test_combined_config.py index 41dc5bc7a..81312c7f0 100644 --- a/tests/configs/test_combined_config.py +++ b/tests/configs/test_combined_config.py @@ -74,8 +74,7 @@ def test_invalid_optimizer_config_missing_field(): def test_invalid_optimizer_config_wrong_type(): """Test that an invalid field type raises ValidationError.""" - invalid_config = [ - { + invalid_config = { "node_type": "scoring", "target_metric": "scoring_roc_auc", "search_space": [ @@ -87,7 +86,6 @@ def test_invalid_optimizer_config_wrong_type(): } ], } - ] with pytest.raises(TypeError): NodeOptimizer(**invalid_config) diff --git a/tests/modules/scoring/test_bert.py b/tests/modules/scoring/test_bert.py index 3ef319703..03da17c78 100644 --- a/tests/modules/scoring/test_bert.py +++ b/tests/modules/scoring/test_bert.py @@ -9,7 +9,7 @@ def test_bert_prediction(dataset): """Test that the transformer model can fit and make predictions.""" data_handler = DataHandler(dataset) - scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) @@ -46,7 +46,7 @@ def test_bert_cache_clearing(dataset): """Test that the transformer model properly handles cache clearing.""" data_handler = DataHandler(dataset) - scorer = BertScorer(model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) + scorer = BertScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8) scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) diff --git a/tests/nodes/test_decision.py b/tests/nodes/test_decision.py index 8bc65f820..826ca9949 100644 --- a/tests/nodes/test_decision.py +++ b/tests/nodes/test_decision.py @@ -19,7 +19,7 @@ def test_decision_multiclass(scoring_optimizer_multiclass): "node_type": "decision", "search_space": [ {"module_name": "threshold", "thresh": [0.5]}, - {"module_name": "tunable", "n_trials": [None, 3]}, + {"module_name": "tunable", "n_trials": 3}, { "module_name": "argmax", }, @@ -58,7 +58,7 @@ def test_decision_multilabel(scoring_optimizer_multilabel): "node_type": "decision", "search_space": [ {"module_name": "threshold", "thresh": [0.5]}, - {"module_name": "tunable", "n_trials": [None, 3]}, + {"module_name": "tunable", "n_trials": 3}, {"module_name": "adaptive"}, ], }