Skip to content

Commit a04587b

Browse files
Fix/minor fixes (#206)
* hf model config * add ptuning args validation * Update optimizer_config.schema.json * bugfix * bugfix * Update optimizer_config.schema.json --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 2702e41 commit a04587b

File tree

4 files changed

+106
-11
lines changed

4 files changed

+106
-11
lines changed

autointent/_optimization_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pydantic import BaseModel, PositiveInt
44

5-
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
5+
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, LoggingConfig
66
from .custom_types import SamplerType
77

88

@@ -25,6 +25,8 @@ class OptimizationConfig(BaseModel):
2525

2626
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()
2727

28+
transformer_config: HFModelConfig = HFModelConfig()
29+
2830
sampler: SamplerType = "brute"
2931
"""See tutorial on optuna and presets."""
3032

autointent/_pipeline/_pipeline.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CrossEncoderConfig,
1515
DataConfig,
1616
EmbedderConfig,
17+
HFModelConfig,
1718
InferenceNodeConfig,
1819
LoggingConfig,
1920
)
@@ -67,10 +68,13 @@ def __init__(
6768
self.embedder_config = EmbedderConfig()
6869
self.cross_encoder_config = CrossEncoderConfig()
6970
self.data_config = DataConfig()
71+
self.transformer_config = HFModelConfig()
7072
elif not isinstance(nodes[0], InferenceNode):
7173
assert_never(nodes)
7274

73-
def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig) -> None:
75+
def set_config(
76+
self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig | DataConfig | HFModelConfig
77+
) -> None:
7478
"""Set the configuration for the pipeline.
7579
7680
Args:
@@ -84,6 +88,8 @@ def set_config(self, config: LoggingConfig | EmbedderConfig | CrossEncoderConfig
8488
self.cross_encoder_config = config
8589
elif isinstance(config, DataConfig):
8690
self.data_config = config
91+
elif isinstance(config, HFModelConfig):
92+
self.transformer_config = config
8793
else:
8894
assert_never(config)
8995

@@ -133,6 +139,7 @@ def from_optimization_config(cls, config: dict[str, Any] | Path | str | Optimiza
133139
pipeline.set_config(optimization_config.data_config)
134140
pipeline.set_config(optimization_config.embedder_config)
135141
pipeline.set_config(optimization_config.cross_encoder_config)
142+
pipeline.set_config(optimization_config.transformer_config)
136143
return pipeline
137144

138145
def _fit(self, context: Context, sampler: SamplerType) -> None:
@@ -198,6 +205,7 @@ def fit(
198205
context.configure_logging(self.logging_config)
199206
context.configure_transformer(self.embedder_config)
200207
context.configure_transformer(self.cross_encoder_config)
208+
context.configure_transformer(self.transformer_config)
201209

202210
self.validate_modules(dataset, mode=incompatible_search_space)
203211

autointent/modules/scoring/_ptuning/ptuning.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""PTuningScorer class for ptuning-based classification."""
22

33
from pathlib import Path
4-
from typing import Any
4+
from typing import Any, Literal
55

6-
from peft import PromptEncoderConfig, get_peft_model
6+
from peft import PromptEncoderConfig, PromptEncoderReparameterizationType, TaskType, get_peft_model
7+
from pydantic import PositiveInt
78

89
from autointent import Context
910
from autointent._callbacks import REPORTERS_NAMES
@@ -53,14 +54,19 @@ class PTuningScorer(BertScorer):
5354

5455
name = "ptuning"
5556

