diff --git a/autointent/_optimization_config.py b/autointent/_optimization_config.py index 6c55ee148..b314aa4bf 100644 --- a/autointent/_optimization_config.py +++ b/autointent/_optimization_config.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, PositiveInt -from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig +from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig from .custom_types import SamplerType @@ -25,6 +25,8 @@ class OptimizationConfig(BaseModel): cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig() + transformer_config: HFModelConfig = HFModelConfig() + sampler: SamplerType = "brute" """See tutorial on optuna and presets.""" diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index a141352ba..8966a4fbc 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -14,6 +14,7 @@ CrossEncoderConfig, DataConfig, EmbedderConfig, + HFModelConfig, InferenceNodeConfig, LoggingConfig, ) @@ -67,10 +68,13 @@ def __init__( self.embedder_config = EmbedderConfig() self.cross_encoder_config = CrossEncoderConfig() self.data_config = DataConfig() + self.transformer_config = HFModelConfig() elif not isinstance(nodes[0], InferenceNode): assert_never(nodes) - def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig) -> None: + def set_config( + self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig + ) -> None: """Set the configuration for the pipeline. Args: @@ -84,6 +88,8 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig self.cross_encoder_config = config elif isinstance(config, DataConfig): self.data_config = config + elif isinstance(config, HFModelConfig): + self.transformer_config = config else: assert_never(config) @@ -133,6 +139,7 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza pipeline.set_config(optimization_config.data_config) pipeline.set_config(optimization_config.embedder_config) pipeline.set_config(optimization_config.cross_encoder_config) + pipeline.set_config(optimization_config.transformer_config) return pipeline def _fit(self, context: Context, sampler: SamplerType) -> None: @@ -198,6 +205,7 @@ def fit( context.configure_logging(self.logging_config) context.configure_transformer(self.embedder_config) context.configure_transformer(self.cross_encoder_config) + context.configure_transformer(self.transformer_config) self.validate_modules(dataset, mode=incompatible_search_space) diff --git a/autointent/modules/scoring/_ptuning/ptuning.py b/autointent/modules/scoring/_ptuning/ptuning.py index 417cb138f..c0733134a 100644 --- a/autointent/modules/scoring/_ptuning/ptuning.py +++ b/autointent/modules/scoring/_ptuning/ptuning.py @@ -1,9 +1,10 @@ """PTuningScorer class for ptuning-based classification.""" from pathlib import Path -from typing import Any +from typing import Any, Literal -from peft import PromptEncoderConfig, get_peft_model +from peft import PromptEncoderConfig, PromptEncoderReparameterizationType, TaskType, get_peft_model +from pydantic import PositiveInt from autointent import Context from autointent._callbacks import REPORTERS_NAMES @@ -53,14 +54,19 @@ class PTuningScorer(BertScorer): name = "ptuning" - def __init__( + def __init__( # noqa: PLR0913 self, classification_model_config: HFModelConfig | str | dict[str, Any] | None = None, - num_train_epochs: int = 3, - batch_size: int = 8, + num_train_epochs: PositiveInt = 3, + batch_size: PositiveInt = 8, learning_rate: float = 5e-5, seed: int = 0, report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type] + encoder_reparameterization_type: Literal["MLP", "LSTM"] = "LSTM", + num_virtual_tokens: PositiveInt = 10, + encoder_dropout: float = 0.1, + encoder_hidden_size: PositiveInt = 128, + encoder_num_layers: PositiveInt = 2, **ptuning_kwargs: Any, # noqa: ANN401 ) -> None: super().__init__( @@ -71,17 +77,30 @@ def __init__( seed=seed, report_to=report_to, ) - self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs) + self._ptuning_config = PromptEncoderConfig( + task_type=TaskType.SEQ_CLS, + encoder_reparameterization_type=PromptEncoderReparameterizationType(encoder_reparameterization_type), + num_virtual_tokens=num_virtual_tokens, + encoder_dropout=encoder_dropout, + encoder_hidden_size=encoder_hidden_size, + encoder_num_layers=encoder_num_layers, + **ptuning_kwargs, + ) @classmethod - def from_context( + def from_context( # noqa: PLR0913 cls, context: Context, classification_model_config: HFModelConfig | str | dict[str, Any] | None = None, - num_train_epochs: int = 3, - batch_size: int = 8, + num_train_epochs: PositiveInt = 3, + batch_size: PositiveInt = 8, learning_rate: float = 5e-5, seed: int = 0, + encoder_reparameterization_type: Literal["MLP", "LSTM"] = "LSTM", + num_virtual_tokens: PositiveInt = 10, + encoder_dropout: float = 0.1, + encoder_hidden_size: PositiveInt = 128, + encoder_num_layers: PositiveInt = 2, **ptuning_kwargs: Any, # noqa: ANN401 ) -> "PTuningScorer": """Create a PTuningScorer instance using a Context object. @@ -93,6 +112,11 @@ def from_context( batch_size: Batch size for training learning_rate: Learning rate for training seed: Random seed for reproducibility + encoder_reparameterization_type: Reparametrization type for the prompt encoder + num_virtual_tokens: Number of virtual tokens for the prompt encoder + encoder_dropout: Dropout for the prompt encoder + encoder_hidden_size: Hidden size for the prompt encoder + encoder_num_layers: Number of layers for the prompt encoder **ptuning_kwargs: Arguments for PromptEncoderConfig """ if classification_model_config is None: @@ -107,6 +131,11 @@ def from_context( learning_rate=learning_rate, seed=seed, report_to=report_to, + encoder_reparameterization_type=encoder_reparameterization_type, + num_virtual_tokens=num_virtual_tokens, + encoder_dropout=encoder_dropout, + encoder_hidden_size=encoder_hidden_size, + encoder_num_layers=encoder_num_layers, **ptuning_kwargs, ) diff --git a/docs/optimizer_config.schema.json b/docs/optimizer_config.schema.json index 27e83abe8..f9e9f7e18 100644 --- a/docs/optimizer_config.schema.json +++ b/docs/optimizer_config.schema.json @@ -253,6 +253,48 @@ "title": "EmbedderConfig", "type": "object" }, + "HFModelConfig": { + "additionalProperties": false, + "properties": { + "model_name": { + "default": "prajjwal1/bert-tiny", + "description": "Name of the hugging face repository with transformer model.", + "title": "Model Name", + "type": "string" + }, + "batch_size": { + "default": 32, + "description": "Batch size for model inference.", + "exclusiveMinimum": 0, + "title": "Batch Size", + "type": "integer" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Torch notation for CPU or CUDA.", + "title": "Device" + }, + "tokenizer_config": { + "$ref": "#/$defs/TokenizerConfig" + }, + "trust_remote_code": { + "default": false, + "description": "Whether to trust the remote code when loading the model.", + "title": "Trust Remote Code", + "type": "boolean" + } + }, + "title": "HFModelConfig", + "type": "object" + }, "LoggingConfig": { "additionalProperties": false, "description": "Configuration for the logging.", @@ -442,6 +484,20 @@ "train_head": false } }, + "transformer_config": { + "$ref": "#/$defs/HFModelConfig", + "default": { + "model_name": "prajjwal1/bert-tiny", + "batch_size": 32, + "device": null, + "tokenizer_config": { + "max_length": null, + "padding": true, + "truncation": true + }, + "trust_remote_code": false + } + }, "sampler": { "default": "brute", "enum": [