Skip to content

Commit 9bec389

Browse files
committed
fixed requested changes
1 parent d61f743 commit 9bec389

File tree

2 files changed

+23
-32
lines changed

2 files changed

+23
-32
lines changed

autointent/modules/scoring/_lora/lora.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
from autointent import Context
20+
from autointent._callbacks import REPORTERS_NAMES
2021
from autointent.configs import HFModelConfig
2122
from autointent.custom_types import ListOfLabels
2223
from autointent.modules.base import BaseScorer
@@ -26,7 +27,6 @@ class BERTLoRAScorer(BaseScorer):
2627
name = "lora"
2728
supports_multiclass = True
2829
supports_multilabel = True
29-
_multilabel: bool
3030
_model: Any
3131
_tokenizer: Any
3232

@@ -37,14 +37,15 @@ def __init__(
3737
batch_size: int = 8,
3838
learning_rate: float = 5e-5,
3939
seed: int = 0,
40-
**lora_kwargs: Any,
40+
report_to: REPORTERS_NAMES | None = None, # type: ignore
41+
**lora_kwargs: Any, # noqa: ANN401
4142
) -> None:
4243
self.model_config = HFModelConfig.from_search_config(model_config)
4344
self.num_train_epochs = num_train_epochs
4445
self.batch_size = batch_size
4546
self.learning_rate = learning_rate
4647
self.seed = seed
47-
self._multilabel = False
48+
self.report_to = report_to
4849
self._lora_config = LoraConfig(**lora_kwargs)
4950

5051
@classmethod
@@ -56,7 +57,7 @@ def from_context(
5657
batch_size: int = 8,
5758
learning_rate: float = 5e-5,
5859
seed: int = 0,
59-
**lora_kwargs: Any,
60+
**lora_kwargs: Any, # noqa: ANN401
6061
) -> "BERTLoRAScorer":
6162
if model_config is None:
6263
model_config = context.resolve_embedder()
@@ -66,17 +67,13 @@ def from_context(
6667
batch_size=batch_size,
6768
learning_rate=learning_rate,
6869
seed=seed,
70+
report_to=context.logging_config.report_to
6971
**lora_kwargs,
7072
)
7173

7274
def get_embedder_config(self) -> dict[str, Any]:
7375
return self.model_config.model_dump()
7476

75-
def _validate_task(self, labels: ListOfLabels) -> None:
76-
"""Validate the task and set _multilabel flag."""
77-
super()._validate_task(labels)
78-
self._multilabel = isinstance(labels[0], list)
79-
8077
def fit(
8178
self,
8279
utterances: list[str],
@@ -87,20 +84,12 @@ def fit(
8784

8885
self._validate_task(labels)
8986

90-
if self._multilabel:
91-
labels_array = np.array(labels)
92-
num_labels = labels_array.shape[1]
93-
else:
94-
num_labels = len(set(labels))
95-
9687
model_name = self.model_config.model_name
9788
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
98-
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
99-
100-
# Apply LoRA to the model
89+
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes)
10190
self._model = get_peft_model(self._model, self._lora_config)
10291

103-
use_cpu = hasattr(self.model_config, "device") and self.model_config.device == "cpu"
92+
use_cpu = self.model_config.device == "cpu"
10493

10594
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
10695
return self._tokenizer( # type: ignore[no-any-return]
@@ -120,7 +109,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
120109
save_strategy="no",
121110
logging_strategy="steps",
122111
logging_steps=10,
123-
report_to="wandb",
112+
report_to=self.report_to,
124113
use_cpu=use_cpu,
125114
)
126115

@@ -141,17 +130,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
141130
msg = "Model is not trained. Call fit() first."
142131
raise RuntimeError(msg)
143132

144-
inputs = self._tokenizer(
145-
utterances, padding=True, truncation=True, max_length=self.model_config.tokenizer_config.max_length, return_tensors="pt"
146-
)
147-
148-
with torch.no_grad():
149-
outputs = self._model(**inputs)
150-
logits = outputs.logits
151-
152-
if self._multilabel:
153-
return torch.sigmoid(logits).numpy()
154-
return torch.softmax(logits, dim=1).numpy()
133+
all_predictions = []
134+
for i in range(0, len(utterances), self.batch_size):
135+
batch = utterances[i : i + self.batch_size]
136+
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
137+
with torch.no_grad():
138+
outputs = self._model(**inputs)
139+
logits = outputs.logits
140+
if self._multilabel:
141+
batch_predictions = torch.sigmoid(logits).numpy()
142+
else:
143+
batch_predictions = torch.softmax(logits, dim=1).numpy()
144+
all_predictions.append(batch_predictions)
145+
return np.vstack(all_predictions) if all_predictions else np.array([])
155146

156147
def clear_cache(self) -> None:
157148
if hasattr(self, "_model"):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ dependencies = [
4545
"xxhash (>=3.5.0,<4.0.0)",
4646
"python-dotenv (>=1.0.1,<2.0.0)",
4747
"transformers[torch] (>=4.49.0,<5.0.0)",
48-
"peft (==0.10.0)",
48+
"peft (>= 0.10.0, <1.0.0)",
4949
]
5050

5151
[project.urls]

0 commit comments

Comments
 (0)