Skip to content

Commit 2155c09

Browse files
committed
fix ruff
1 parent f60f167 commit 2155c09

File tree

2 files changed

+6
-17
lines changed

2 files changed

+6
-17
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def from_context(
7171

7272
def get_embedder_config(self) -> dict[str, Any]:
7373
return self.classification_model_config.model_dump()
74-
75-
def __initialize_model(self):
74+
75+
def __initialize_model(self) -> None:
7676
label2id = {i: i for i in range(self._n_classes)}
7777
id2label = {i: i for i in range(self._n_classes)}
7878

autointent/modules/scoring/_lora/lora.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,9 @@
11
"""BertScorer class for transformer-based classification with LoRA."""
22

3-
import tempfile
43
from typing import Any
54

6-
import numpy as np
7-
import numpy.typing as npt
8-
import torch
9-
from datasets import Dataset
105
from peft import LoraConfig, get_peft_model
11-
from transformers import (
12-
AutoModelForSequenceClassification,
13-
AutoTokenizer,
14-
DataCollatorWithPadding,
15-
Trainer,
16-
TrainingArguments,
17-
)
6+
from transformers import AutoModelForSequenceClassification
187

198
from autointent import Context
209
from autointent._callbacks import REPORTERS_NAMES
@@ -39,7 +28,7 @@ def __init__(
3928
report_to: REPORTERS_NAMES | None = None, # type: ignore[no-any-return]
4029
**lora_kwargs: dict[str, Any],
4130
) -> None:
42-
super(BERTLoRAScorer, self).__init__(
31+
super().__init__(
4332
classification_model_config=classification_model_config,
4433
num_train_epochs=num_train_epochs,
4534
batch_size=batch_size,
@@ -71,8 +60,8 @@ def from_context(
7160
report_to=context.logging_config.report_to,
7261
**lora_kwargs,
7362
)
74-
75-
def __initialize_model(self, ):
63+
64+
def __initialize_model(self) -> None:
7665
self._model = AutoModelForSequenceClassification.from_pretrained(
7766
self.classification_model_config.model_name,
7867
num_labels=self._n_classes,

0 commit comments

Comments
 (0)