Skip to content

Commit 3f80b52

Browse files
riapushgithub-actions[bot]Samoed
authored
Lora scorer (#170)
* added lora scorer * fix ruff * Update __init__.py * updated after mr #165 * Update pyproject.toml * fixed requested changes * fixed ruff failing * fixed remarks * Update optimizer_config.schema.json * added test * ruff fix * convert labels to float * Update autointent/modules/scoring/_lora/lora.py Co-authored-by: Roman Solomatin <[email protected]> * Update autointent/modules/scoring/_lora/lora.py Co-authored-by: Roman Solomatin <[email protected]> * change model_config name, added trust_remote_code * Update lora.py * inherited lora from bert * fix ruff * fix search space * Update lora.py * Update lora.py * added dump check * Update test_lora.py * Update test_lora.py * added docstring * fix ruff * Update test_lora.py * Update test_lora.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]>
1 parent bdca370 commit 3f80b52

File tree

10 files changed

+282
-13
lines changed

10 files changed

+282
-13
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,4 @@ tests_logs
179179
tests/logs
180180
runs/
181181
vector_db*
182+
/wandb

autointent/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
1414
from .regex import SimpleRegex
1515
from .scoring import (
16+
BERTLoRAScorer,
1617
BertScorer,
1718
DescriptionScorer,
1819
DNNCScorer,
@@ -46,6 +47,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4647
SklearnScorer,
4748
MLKnnScorer,
4849
BertScorer,
50+
BERTLoRAScorer
4951
]
5052
)
5153

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from ._dnnc import DNNCScorer
44
from ._knn import KNNScorer, RerankScorer
55
from ._linear import LinearScorer
6+
from ._lora import BERTLoRAScorer
67
from ._mlknn import MLKnnScorer
78
from ._sklearn import SklearnScorer
89

