Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion autointent/_optimization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel, PositiveInt

from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig
from .custom_types import SamplerType


Expand All @@ -25,6 +25,8 @@ class OptimizationConfig(BaseModel):

cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()

transformer_config: HFModelConfig = HFModelConfig()

sampler: SamplerType = "brute"
"""See tutorial on optuna and presets."""

Expand Down
10 changes: 9 additions & 1 deletion autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
CrossEncoderConfig,
DataConfig,
EmbedderConfig,
HFModelConfig,
InferenceNodeConfig,
LoggingConfig,
)
Expand Down Expand Up @@ -67,10 +68,13 @@ def __init__(
self.embedder_config = EmbedderConfig()
self.cross_encoder_config = CrossEncoderConfig()
self.data_config = DataConfig()
self.transformer_config = HFModelConfig()
elif not isinstance(nodes[0], InferenceNode):
assert_never(nodes)

def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig) -> None:
def set_config(
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig
) -> None:
"""Set the configuration for the pipeline.

Args:
Expand All @@ -84,6 +88,8 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
self.cross_encoder_config = config
elif isinstance(config, DataConfig):
self.data_config = config
elif isinstance(config, HFModelConfig):
self.transformer_config = config
else:
assert_never(config)

Expand Down Expand Up @@ -133,6 +139,7 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
pipeline.set_config(optimization_config.data_config)
pipeline.set_config(optimization_config.embedder_config)
pipeline.set_config(optimization_config.cross_encoder_config)
pipeline.set_config(optimization_config.transformer_config)
return pipeline

def _fit(self, context: Context, sampler: SamplerType) -> None:
Expand Down Expand Up @@ -198,6 +205,7 @@ def fit(
context.configure_logging(self.logging_config)
context.configure_transformer(self.embedder_config)
context.configure_transformer(self.cross_encoder_config)
context.configure_transformer(self.transformer_config)

self.validate_modules(dataset, mode=incompatible_search_space)

Expand Down
47 changes: 38 additions & 9 deletions autointent/modules/scoring/_ptuning/ptuning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""PTuningScorer class for ptuning-based classification."""

from pathlib import Path
from typing import Any
from typing import Any, Literal

from peft import PromptEncoderConfig, get_peft_model
from peft import PromptEncoderConfig, PromptEncoderReparameterizationType, TaskType, get_peft_model
from pydantic import PositiveInt

from autointent import Context
from autointent._callbacks import REPORTERS_NAMES
Expand Down Expand Up @@ -53,14 +54,19 @@ class PTuningScorer(BertScorer):

name = "ptuning"

def __init__(
def __init__( # noqa: PLR0913
self,
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
num_train_epochs: PositiveInt = 3,
batch_size: PositiveInt = 8,
learning_rate: float = 5e-5,
seed: int = 0,
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
encoder_reparameterization_type: Literal["MLP", "LSTM"] = "LSTM",
num_virtual_tokens: PositiveInt = 10,
encoder_dropout: float = 0.1,
encoder_hidden_size: PositiveInt = 128,
encoder_num_layers: PositiveInt = 2,
**ptuning_kwargs: Any, # noqa: ANN401
) -> None:
super().__init__(
Expand All @@ -71,17 +77,30 @@ def __init__(
seed=seed,
report_to=report_to,
)
self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs)
self._ptuning_config = PromptEncoderConfig(
task_type=TaskType.SEQ_CLS,
encoder_reparameterization_type=PromptEncoderReparameterizationType(encoder_reparameterization_type),
num_virtual_tokens=num_virtual_tokens,
encoder_dropout=encoder_dropout,
encoder_hidden_size=encoder_hidden_size,
encoder_num_layers=encoder_num_layers,
**ptuning_kwargs,
)

@classmethod
def from_context(
def from_context( # noqa: PLR0913
cls,
context: Context,
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
num_train_epochs: int = 3,
batch_size: int = 8,
num_train_epochs: PositiveInt = 3,
batch_size: PositiveInt = 8,
learning_rate: float = 5e-5,
seed: int = 0,
encoder_reparameterization_type: Literal["MLP", "LSTM"] = "LSTM",
num_virtual_tokens: PositiveInt = 10,
encoder_dropout: float = 0.1,
encoder_hidden_size: PositiveInt = 128,
encoder_num_layers: PositiveInt = 2,
**ptuning_kwargs: Any, # noqa: ANN401
) -> "PTuningScorer":
"""Create a PTuningScorer instance using a Context object.
Expand All @@ -93,6 +112,11 @@ def from_context(
batch_size: Batch size for training
learning_rate: Learning rate for training
seed: Random seed for reproducibility
encoder_reparameterization_type: Reparametrization type for the prompt encoder
num_virtual_tokens: Number of virtual tokens for the prompt encoder
encoder_dropout: Dropout for the prompt encoder
encoder_hidden_size: Hidden size for the prompt encoder
encoder_num_layers: Number of layers for the prompt encoder
**ptuning_kwargs: Arguments for PromptEncoderConfig
"""
if classification_model_config is None:
Expand All @@ -107,6 +131,11 @@ def from_context(
learning_rate=learning_rate,
seed=seed,
report_to=report_to,
encoder_reparameterization_type=encoder_reparameterization_type,
num_virtual_tokens=num_virtual_tokens,
encoder_dropout=encoder_dropout,
encoder_hidden_size=encoder_hidden_size,
encoder_num_layers=encoder_num_layers,
**ptuning_kwargs,
)

Expand Down
56 changes: 56 additions & 0 deletions docs/optimizer_config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,48 @@
"title": "EmbedderConfig",
"type": "object"
},
"HFModelConfig": {
"additionalProperties": false,
"properties": {
"model_name": {
"default": "prajjwal1/bert-tiny",
"description": "Name of the hugging face repository with transformer model.",
"title": "Model Name",
"type": "string"
},
"batch_size": {
"default": 32,
"description": "Batch size for model inference.",
"exclusiveMinimum": 0,
"title": "Batch Size",
"type": "integer"
},
"device": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"description": "Torch notation for CPU or CUDA.",
"title": "Device"
},
"tokenizer_config": {
"$ref": "#/$defs/TokenizerConfig"
},
"trust_remote_code": {
"default": false,
"description": "Whether to trust the remote code when loading the model.",
"title": "Trust Remote Code",
"type": "boolean"
}
},
"title": "HFModelConfig",
"type": "object"
},
"LoggingConfig": {
"additionalProperties": false,
"description": "Configuration for the logging.",
Expand Down Expand Up @@ -442,6 +484,20 @@
"train_head": false
}
},
"transformer_config": {
"$ref": "#/$defs/HFModelConfig",
"default": {
"model_name": "prajjwal1/bert-tiny",
"batch_size": 32,
"device": null,
"tokenizer_config": {
"max_length": null,
"padding": true,
"truncation": true
},
"trust_remote_code": false
}
},
"sampler": {
"default": "brute",
"enum": [
Expand Down