Skip to content

Commit 3aecbbc

Browse files
committed
fixed remarks
1 parent 0d5aaa5 commit 3aecbbc

File tree

1 file changed

+17
-6
lines changed
  • autointent/modules/scoring/_lora

1 file changed

+17
-6
lines changed

autointent/modules/scoring/_lora/lora.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
batch_size: int = 8,
3838
learning_rate: float = 5e-5,
3939
seed: int = 0,
40-
report_to: REPORTERS_NAMES | None = None, # type: ignore # noqa: PGH003
40+
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
4141
**lora_kwargs: Any, # noqa: ANN401
4242
) -> None:
4343
self.model_config = HFModelConfig.from_search_config(model_config)
@@ -53,7 +53,7 @@ def from_context(
5353
cls,
5454
context: Context,
5555
model_config: HFModelConfig | str | dict[str, Any] | None = None,
56-
num_train_epochs: int = 10,
56+
num_train_epochs: int = 3,
5757
batch_size: int = 8,
5858
learning_rate: float = 5e-5,
5959
seed: int = 0,
@@ -67,7 +67,7 @@ def from_context(
6767
batch_size=batch_size,
6868
learning_rate=learning_rate,
6969
seed=seed,
70-
report_to=context.logging_config.report_to
70+
report_to=context.logging_config.report_to,
7171
**lora_kwargs,
7272
)
7373

@@ -86,9 +86,16 @@ def fit(
8686

8787
model_name = self.model_config.model_name
8888
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
89-
self._model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=self._n_classes)
89+
self._model = AutoModelForSequenceClassification.from_pretrained(
90+
model_name,
91+
num_labels=self._n_classes,
92+
problem_type="multi_label_classification" if self._multilabel else "single_label_classification"
93+
)
9094
self._model = get_peft_model(self._model, self._lora_config)
9195

96+
device = torch.device(self.model_config.device)
97+
self._model = self._model.to(device)
98+
9299
use_cpu = self.model_config.device == "cpu"
93100

94101
def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
@@ -129,18 +136,22 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
129136
if not hasattr(self, "_model") or not hasattr(self, "_tokenizer"):
130137
msg = "Model is not trained. Call fit() first."
131138
raise RuntimeError(msg)
139+
140+
device = torch.device(self.model_config.device)
141+
self._model = self._model.to(device)
132142

133143
all_predictions = []
134144
for i in range(0, len(utterances), self.batch_size):
135145
batch = utterances[i : i + self.batch_size]
136146
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
147+
inputs = {k: v.to(device) for k, v in inputs.items()}
137148
with torch.no_grad():
138149
outputs = self._model(**inputs)
139150
logits = outputs.logits
140151
if self._multilabel:
141-
batch_predictions = torch.sigmoid(logits).numpy()
152+
batch_predictions = torch.sigmoid(logits).cpu().numpy()
142153
else:
143-
batch_predictions = torch.softmax(logits, dim=1).numpy()
154+
batch_predictions = torch.softmax(logits, dim=1).cpu().numpy()
144155
all_predictions.append(batch_predictions)
145156
return np.vstack(all_predictions) if all_predictions else np.array([])
146157

0 commit comments

Comments
 (0)