Skip to content

Commit f60f167

Browse files
committed
inherited lora from bert
1 parent 5bf5de0 commit f60f167

File tree

2 files changed

+35
-118
lines changed

2 files changed

+35
-118
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ def from_context(
7171

7272
def get_embedder_config(self) -> dict[str, Any]:
7373
return self.classification_model_config.model_dump()
74+
75+
def __initialize_model(self):
76+
label2id = {i: i for i in range(self._n_classes)}
77+
id2label = {i: i for i in range(self._n_classes)}
78+
79+
self._model = AutoModelForSequenceClassification.from_pretrained(
80+
self.classification_model_config.model_name,
81+
trust_remote_code=self.classification_model_config.trust_remote_code,
82+
num_labels=self._n_classes,
83+
label2id=label2id,
84+
id2label=id2label,
85+
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
86+
)
87+
7488

7589
def fit(
7690
self,
@@ -81,20 +95,9 @@ def fit(
8195
self.clear_cache()
8296
self._validate_task(labels)
8397

84-
model_name = self.classification_model_config.model_name
85-
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
98+
self._tokenizer = AutoTokenizer.from_pretrained(self.classification_model_config.model_name)
8699

87-
label2id = {i: i for i in range(self._n_classes)}
88-
id2label = {i: i for i in range(self._n_classes)}
89-
90-
self._model = AutoModelForSequenceClassification.from_pretrained(
91-
model_name,
92-
trust_remote_code=self.classification_model_config.trust_remote_code,
93-
num_labels=self._n_classes,
94-
label2id=label2id,
95-
id2label=id2label,
96-
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
97-
)
100+
self.__initialize_model()
98101

99102
use_cpu = self.classification_model_config.device == "cpu"
100103

autointent/modules/scoring/_lora/lora.py

Lines changed: 19 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
from autointent import Context
2020
from autointent._callbacks import REPORTERS_NAMES
2121
from autointent.configs import HFModelConfig
22-
from autointent.custom_types import ListOfLabels
23-
from autointent.modules.base import BaseScorer
22+
from autointent.modules.scoring._bert import BertScorer
2423

2524

26-
class BERTLoRAScorer(BaseScorer):
25+
class BERTLoRAScorer(BertScorer):
2726
name = "lora"
2827
supports_multiclass = True
2928
supports_multilabel = True
@@ -32,137 +31,52 @@ class BERTLoRAScorer(BaseScorer):
3231

3332
def __init__(
3433
self,
35-
transformer_config: HFModelConfig | str | dict[str, Any] | None = None,
34+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
3635
num_train_epochs: int = 3,
3736
batch_size: int = 8,
3837
learning_rate: float = 5e-5,
3938
seed: int = 0,
4039
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
4140
**lora_kwargs: dict[str, Any],
4241
) -> None:
43-
self.transformer_config = HFModelConfig.from_search_config(transformer_config)
44-
self.num_train_epochs = num_train_epochs
45-
self.batch_size = batch_size
46-
self.learning_rate = learning_rate
47-
self.seed = seed
48-
self.report_to = report_to
42+
super(BERTLoRAScorer, self).__init__(
43+
classification_model_config=classification_model_config,
44+
num_train_epochs=num_train_epochs,
45+
batch_size=batch_size,
46+
learning_rate=learning_rate,
47+
seed=seed,
48+
report_to=report_to, # type: ignore[no-any-return]
49+
)
4950
self._lora_config = LoraConfig(**lora_kwargs)
5051

5152
@classmethod
5253
def from_context(
5354
cls,
5455
context: Context,
55-
transformer_config: HFModelConfig | str | dict[str, Any] | None = None,
56+
classification_model_config: HFModelConfig | str | dict[str, Any] | None = None,
5657
num_train_epochs: int = 3,
5758
batch_size: int = 8,
5859
learning_rate: float = 5e-5,
5960
seed: int = 0,
6061
**lora_kwargs: dict[str, Any],
6162
) -> "BERTLoRAScorer":
62-
if transformer_config is None:
63-
transformer_config = context.resolve_embedder()
63+
if classification_model_config is None:
64+
classification_model_config = context.resolve_embedder()
6465
return cls(
65-
transformer_config=transformer_config,
66+
classification_model_config=classification_model_config,
6667
num_train_epochs=num_train_epochs,
6768
batch_size=batch_size,
6869
learning_rate=learning_rate,
6970
seed=seed,
7071
report_to=context.logging_config.report_to,
7172
**lora_kwargs,
7273
)
73-
74-
def get_embedder_config(self) -> dict[str, Any]:
75-
return self.transformer_config.model_dump()
76-
77-
def fit(
78-
self,
79-
utterances: list[str],
80-
labels: ListOfLabels,
81-
) -> None:
82-
if hasattr(self, "_model"):
83-
self.clear_cache()
84-
85-
self._validate_task(labels)
86-
87-
model_name = self.transformer_config.model_name
88-
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
74+
75+
def __initialize_model(self, ):
8976
self._model = AutoModelForSequenceClassification.from_pretrained(
90-
model_name,
77+
self.classification_model_config.model_name,
9178
num_labels=self._n_classes,
9279
problem_type="multi_label_classification" if self._multilabel else "single_label_classification",
93-
trust_remote_code=self.transformer_config.trust_remote_code,
80+
trust_remote_code=self.classification_model_config.trust_remote_code,
9481
)
9582
self._model = get_peft_model(self._model, self._lora_config)
96-
97-
device = torch.device(self.transformer_config.device if self.transformer_config.device else "cpu")
98-
self._model = self._model.to(device)
99-
100-
use_cpu = self.transformer_config.device == "cpu"
101-
102-
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
103-
return self._tokenizer( # type: ignore[no-any-return]
104-
examples["text"], return_tensors="pt", **self.transformer_config.tokenizer_config.model_dump()
105-
)
106-
107-
dataset = Dataset.from_dict({"text": utterances, "labels": labels})
108-
if self._multilabel:
109-
dataset = dataset.map(
110-
lambda example: {"label": torch.tensor(example["labels"], dtype=torch.float)}, remove_columns=["labels"]
111-
)
112-
dataset = dataset.rename_column("label", "labels")
113-
tokenized_dataset = dataset.map(tokenize_function, batched=True)
114-
115-
with tempfile.TemporaryDirectory() as tmp_dir:
116-
training_args = TrainingArguments(
117-
output_dir=tmp_dir,
118-
num_train_epochs=self.num_train_epochs,
119-
per_device_train_batch_size=self.batch_size,
120-
learning_rate=self.learning_rate,
121-
seed=self.seed,
122-
save_strategy="no",
123-
logging_strategy="steps",
124-
logging_steps=10,
125-
report_to=self.report_to,
126-
use_cpu=use_cpu,
127-
)
128-
129-
trainer = Trainer(
130-
model=self._model,
131-
args=training_args,
132-
train_dataset=tokenized_dataset,
133-
tokenizer=self._tokenizer,
134-
data_collator=DataCollatorWithPadding(tokenizer=self._tokenizer),
135-
)
136-
137-
trainer.train()
138-
139-
self._model.eval()
140-
141-
def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
142-
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
143-
msg = "Model is not trained. Call fit() first."
144-
raise RuntimeError(msg)
145-
146-
device = torch.device(self.transformer_config.device if self.transformer_config.device else "cpu")
147-
self._model = self._model.to(device)
148-
149-
all_predictions = []
150-
for i in range(0, len(utterances), self.batch_size):
151-
batch = utterances[i : i + self.batch_size]
152-
inputs = self._tokenizer(batch, return_tensors="pt", **self.transformer_config.tokenizer_config.model_dump())
153-
inputs = {k: v.to(device) for k, v in inputs.items()}
154-
with torch.no_grad():
155-
outputs = self._model(**inputs)
156-
logits = outputs.logits
157-
if self._multilabel:
158-
batch_predictions = torch.sigmoid(logits).cpu().numpy()
159-
else:
160-
batch_predictions = torch.softmax(logits, dim=1).cpu().numpy()
161-
all_predictions.append(batch_predictions)
162-
return np.vstack(all_predictions) if all_predictions else np.array([])
163-
164-
def clear_cache(self) -> None:
165-
if hasattr(self, "_model"):
166-
del self._model
167-
if hasattr(self, "_tokenizer"):
168-
del self._tokenizer

0 commit comments

Comments
 (0)