diff --git a/autointent/_dump_tools.py b/autointent/_dump_tools.py index fbd982f5a..645e2e2c0 100644 --- a/autointent/_dump_tools.py +++ b/autointent/_dump_tools.py @@ -7,8 +7,16 @@ import joblib import numpy as np import numpy.typing as npt +from peft import PeftModel from pydantic import BaseModel from sklearn.base import BaseEstimator +from transformers import ( # type: ignore[attr-defined] + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from autointent import Embedder, Ranker, VectorIndex from autointent.configs import CrossEncoderConfig, EmbedderConfig @@ -34,6 +42,7 @@ class Dumper: pydantic_models: str = "pydantic" hf_models = "hf_models" hf_tokenizers = "hf_tokenizers" + ptuning_models = "ptuning_models" @staticmethod def make_subdirectories(path: Path, exists_ok: bool = False) -> None: @@ -52,6 +61,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None: path / Dumper.pydantic_models, path / Dumper.hf_models, path / Dumper.hf_tokenizers, + path / Dumper.ptuning_models, ] for subdir in subdirectories: subdir.mkdir(parents=True, exist_ok=exists_ok) @@ -101,25 +111,38 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]] except Exception as e: msg = f"Error dumping pydantic model {key}: {e}" logging.exception(msg) - elif (key == "_model" or "model" in key.lower()) and hasattr(val, "save_pretrained"): + elif isinstance(val, PeftModel): + # dumping peft models is a nightmare... + # this might break with new versions of peft + try: + if val._is_prompt_learning: # noqa: SLF001 + # strategy to save prompt learning models: save prompt encoder and bert classifier separately + model_path = path / Dumper.ptuning_models / key + model_path.mkdir(parents=True, exist_ok=True) + val.save_pretrained(str(model_path / "peft")) + val.base_model.save_pretrained(model_path / "base_model") # type: ignore[attr-defined] + else: + # strategy to save lora models: merge adapters and save as usual hugging face model + model_path = path / Dumper.hf_models / key + model_path.mkdir(parents=True, exist_ok=True) + merged_model: PreTrainedModel = val.merge_and_unload() + merged_model.save_pretrained(model_path) # type: ignore[attr-defined] + except Exception as e: + msg = f"Error dumping PeftModel {key}: {e}" + logger.exception(msg) + elif isinstance(val, PreTrainedModel): model_path = path / Dumper.hf_models / key model_path.mkdir(parents=True, exist_ok=True) try: - val.save_pretrained(model_path) - class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__} - with (model_path / "class_info.json").open("w") as f: - json.dump(class_info, f) + val.save_pretrained(model_path) # type: ignore[attr-defined] except Exception as e: msg = f"Error dumping HF model {key}: {e}" logger.exception(msg) - elif (key == "_tokenizer" or "tokenizer" in key.lower()) and hasattr(val, "save_pretrained"): + elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast): tokenizer_path = path / Dumper.hf_tokenizers / key tokenizer_path.mkdir(parents=True, exist_ok=True) try: - val.save_pretrained(tokenizer_path) - class_info = {"module": val.__class__.__module__, "name": val.__class__.__name__} - with (tokenizer_path / "class_info.json").open("w") as f: - json.dump(class_info, f) + val.save_pretrained(tokenizer_path) # type: ignore[union-attr] except Exception as e: msg = f"Error dumping HF tokenizer {key}: {e}" logger.exception(msg) @@ -202,29 +225,25 @@ def load( # noqa: C901, PLR0912, PLR0915 msg = f"Error loading Pydantic model from {model_dir}: {e}" logger.exception(msg) continue + elif child.name == Dumper.ptuning_models: + for model_dir in child.iterdir(): + try: + model = AutoModelForSequenceClassification.from_pretrained(model_dir / "base_model") + hf_models[model_dir.name] = PeftModel.from_pretrained(model, model_dir / "peft") + except Exception as e: # noqa: PERF203 + msg = f"Error loading PeftModel {model_dir.name}: {e}" + logger.exception(msg) elif child.name == Dumper.hf_models: for model_dir in child.iterdir(): try: - with (model_dir / "class_info.json").open("r") as f: - class_info = json.load(f) - - module = __import__(class_info["module"], fromlist=[class_info["name"]]) - model_class = getattr(module, class_info["name"]) - - hf_models[model_dir.name] = model_class.from_pretrained(model_dir) + hf_models[model_dir.name] = AutoModelForSequenceClassification.from_pretrained(model_dir) except Exception as e: # noqa: PERF203 msg = f"Error loading HF model {model_dir.name}: {e}" logger.exception(msg) elif child.name == Dumper.hf_tokenizers: for tokenizer_dir in child.iterdir(): try: - with (tokenizer_dir / "class_info.json").open("r") as f: - class_info = json.load(f) - - module = __import__(class_info["module"], fromlist=[class_info["name"]]) - tokenizer_class = getattr(module, class_info["name"]) - - hf_tokenizers[tokenizer_dir.name] = tokenizer_class.from_pretrained(tokenizer_dir) + hf_tokenizers[tokenizer_dir.name] = AutoTokenizer.from_pretrained(tokenizer_dir) except Exception as e: # noqa: PERF203 msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}" logger.exception(msg) diff --git a/autointent/context/_context.py b/autointent/context/_context.py index f4cc05bca..c3e35f87f 100644 --- a/autointent/context/_context.py +++ b/autointent/context/_context.py @@ -7,7 +7,7 @@ from autointent import Dataset from autointent._callbacks import CallbackHandler, get_callbacks -from autointent.configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig +from autointent.configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig from .data_handler import DataHandler from .optimization_info import OptimizationInfo @@ -49,7 +49,7 @@ def configure_logging(self, config: LoggingConfig) -> None: self.callback_handler = get_callbacks(config.report_to) self.optimization_info = OptimizationInfo() - def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig) -> None: + def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig | HFModelConfig) -> None: """Configure the vector index client and embedder. Args: @@ -59,6 +59,8 @@ def configure_transformer(self, config: EmbedderConfig | CrossEncoderConfig) -> self.embedder_config = config elif isinstance(config, CrossEncoderConfig): self.cross_encoder_config = config + elif isinstance(config, HFModelConfig): + self.transformer_config = config def set_dataset(self, dataset: Dataset, config: DataConfig) -> None: """Set the datasets for training, validation and testing. @@ -133,31 +135,40 @@ def has_saved_modules(self) -> bool: def resolve_embedder(self) -> EmbedderConfig: """Resolve the embedder configuration. - Returns the best embedder configuration or default configuration. - - Raises: - RuntimeError: If embedder configuration cannot be resolved. + This method returns the configuration with the following priorities: + - the best embedder configuration obtained during embedding node optimization + - default configuration preset by user with :py:meth:`Context.configure_transformer` + - default configuration preset by AutoIntent in :py:class:`autointent.configs.EmbedderConfig` """ try: return self.optimization_info.get_best_embedder() - except ValueError as e: + except ValueError: if hasattr(self, "embedder_config"): return self.embedder_config - msg = ( - "Embedder could't be resolved. Either include embedding node into the " - "search space or set default config with Context.configure_transformer." - ) - raise RuntimeError(msg) from e + return EmbedderConfig() def resolve_ranker(self) -> CrossEncoderConfig: """Resolve the cross-encoder configuration. - Returns default config if set. - - Raises: - RuntimeError: If cross-encoder configuration cannot be resolved. + This method returns the configuration with the following priorities: + - default configuration preset by user with :py:meth:`Context.configure_transformer` + - default configuration preset by AutoIntent in :py:class:`autointent.configs.CrossEncoderConfig` """ if hasattr(self, "cross_encoder_config"): return self.cross_encoder_config - msg = "Cross-encoder could't be resolved. Set default config with Context.configure_transformer." - raise RuntimeError(msg) + return CrossEncoderConfig() + + def resolve_transformer(self) -> HFModelConfig: + """Resolve the transformer configuration. + + This method returns the configuration with the following priorities: + - the best transformer configuration obtained during embedding node optimization + - default configuration preset by user with :py:meth:`Context.configure_transformer` + - default configuration preset by AutoIntent in :py:class:`autointent.configs.HFModelConfig` + """ + try: + return self.optimization_info.get_best_embedder() + except ValueError: + if hasattr(self, "transformer_config"): + return self.transformer_config + return HFModelConfig() diff --git a/autointent/modules/base/_base.py b/autointent/modules/base/_base.py index 71343e1d6..9cda8b5a3 100644 --- a/autointent/modules/base/_base.py +++ b/autointent/modules/base/_base.py @@ -138,9 +138,16 @@ def from_context(cls, context: Context, **kwargs: dict[str, Any]) -> "BaseModule Initialized module """ - def get_embedder_config(self) -> dict[str, Any] | None: - """Get the config of the embedder.""" - return None + @abstractmethod + def get_implicit_initialization_params(self) -> dict[str, Any]: + """Return default params used in ``__init__`` method. + + Some parameters of the module may be inferred using context rather from ``__init__`` method. + But they need to be logged for reproducibility during loading from disk. + + Returns: + Dictionary of default params + """ @staticmethod def score_metrics_ho(params: tuple[Any, Any], metrics_dict: dict[str, Any]) -> dict[str, float]: diff --git a/autointent/modules/base/_decision.py b/autointent/modules/base/_decision.py index e9487ed31..eec8e9349 100644 --- a/autointent/modules/base/_decision.py +++ b/autointent/modules/base/_decision.py @@ -18,6 +18,9 @@ class BaseDecision(BaseModule, ABC): """Base class for decision modules.""" + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {} + @abstractmethod def fit( self, diff --git a/autointent/modules/base/_embedding.py b/autointent/modules/base/_embedding.py index 376fe5a82..22b845eb5 100644 --- a/autointent/modules/base/_embedding.py +++ b/autointent/modules/base/_embedding.py @@ -1,6 +1,7 @@ """Base class for embedding modules.""" from abc import ABC +from typing import Any from autointent import Context from autointent.custom_types import ListOfLabels @@ -10,6 +11,9 @@ class BaseEmbedding(BaseModule, ABC): """Base class for embedding modules.""" + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {} + def get_train_data(self, context: Context) -> tuple[list[str], ListOfLabels]: """Get train data. diff --git a/autointent/modules/base/_regex.py b/autointent/modules/base/_regex.py index 34656b799..cdd9f389e 100644 --- a/autointent/modules/base/_regex.py +++ b/autointent/modules/base/_regex.py @@ -1,9 +1,13 @@ """Base class for embedding modules.""" from abc import ABC +from typing import Any from autointent.modules.base import BaseModule class BaseRegex(BaseModule, ABC): """Base class for rule-based modules.""" + + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {} diff --git a/autointent/modules/scoring/_bert.py b/autointent/modules/scoring/_bert.py index be29935de..37331c28c 100644 --- a/autointent/modules/scoring/_bert.py +++ b/autointent/modules/scoring/_bert.py @@ -26,8 +26,8 @@ class BertScorer(BaseScorer): name = "bert" supports_multiclass = True supports_multilabel = True - _model: Any - _tokenizer: Any + _model: Any # transformers AutoModel factory returns Any + _tokenizer: Any # transformers AutoTokenizer factory returns Any def __init__( self, @@ -56,7 +56,7 @@ def from_context( seed: int = 0, ) -> "BertScorer": if classification_model_config is None: - classification_model_config = context.resolve_embedder() + classification_model_config = context.resolve_transformer() report_to = context.logging_config.report_to @@ -69,14 +69,14 @@ def from_context( report_to=report_to, ) - def get_embedder_config(self) -> dict[str, Any]: - return self.classification_model_config.model_dump() + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {"classification_model_config": self.classification_model_config.model_dump()} - def __initialize_model(self) -> None: + def _initialize_model(self) -> Any: # noqa: ANN401 label2id = {i: i for i in range(self._n_classes)} id2label = {i: i for i in range(self._n_classes)} - self._model = AutoModelForSequenceClassification.from_pretrained( + return AutoModelForSequenceClassification.from_pretrained( self.classification_model_config.model_name, trust_remote_code=self.classification_model_config.trust_remote_code, num_labels=self._n_classes, @@ -96,7 +96,7 @@ def fit( self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name) - self.__initialize_model() + self._model = self._initialize_model() use_cpu = self.classification_model_config.device == "cpu" @@ -126,7 +126,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]: save_strategy="no", logging_strategy="steps", logging_steps=10, - report_to=self.report_to, + report_to=self.report_to if self.report_to is not None else "none", use_cpu=use_cpu, ) diff --git a/autointent/modules/scoring/_description/description.py b/autointent/modules/scoring/_description/description.py index 657de93ac..2ef7d4b51 100644 --- a/autointent/modules/scoring/_description/description.py +++ b/autointent/modules/scoring/_description/description.py @@ -76,9 +76,9 @@ def from_context( Returns: Initialized DescriptionScorer instance """ - if embedder_config is None: + if embedder_config is None and encoder_type == "bi": embedder_config = context.resolve_embedder() - if cross_encoder_config is None: + if cross_encoder_config is None and encoder_type == "cross": cross_encoder_config = context.resolve_ranker() return cls( @@ -88,21 +88,13 @@ def from_context( encoder_type=encoder_type, ) - def get_embedder_config(self) -> dict[str, Any]: - """Get the configuration of the embedder. - - Returns: - Embedder configuration - """ - return self.embedder_config.model_dump() - - def get_cross_encoder_config(self) -> dict[str, Any]: - """Get the configuration of the cross-encoder. - - Returns: - Cross-encoder configuration - """ - return self.cross_encoder_config.model_dump() + def get_implicit_initialization_params(self) -> dict[str, Any]: + res = {} + if self._encoder_type == "bi": + res["embedder_config"] = self.embedder_config.model_dump() + else: + res["cross_encoder_config"] = self.cross_encoder_config.model_dump() + return res def fit( self, diff --git a/autointent/modules/scoring/_dnnc/dnnc.py b/autointent/modules/scoring/_dnnc/dnnc.py index 8c389e5a5..48792dbe9 100644 --- a/autointent/modules/scoring/_dnnc/dnnc.py +++ b/autointent/modules/scoring/_dnnc/dnnc.py @@ -101,6 +101,12 @@ def from_context( cross_encoder_config=cross_encoder_config, ) + def get_implicit_initialization_params(self) -> dict[str, Any]: + return { + "embedder_config": self.embedder_config.model_dump(), + "cross_encoder_config": self.cross_encoder_config.model_dump(), + } + def fit(self, utterances: list[str], labels: ListOfLabels) -> None: """Fit the scorer by training or loading the vector index. diff --git a/autointent/modules/scoring/_knn/knn.py b/autointent/modules/scoring/_knn/knn.py index 749bbdc31..bad7a6ef8 100644 --- a/autointent/modules/scoring/_knn/knn.py +++ b/autointent/modules/scoring/_knn/knn.py @@ -97,13 +97,8 @@ def from_context( weights=weights, ) - def get_embedder_config(self) -> dict[str, Any]: - """Get the name of the embedder. - - Returns: - Embedder name - """ - return self.embedder_config.model_dump() + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {"embedder_config": self.embedder_config.model_dump()} def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = False) -> None: """Fit the scorer by training or loading the vector index. diff --git a/autointent/modules/scoring/_knn/rerank_scorer.py b/autointent/modules/scoring/_knn/rerank_scorer.py index 656fb09b8..09ecab10b 100644 --- a/autointent/modules/scoring/_knn/rerank_scorer.py +++ b/autointent/modules/scoring/_knn/rerank_scorer.py @@ -96,6 +96,12 @@ def from_context( cross_encoder_config=cross_encoder_config, ) + def get_implicit_initialization_params(self) -> dict[str, Any]: + return { + "embedder_config": self.embedder_config.model_dump(), + "cross_encoder_config": self.cross_encoder_config.model_dump(), + } + def fit(self, utterances: list[str], labels: ListOfLabels) -> None: """Fit the RerankScorer with utterances and labels. diff --git a/autointent/modules/scoring/_linear.py b/autointent/modules/scoring/_linear.py index be74ada89..0e9bef76a 100644 --- a/autointent/modules/scoring/_linear.py +++ b/autointent/modules/scoring/_linear.py @@ -90,13 +90,8 @@ def from_context( embedder_config=embedder_config, ) - def get_embedder_config(self) -> dict[str, Any]: - """Get the name of the embedder. - - Returns: - Embedder name - """ - return self.embedder_config.model_dump() + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {"embedder_config": self.embedder_config.model_dump()} def fit( self, diff --git a/autointent/modules/scoring/_lora/lora.py b/autointent/modules/scoring/_lora/lora.py index 80c24c8cb..124cc4baf 100644 --- a/autointent/modules/scoring/_lora/lora.py +++ b/autointent/modules/scoring/_lora/lora.py @@ -1,12 +1,13 @@ """BertScorer class for transformer-based classification with LoRA.""" +from pathlib import Path from typing import Any from peft import LoraConfig, get_peft_model -from transformers import AutoModelForSequenceClassification from autointent import Context from autointent._callbacks import REPORTERS_NAMES +from autointent._dump_tools import Dumper from autointent.configs import HFModelConfig from autointent.modules.scoring._bert import BertScorer @@ -59,10 +60,6 @@ class BERTLoRAScorer(BertScorer): """ name = "lora" - supports_multiclass = True - supports_multilabel = True - _model: Any - _tokenizer: Any def __init__( self, @@ -72,7 +69,7 @@ def __init__( learning_rate: float = 5e-5, seed: int = 0, report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type] - **lora_kwargs: dict[str, Any], + **lora_kwargs: Any, # noqa: ANN401 ) -> None: super().__init__( classification_model_config=classification_model_config, @@ -82,7 +79,7 @@ def __init__( seed=seed, report_to=report_to, ) - self._lora_config = LoraConfig(**lora_kwargs) # type: ignore[arg-type] + self._lora_config = LoraConfig(**lora_kwargs) @classmethod def from_context( @@ -93,10 +90,10 @@ def from_context( batch_size: int = 8, learning_rate: float = 5e-5, seed: int = 0, - **lora_kwargs: dict[str, Any], + **lora_kwargs: Any, # noqa: ANN401 ) -> "BERTLoRAScorer": if classification_model_config is None: - classification_model_config = context.resolve_embedder() + classification_model_config = context.resolve_transformer() return cls( classification_model_config=classification_model_config, num_train_epochs=num_train_epochs, @@ -107,11 +104,9 @@ def from_context( **lora_kwargs, ) - def __initialize_model(self) -> None: - self._model = AutoModelForSequenceClassification.from_pretrained( - self.classification_model_config.model_name, - num_labels=self._n_classes, - problem_type="multi_label_classification" if self._multilabel else "single_label_classification", - trust_remote_code=self.classification_model_config.trust_remote_code, - ) - self._model = get_peft_model(self._model, self._lora_config) + def _initialize_model(self) -> Any: # noqa: ANN401 + model = super()._initialize_model() + return get_peft_model(model, self._lora_config) + + def dump(self, path: str) -> None: + Dumper.dump(self, Path(path), exclude=[LoraConfig]) diff --git a/autointent/modules/scoring/_mlknn/mlknn.py b/autointent/modules/scoring/_mlknn/mlknn.py index 306453010..3021d9349 100644 --- a/autointent/modules/scoring/_mlknn/mlknn.py +++ b/autointent/modules/scoring/_mlknn/mlknn.py @@ -111,13 +111,8 @@ def from_context( ignore_first_neighbours=ignore_first_neighbours, ) - def get_embedder_config(self) -> dict[str, Any]: - """Get the name of the embedder. - - Returns: - Embedder name - """ - return self.embedder_config.model_dump() + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {"embedder_config": self.embedder_config.model_dump()} def fit(self, utterances: list[str], labels: ListOfLabels) -> None: """Fit the scorer by training or loading the vector index and calculating probabilities. diff --git a/autointent/modules/scoring/_ptuning/ptuning.py b/autointent/modules/scoring/_ptuning/ptuning.py index d557dc476..417cb138f 100644 --- a/autointent/modules/scoring/_ptuning/ptuning.py +++ b/autointent/modules/scoring/_ptuning/ptuning.py @@ -1,15 +1,13 @@ """PTuningScorer class for ptuning-based classification.""" +from pathlib import Path from typing import Any -import torch from peft import PromptEncoderConfig, get_peft_model -from transformers import ( - AutoModelForSequenceClassification, -) from autointent import Context from autointent._callbacks import REPORTERS_NAMES +from autointent._dump_tools import Dumper from autointent.configs import HFModelConfig from autointent.modules.scoring._bert import BertScorer @@ -54,10 +52,6 @@ class PTuningScorer(BertScorer): """ name = "ptuning" - supports_multiclass = True - supports_multilabel = True - _model: Any - _tokenizer: Any def __init__( self, @@ -67,7 +61,7 @@ def __init__( learning_rate: float = 5e-5, seed: int = 0, report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type] - **ptuning_kwargs: dict[str, Any], + **ptuning_kwargs: Any, # noqa: ANN401 ) -> None: super().__init__( classification_model_config=classification_model_config, @@ -77,8 +71,7 @@ def __init__( seed=seed, report_to=report_to, ) - self._ptuning_config = PromptEncoderConfig(**ptuning_kwargs) # type: ignore[arg-type] - torch.manual_seed(seed) + self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs) @classmethod def from_context( @@ -89,7 +82,7 @@ def from_context( batch_size: int = 8, learning_rate: float = 5e-5, seed: int = 0, - **ptuning_kwargs: dict[str, Any], + **ptuning_kwargs: Any, # noqa: ANN401 ) -> "PTuningScorer": """Create a PTuningScorer instance using a Context object. @@ -103,7 +96,7 @@ def from_context( **ptuning_kwargs: Arguments for PromptEncoderConfig """ if classification_model_config is None: - classification_model_config = context.resolve_embedder() + classification_model_config = context.resolve_transformer() report_to = context.logging_config.report_to @@ -117,14 +110,10 @@ def from_context( **ptuning_kwargs, ) - def _initialize_model(self) -> None: + def _initialize_model(self) -> Any: # noqa: ANN401 """Initialize the model with P-tuning configuration.""" - model_name = self.classification_model_config.model_name - self._model = AutoModelForSequenceClassification.from_pretrained( - model_name, - num_labels=self._n_classes, - problem_type="multi_label_classification" if self._multilabel else "single_label_classification", - trust_remote_code=self.classification_model_config.trust_remote_code, - return_dict=True, - ) - self._model = get_peft_model(self._model, self._ptuning_config) + model = super()._initialize_model() + return get_peft_model(model, self._ptuning_config) + + def dump(self, path: str) -> None: + Dumper.dump(self, Path(path), exclude=[PromptEncoderConfig]) diff --git a/autointent/modules/scoring/_sklearn/sklearn_scorer.py b/autointent/modules/scoring/_sklearn/sklearn_scorer.py index 8e5ed51c2..62b340f94 100644 --- a/autointent/modules/scoring/_sklearn/sklearn_scorer.py +++ b/autointent/modules/scoring/_sklearn/sklearn_scorer.py @@ -104,6 +104,9 @@ def from_context( **clf_args, ) + def get_implicit_initialization_params(self) -> dict[str, Any]: + return {"embedder_config": self.embedder_config.model_dump()} + def fit( self, utterances: list[str], diff --git a/autointent/nodes/_node_optimizer.py b/autointent/nodes/_node_optimizer.py index a0b8f71b8..c5f91bced 100644 --- a/autointent/nodes/_node_optimizer.py +++ b/autointent/nodes/_node_optimizer.py @@ -129,10 +129,7 @@ def objective( self._logger.debug("Initializing %s module with config: %s", module_name, json.dumps(config)) module = self.node_info.modules_available[module_name].from_context(context, **config) - - embedder_config = module.get_embedder_config() - if embedder_config is not None: - config["embedder_config"] = embedder_config + config.update(module.get_implicit_initialization_params()) context.callback_handler.start_module(module_name=module_name, num=self._counter, module_kwargs=config) diff --git a/autointent/utils.py b/autointent/utils.py index f08de3db5..c46d5e609 100644 --- a/autointent/utils.py +++ b/autointent/utils.py @@ -9,17 +9,21 @@ from autointent.custom_types import SearchSpacePreset -def load_search_space(path: Path | str) -> list[dict[str, Any]]: +def load_search_space(path_or_str: Path | str) -> list[dict[str, Any]]: """Load hyperparameters search space from file. Args: - path: Path to the search space file. + path_or_str: Path to the search space file or string representation of the search space. Returns: List of dictionaries representing the search space. """ - with Path(path).open() as file: - return yaml.safe_load(file) # type: ignore[no-any-return] + if isinstance(path_or_str, Path): + with path_or_str.open() as file: + return yaml.safe_load(file) # type: ignore[no-any-return] + else: + # string representation of the search space + return yaml.safe_load(path_or_str) # type: ignore[no-any-return] def load_preset(name: SearchSpacePreset) -> dict[str, Any]: diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index 6828d6505..6e2390b7d 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -48,7 +48,6 @@ classification_model_config: ["prajjwal1/bert-tiny"] num_train_epochs: [1] batch_size: [8, 16] - task_type: ["SEQ_CLS"] num_virtual_tokens: [10, 20] - node_type: decision target_metric: decision_accuracy diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 3cfe6f63c..e82db8e47 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -36,7 +36,6 @@ classification_model_config: ["prajjwal1/bert-tiny"] num_train_epochs: [1] batch_size: [8] - task_type: ["SEQ_CLS"] num_virtual_tokens: [10, 20] - module_name: lora classification_model_config: diff --git a/tests/modules/scoring/test_ptuning.py b/tests/modules/scoring/test_ptuning.py index d90cc7b13..4d6ff2385 100644 --- a/tests/modules/scoring/test_ptuning.py +++ b/tests/modules/scoring/test_ptuning.py @@ -17,7 +17,6 @@ def test_ptuning_scorer_dump_load(dataset): classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8, - task_type="SEQ_CLS", num_virtual_tokens=10, seed=42, ) @@ -38,7 +37,6 @@ def test_ptuning_scorer_dump_load(dataset): classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8, - task_type="SEQ_CLS", num_virtual_tokens=10, seed=42, ) @@ -66,7 +64,6 @@ def test_ptuning_prediction(dataset): classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8, - task_type="SEQ_CLS", num_virtual_tokens=10, seed=42, ) @@ -106,7 +103,6 @@ def test_ptuning_cache_clearing(dataset): classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8, - task_type="SEQ_CLS", num_virtual_tokens=20, seed=42, )