Skip to content

Commit 54098c0

Browse files
nikidukigithub-actions[bot]voorhsDarinochkaSamoed
authored
PTuningScorer (#178)
* Initial commit of PTuningScorer module * Added peft (>=0.10.0, <0.15.0) in dependencies * Implement fit/predict PTuningScorer * Added PTuningScorer in __init__ file * Update optimizer_config.schema.json * Minor fixs * PGH00 * Refactor clear_cache in fit method * Refactor typing ignore + remove unnecessary * Fix fit method status check * Added test for PTuningScorer * Fix mypy typing * Update and fix peft version dependencies * Fix mypy typing * Added test in multiclass.yaml, multilabel.yaml * Update docs strings * Fix mypy typing * Added trust_remote_code * make proper rst reference * Added test for dump lod * feat: added crossencoder (#181) * feat: added crossencoder * refactor * feat: added arg similarity * Update optimizer_config.schema.json * feat: added tests * feat: added errors * fix: scoring test * fix: description vectors error * fix: description vectors error * fix: lint * fix: test * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * fix: unit tests * feat: added test for description * feat: delete encoder_type from the class args * feat: update assets * feat: update assets * fix: fixed test * Update optimizer_config.schema.json --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: voorhs <[email protected]> * Added fixed seed to test reproduction * Pull LoraScorer and Bert Refactor * Refactor PTuningScorer * Refactor test for ptuning * Fix typing * Fix multilabel multiclass tests * Fix typing --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> Co-authored-by: Darinochka <[email protected]> Co-authored-by: Roman Solomatin <[email protected]>
1 parent 3f80b52 commit 54098c0

File tree

12 files changed

+283
-15
lines changed

12 files changed

+283
-15
lines changed

autointent/_callbacks/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self) -> None:
1616
Raises an ImportError if neither are installed.
1717
"""
1818
try:
19-
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
19+
from torch.utils.tensorboard import SummaryWriter
2020

2121
self.writer = SummaryWriter
2222
except ImportError:

autointent/modules/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
KNNScorer,
2121
LinearScorer,
2222
MLKnnScorer,
23+
PTuningScorer,
2324
RerankScorer,
2425
SklearnScorer,
2526
)
@@ -47,7 +48,8 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4748
SklearnScorer,
4849
MLKnnScorer,
4950
BertScorer,
50-
BERTLoRAScorer
51+
BERTLoRAScorer,
52+
PTuningScorer,
5153
]
5254
)
5355

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._linear import LinearScorer
66
from ._lora import BERTLoRAScorer
77
from ._mlknn import MLKnnScorer
8+
from ._ptuning import PTuningScorer
89
from ._sklearn import SklearnScorer
910

1011
__all__ = [
@@ -15,6 +16,7 @@
1516
"KNNScorer",
1617
"LinearScorer",
1718
"MLKnnScorer",
19+
"PTuningScorer",
1820
"RerankScorer",
1921
"SklearnScorer",
2022
]

autointent/modules/scoring/_bert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def __initialize_model(self) -> None:
8585
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
8686
)
8787

88-
8988
def fit(
9089
self,
9190
utterances: list[str],

autointent/modules/scoring/_lora/lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def __init__(
8181
learning_rate=learning_rate,
8282
seed=seed,
8383
report_to=report_to,
84-
)
85-
self._lora_config = LoraConfig(**lora_kwargs) # type: ignore[arg-type]
84+
)
85+
self._lora_config = LoraConfig(**lora_kwargs) # type: ignore[arg-type]
8686

8787
@classmethod
8888
def from_context(
@@ -113,5 +113,5 @@ def __initialize_model(self) -> None:
113113
num_labels=self._n_classes,
114114
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
115115
trust_remote_code=self.classification_model_config.trust_remote_code,
116-
)
116+
)
117117
self._model = get_peft_model(self._model, self._lora_config)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ptuning import PTuningScorer
2+
3+
__all__ = ["PTuningScorer"]
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""PTuningScorer class for ptuning-based classification."""
2+
3+
from typing import Any
4+
5+
import torch
6+
from peft import PromptEncoderConfig, get_peft_model
7+
from transformers import (
8+
AutoModelForSequenceClassification,
9+
)
10+
11+
from autointent import Context
12+
from autointent._callbacks import REPORTERS_NAMES
13+
from autointent.configs import HFModelConfig
14+
from autointent.modules.scoring._bert import BertScorer
15+
16+
17+
class PTuningScorer(BertScorer):
18+
"""PEFT P-tuning scorer.
19+
20+
Args:
21+
classification_model_config: Config of the base transformer model (HFModelConfig, str, or dict)
22+
num_train_epochs: Number of training epochs
23+
batch_size: Batch size for training
24+
learning_rate: Learning rate for training
25+
seed: Random seed for reproducibility
26+
report_to: Reporting tool for training logs
27+
**ptuning_kwargs: Arguments for `PromptEncoderConfig <https://huggingface.co/docs/peft/package_reference/p_tuning#peft.PromptEncoderConfig>`_
28+
29+
Example:
30+
--------
31+
.. testcode::
32+
33+
from autointent.modules import PTuningScorer
34+
scorer = PTuningScorer(
35+
classification_model_config="prajjwal1/bert-tiny",
36+
num_train_epochs=3,
37+
batch_size=8,
38+
task_type="SEQ_CLS",
39+
num_virtual_tokens=10,
40+
seed=42
41+
)
42+
utterances = ["hello", "goodbye", "allo", "sayonara"]
43+
labels = [0, 1, 0, 1]
44+
scorer.fit(utterances, labels)
45+
test_utterances = ["hi", "bye"]
46+
probabilities = scorer.predict(test_utterances)
47+
print(probabilities)
48+
49+
.. testoutput::
50+
51+
[[0.49925193 0.50074804]
52+
[0.4944601 0.5055399 ]]
53+
54+
"""
55+
56+
name = "ptuning"
57+
supports_multiclass = True
58+
supports_multilabel = True
59+
_model: Any
60+
_tokenizer: Any
61+
62+
def __init__(
63+
self,
64+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
65+
num_train_epochs: int = 3,
66+
batch_size: int = 8,
67+
learning_rate: float = 5e-5,
68+
seed: int = 0,
69+
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
70+
**ptuning_kwargs: dict[str, Any],
71+
) -> None:
72+
super().__init__(
73+
classification_model_config=classification_model_config,
74+
num_train_epochs=num_train_epochs,
75+
batch_size=batch_size,
76+
learning_rate=learning_rate,
77+
seed=seed,
78+
report_to=report_to,
79+
)
80+
self._ptuning_config = PromptEncoderConfig(**ptuning_kwargs) # type: ignore[arg-type]
81+
torch.manual_seed(seed)
82+
83+
@classmethod
84+
def from_context(
85+
cls,
86+
context: Context,
87+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
88+
num_train_epochs: int = 3,
89+
batch_size: int = 8,
90+
learning_rate: float = 5e-5,
91+
seed: int = 0,
92+
**ptuning_kwargs: dict[str, Any],
93+
) -> "PTuningScorer":
94+
"""Create a PTuningScorer instance using a Context object.
95+
96+
Args:
97+
context: Context containing configurations and utilities
98+
classification_model_config: Config of the base model, or None to use the best embedder
99+
num_train_epochs: Number of training epochs
100+
batch_size: Batch size for training
101+
learning_rate: Learning rate for training
102+
seed: Random seed for reproducibility
103+
**ptuning_kwargs: Arguments for PromptEncoderConfig
104+
"""
105+
if classification_model_config is None:
106+
classification_model_config = context.resolve_embedder()
107+
108+
report_to = context.logging_config.report_to
109+
110+
return cls(
111+
classification_model_config=classification_model_config,
112+
num_train_epochs=num_train_epochs,
113+
batch_size=batch_size,
114+
learning_rate=learning_rate,
115+
seed=seed,
116+
report_to=report_to,
117+
**ptuning_kwargs,
118+
)
119+
120+
def _initialize_model(self) -> None:
121+
"""Initialize the model with P-tuning configuration."""
122+
model_name = self.classification_model_config.model_name
123+
self._model = AutoModelForSequenceClassification.from_pretrained(
124+
model_name,
125+
num_labels=self._n_classes,
126+
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
127+
trust_remote_code=self.classification_model_config.trust_remote_code,
128+
return_dict=True,
129+
)
130+
self._model = get_peft_model(self._model, self._ptuning_config)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ dependencies = [
4545
"xxhash (>=3.5.0,<4.0.0)",
4646
"python-dotenv (>=1.0.1,<2.0.0)",
4747
"transformers[torch] (>=4.49.0,<5.0.0)",
48-
"peft (>= 0.10.0, <1.0.0)",
48+
"peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)",
4949
"codecarbon (==2.6)",
5050
]
5151

tests/assets/configs/multiclass.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
learning_rate: [5.0e-5]
4444
seed: [0]
4545
lora_alpha: [16]
46+
- module_name: ptuning
47+
classification_model_config: ["prajjwal1/bert-tiny"]
48+
num_train_epochs: [1]
49+
batch_size: [8, 16]
50+
task_type: ["SEQ_CLS"]
51+
num_virtual_tokens: [10, 20]
4652
- node_type: decision
4753
target_metric: decision_accuracy
4854
search_space:

tests/assets/configs/multilabel.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
batch_size: [8]
3232
learning_rate: [5.0e-5]
3333
seed: [0]
34+
- module_name: ptuning
35+
classification_model_config: ["prajjwal1/bert-tiny"]
36+
num_train_epochs: [1]
37+
batch_size: [8]
38+
task_type: ["SEQ_CLS"]
39+
num_virtual_tokens: [10, 20]
3440
- module_name: lora
3541
classification_model_config:
3642
- model_name: avsolatorio/GIST-small-Embedding-v0

0 commit comments

Comments
 (0)