11"""BertScorer class for transformer-based classification with LoRA."""
22
3- import tempfile
43from typing import Any
54
6- import numpy as np
7- import numpy .typing as npt
8- import torch
9- from datasets import Dataset
105from peft import LoraConfig , get_peft_model
11- from transformers import (
12- AutoModelForSequenceClassification ,
13- AutoTokenizer ,
14- DataCollatorWithPadding ,
15- Trainer ,
16- TrainingArguments ,
17- )
6+ from transformers import AutoModelForSequenceClassification
187
198from autointent import Context
209from autointent ._callbacks import REPORTERS_NAMES
@@ -39,7 +28,7 @@ def __init__(
3928 report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
4029 ** lora_kwargs : dict [str , Any ],
4130 ) -> None :
42- super (BERTLoRAScorer , self ).__init__ (
31+ super ().__init__ (
4332 classification_model_config = classification_model_config ,
4433 num_train_epochs = num_train_epochs ,
4534 batch_size = batch_size ,
@@ -71,8 +60,8 @@ def from_context(
7160 report_to = context .logging_config .report_to ,
7261 ** lora_kwargs ,
7362 )
74-
75- def __initialize_model (self , ) :
63+
64+ def __initialize_model (self ) -> None :
7665 self ._model = AutoModelForSequenceClassification .from_pretrained (
7766 self .classification_model_config .model_name ,
7867 num_labels = self ._n_classes ,
0 commit comments