56-
def __init__(
57+
def __init__( # noqa: PLR0913
5758
self,
5859
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
59-
num_train_epochs: int = 3,
60-
batch_size: int = 8,
60+
num_train_epochs: PositiveInt = 3,
61+
batch_size: PositiveInt = 8,
6162
learning_rate: float = 5e-5,
6263
seed: int = 0,
6364
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
65+
encoder_reparameterization_type: Literal["MLP", "LSTM"] = "LSTM",
66+
num_virtual_tokens: PositiveInt = 10,
67+
encoder_dropout: float = 0.1,
68+
encoder_hidden_size: PositiveInt = 128,
69+
encoder_num_layers: PositiveInt = 2,
6470
**ptuning_kwargs: Any, # noqa: ANN401
6571
) -> None:
6672
super().__init__(
@@ -71,17 +77,30 @@ def __init__(
7177
seed=seed,
7278
report_to=report_to,
7379
)
74-
self._ptuning_config = PromptEncoderConfig(task_type="SEQ_CLS", **ptuning_kwargs)
80+
self._ptuning_config = PromptEncoderConfig(
81+
task_type=TaskType.SEQ_CLS,
82+
encoder_reparameterization_type=PromptEncoderReparameterizationType(encoder_reparameterization_type),
83+
num_virtual_tokens=num_virtual_tokens,
84+
encoder_dropout=encoder_dropout,
85+
encoder_hidden_size=encoder_hidden_size,
86+
encoder_num_layers=encoder_num_layers,
87+
**ptuning_kwargs,
88+
)
7589

7690
@classmethod
77-
def from_context(
91+
def from_context( # noqa: PLR0913
7892
cls,
7993
context: Context,
8094
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
81-
num_train_epochs: int = 3,
82-
batch_size: int = 8,
95+
num_train_epochs: PositiveInt = 3,
96+
batch_size: PositiveInt = 8,
8397
learning_rate: float = 5e-5,
8498
seed: int = 0,
99+
encoder_reparameterization_type: Literal["MLP", "LSTM"] = "LSTM",
100+
num_virtual_tokens: PositiveInt = 10,
101+
encoder_dropout: float = 0.1,
102+
encoder_hidden_size: PositiveInt = 128,
103+
encoder_num_layers: PositiveInt = 2,
85104
**ptuning_kwargs: Any, # noqa: ANN401
86105
) -> "PTuningScorer":
87106
"""Create a PTuningScorer instance using a Context object.
@@ -93,6 +112,11 @@ def from_context(
93112
batch_size: Batch size for training
94113
learning_rate: Learning rate for training
95114
seed: Random seed for reproducibility
115+
encoder_reparameterization_type: Reparametrization type for the prompt encoder
116+
num_virtual_tokens: Number of virtual tokens for the prompt encoder
117+
encoder_dropout: Dropout for the prompt encoder
118+
encoder_hidden_size: Hidden size for the prompt encoder
119+
encoder_num_layers: Number of layers for the prompt encoder
96120
**ptuning_kwargs: Arguments for PromptEncoderConfig
97121
"""
98122
if classification_model_config is None:
@@ -107,6 +131,11 @@ def from_context(
107131
learning_rate=learning_rate,
108132
seed=seed,
109133
report_to=report_to,
134+
encoder_reparameterization_type=encoder_reparameterization_type,
135+
num_virtual_tokens=num_virtual_tokens,
136+
encoder_dropout=encoder_dropout,
137+
encoder_hidden_size=encoder_hidden_size,
138+
encoder_num_layers=encoder_num_layers,
110139
**ptuning_kwargs,
111140
)
112141

docs/optimizer_config.schema.json

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,48 @@
253253
"title": "EmbedderConfig",
254254
"type": "object"
255255
},
256+
"HFModelConfig": {
257+
"additionalProperties": false,
258+
"properties": {
259+
"model_name": {
260+
"default": "prajjwal1/bert-tiny",
261+
"description": "Name of the hugging face repository with transformer model.",
262+
"title": "Model Name",
263+
"type": "string"
264+
},
265+
"batch_size": {
266+
"default": 32,
267+
"description": "Batch size for model inference.",
268+
"exclusiveMinimum": 0,
269+
"title": "Batch Size",
270+
"type": "integer"
271+
},
272+
"device": {
273+
"anyOf": [
274+
{
275+
"type": "string"
276+
},
277+
{
278+
"type": "null"
279+
}
280+
],
281+
"default": null,
282+
"description": "Torch notation for CPU or CUDA.",
283+
"title": "Device"
284+
},
285+
"tokenizer_config": {
286+
"$ref": "#/$defs/TokenizerConfig"
287+
},
288+
"trust_remote_code": {
289+
"default": false,
290+
"description": "Whether to trust the remote code when loading the model.",
291+
"title": "Trust Remote Code",
292+
"type": "boolean"
293+
}
294+
},
295+
"title": "HFModelConfig",
296+
"type": "object"
297+
},
256298
"LoggingConfig": {
257299
"additionalProperties": false,
258300
"description": "Configuration for the logging.",
@@ -442,6 +484,20 @@
442484
"train_head": false
443485
}
444486
},
487+
"transformer_config": {
488+
"$ref": "#/$defs/HFModelConfig",
489+
"default": {
490+
"model_name": "prajjwal1/bert-tiny",
491+
"batch_size": 32,
492+
"device": null,
493+
"tokenizer_config": {
494+
"max_length": null,
495+
"padding": true,
496+
"truncation": true
497+
},
498+
"trust_remote_code": false
499+
}
500+
},
445501
"sampler": {
446502
"default": "brute",
447503
"enum": [

0 commit comments

Comments
 (0)