910
__all__ = [
11+
"BERTLoRAScorer",
1012
"BertScorer",
1113
"DNNCScorer",
1214
"DescriptionScorer",

autointent/modules/scoring/_bert.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -72,30 +72,33 @@ def from_context(
7272
def get_embedder_config(self) -> dict[str, Any]:
7373
return self.classification_model_config.model_dump()
7474

75-
def fit(
76-
self,
77-
utterances: list[str],
78-
labels: ListOfLabels,
79-
) -> None:
80-
if hasattr(self, "_model"):
81-
self.clear_cache()
82-
self._validate_task(labels)
83-
84-
model_name = self.classification_model_config.model_name
85-
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
86-
75+
def __initialize_model(self) -> None:
8776
label2id = {i: i for i in range(self._n_classes)}
8877
id2label = {i: i for i in range(self._n_classes)}
8978

9079
self._model = AutoModelForSequenceClassification.from_pretrained(
91-
model_name,
80+
self.classification_model_config.model_name,
9281
trust_remote_code=self.classification_model_config.trust_remote_code,
9382
num_labels=self._n_classes,
9483
label2id=label2id,
9584
id2label=id2label,
9685
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
9786
)
9887

88+
89+
def fit(
90+
self,
91+
utterances: list[str],
92+
labels: ListOfLabels,
93+
) -> None:
94+
if hasattr(self, "_model"):
95+
self.clear_cache()
96+
self._validate_task(labels)
97+
98+
self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name)
99+
100+
self.__initialize_model()
101+
99102
use_cpu = self.classification_model_config.device == "cpu"
100103

101104
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .lora import BERTLoRAScorer
2+
3+
__all__ = ["BERTLoRAScorer"]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""BertScorer class for transformer-based classification with LoRA."""
2+
3+
from typing import Any
4+
5+
from peft import LoraConfig, get_peft_model
6+
from transformers import AutoModelForSequenceClassification
7+
8+
from autointent import Context
9+
from autointent._callbacks import REPORTERS_NAMES
10+
from autointent.configs import HFModelConfig
11+
from autointent.modules.scoring._bert import BertScorer
12+
13+
14+
class BERTLoRAScorer(BertScorer):
15+
"""BERTLoRAScorer class for transformer-based classification with LoRA (Low-Rank Adaptation).
16+
17+
Args:
18+
classification_model_config: Config of the base transformer model (HFModelConfig, str, or dict)
19+
num_train_epochs: Number of training epochs (default: 3)
20+
batch_size: Batch size for training (default: 8)
21+
learning_rate: Learning rate for training (default: 5e-5)
22+
seed: Random seed for reproducibility (default: 0)
23+
report_to: Reporting tool for training logs
24+
**lora_kwargs: Arguments for `LoraConfig <https://huggingface.co/docs/peft/package_reference/lora#peft.LoraConfig>`_
25+
26+
Example:
27+
--------
28+
.. testcode::
29+
30+
from autointent.modules import BERTLoRAScorer
31+
32+
# Initialize scorer with LoRA configuration
33+
scorer = BERTLoRAScorer(
34+
classification_model_config="bert-base-uncased",
35+
num_train_epochs=3,
36+
batch_size=8,
37+
learning_rate=5e-5,
38+
seed=42,
39+
r=8, # LoRA rank
40+
lora_alpha=16, # LoRA alpha
41+
)
42+
43+
# Training data
44+
utterances = ["This is great!", "I didn't like it", "Awesome product", "Poor quality"]
45+
labels = [1, 0, 1, 0] # Binary classification
46+
47+
# Fit the model
48+
scorer.fit(utterances, labels)
49+
50+
# Make predictions
51+
test_utterances = ["Good product", "Not worth it"]
52+
probabilities = scorer.predict(test_utterances)
53+
print(probabilities)
54+
55+
.. testoutput::
56+
57+
[[0.89 0.11]
58+
[0.23 0.77]]
59+
"""
60+
61+
name = "lora"
62+
supports_multiclass = True
63+
supports_multilabel = True
64+
_model: Any
65+
_tokenizer: Any
66+
67+
def __init__(
68+
self,
69+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
70+
num_train_epochs: int = 3,
71+
batch_size: int = 8,
72+
learning_rate: float = 5e-5,
73+
seed: int = 0,
74+
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
75+
**lora_kwargs: dict[str, Any],
76+
) -> None:
77+
super().__init__(
78+
classification_model_config=classification_model_config,
79+
num_train_epochs=num_train_epochs,
80+
batch_size=batch_size,
81+
learning_rate=learning_rate,
82+
seed=seed,
83+
report_to=report_to,
84+
)
85+
self._lora_config = LoraConfig(**lora_kwargs) # type: ignore[arg-type]
86+
87+
@classmethod
88+
def from_context(
89+
cls,
90+
context: Context,
91+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
92+
num_train_epochs: int = 3,
93+
batch_size: int = 8,
94+
learning_rate: float = 5e-5,
95+
seed: int = 0,
96+
**lora_kwargs: dict[str, Any],
97+
) -> "BERTLoRAScorer":
98+
if classification_model_config is None:
99+
classification_model_config = context.resolve_embedder()
100+
return cls(
101+
classification_model_config=classification_model_config,
102+
num_train_epochs=num_train_epochs,
103+
batch_size=batch_size,
104+
learning_rate=learning_rate,
105+
seed=seed,
106+
report_to=context.logging_config.report_to,
107+
**lora_kwargs,
108+
)
109+
110+
def __initialize_model(self) -> None:
111+
self._model = AutoModelForSequenceClassification.from_pretrained(
112+
self.classification_model_config.model_name,
113+
num_labels=self._n_classes,
114+
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
115+
trust_remote_code=self.classification_model_config.trust_remote_code,
116+
)
117+
self._model = get_peft_model(self._model, self._lora_config)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"xxhash (>=3.5.0,<4.0.0)",
4646
"python-dotenv (>=1.0.1,<2.0.0)",
4747
"transformers[torch] (>=4.49.0,<5.0.0)",
48+
"peft (>= 0.10.0, <1.0.0)",
4849
"codecarbon (==2.6)",
4950
]
5051

tests/assets/configs/multiclass.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@
3535
batch_size: [8, 16]
3636
learning_rate: [5.0e-5]
3737
seed: [0]
38+
- module_name: lora
39+
classification_model_config:
40+
- model_name: avsolatorio/GIST-small-Embedding-v0
41+
num_train_epochs: [1]
42+
batch_size: [8, 16]
43+
learning_rate: [5.0e-5]
44+
seed: [0]
45+
lora_alpha: [16]
3846
- node_type: decision
3947
target_metric: decision_accuracy
4048
search_space:

tests/assets/configs/multilabel.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@
3131
batch_size: [8]
3232
learning_rate: [5.0e-5]
3333
seed: [0]
34+
- module_name: lora
35+
classification_model_config:
36+
- model_name: avsolatorio/GIST-small-Embedding-v0
37+
num_train_epochs: [1]
38+
batch_size: [8]
39+
learning_rate: [5.0e-5]
40+
seed: [0]
41+
lora_alpha: [16]
3442
- node_type: decision
3543
target_metric: decision_accuracy
3644
search_space:

tests/modules/scoring/test_lora.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import shutil
2+
import tempfile
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import pytest
7+
8+
from autointent.context.data_handler import DataHandler
9+
from autointent.modules import BERTLoRAScorer
10+
11+
12+
def test_lora_scorer_dump_load(dataset):
13+
"""Test that BERTLoRAScorer can be saved and loaded while preserving predictions."""
14+
data_handler = DataHandler(dataset)
15+
16+
# Create and train scorer
17+
scorer_original = BERTLoRAScorer(
18+
classification_model_config="prajjwal1/bert-tiny",
19+
num_train_epochs=1,
20+
batch_size=8
21+
)
22+
scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
23+
24+
# Test data
25+
test_data = [
26+
"why is there a hold on my account",
27+
"why is my bank account frozen",
28+
]
29+
30+
# Get predictions before saving
31+
predictions_before = scorer_original.predict(test_data)
32+
33+
# Create temp directory and save model
34+
temp_dir_path = Path(tempfile.mkdtemp(prefix="lora_scorer_test_"))
35+
try:
36+
# Save the model
37+
scorer_original.dump(str(temp_dir_path))
38+
39+
# Create a new scorer and load saved model
40+
scorer_loaded = BERTLoRAScorer(
41+
classification_model_config="prajjwal1/bert-tiny",
42+
num_train_epochs=1,
43+
batch_size=8
44+
)
45+
scorer_loaded.load(str(temp_dir_path))
46+
47+
# Verify model and tokenizer are loaded
48+
assert hasattr(scorer_loaded, "_model")
49+
assert scorer_loaded._model is not None
50+
assert hasattr(scorer_loaded, "_tokenizer")
51+
assert scorer_loaded._tokenizer is not None
52+
53+
# Get predictions after loading
54+
predictions_after = scorer_loaded.predict(test_data)
55+
56+
# Verify predictions match
57+
assert predictions_before.shape == predictions_after.shape
58+
np.testing.assert_allclose(predictions_before, predictions_after, atol=1e-6)
59+
60+
finally:
61+
# Clean up
62+
shutil.rmtree(temp_dir_path, ignore_errors=True) # workaround for windows permission error
63+
64+
65+
def test_lora_prediction(dataset):
66+
"""Test that the lora model can fit and make predictions."""
67+
data_handler = DataHandler(dataset)
68+
69+
scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
70+
71+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
72+
73+
test_data = [
74+
"why is there a hold on my american saving bank account",
75+
"i am not sure why my account is blocked",
76+
"why is there a hold on my capital one checking account",
77+
"i think my account is blocked but i do not know the reason",
78+
"can you tell me why is my bank account frozen",
79+
]
80+
81+
predictions = scorer.predict(test_data)
82+
83+
# Verify prediction shape
84+
assert predictions.shape[0] == len(test_data)
85+
assert predictions.shape[1] == len(set(data_handler.train_labels(0)))
86+
87+
# Verify predictions are probabilities
88+
assert 0.0 <= np.min(predictions) <= np.max(predictions) <= 1.0
89+
90+
# Verify probabilities sum to 1 for multiclass
91+
if not scorer._multilabel:
92+
for pred_row in predictions:
93+
np.testing.assert_almost_equal(np.sum(pred_row), 1.0, decimal=5)
94+
95+
# Test metadata function if available
96+
if hasattr(scorer, "predict_with_metadata"):
97+
predictions, metadata = scorer.predict_with_metadata(test_data)
98+
assert len(predictions) == len(test_data)
99+
assert metadata is None
100+
101+
102+
def test_lora_cache_clearing(dataset):
103+
"""Test that the lora model properly handles cache clearing."""
104+
data_handler = DataHandler(dataset)
105+
106+
scorer = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
107+
108+
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
109+
110+
test_data = ["test text"]
111+
112+
# Should work before clearing cache
113+
scorer.predict(test_data)
114+
115+
# Clear the cache
116+
scorer.clear_cache()
117+
118+
# Verify model and tokenizer are removed
119+
assert not hasattr(scorer, "_model") or scorer._model is None
120+
assert not hasattr(scorer, "_tokenizer") or scorer._tokenizer is None
121+
122+
# Should raise exception after clearing cache
123+
with pytest.raises(RuntimeError):
124+
scorer.predict(test_data)

0 commit comments

Comments
 (0)