1717)
1818
1919from autointent import Context
20+ from autointent ._callbacks import REPORTERS_NAMES
2021from autointent .configs import HFModelConfig
2122from autointent .custom_types import ListOfLabels
2223from autointent .modules .base import BaseScorer
@@ -26,7 +27,6 @@ class BERTLoRAScorer(BaseScorer):
2627 name = "lora"
2728 supports_multiclass = True
2829 supports_multilabel = True
29- _multilabel : bool
3030 _model : Any
3131 _tokenizer : Any
3232
@@ -37,14 +37,15 @@ def __init__(
3737 batch_size : int = 8 ,
3838 learning_rate : float = 5e-5 ,
3939 seed : int = 0 ,
40- ** lora_kwargs : Any ,
40+ report_to : REPORTERS_NAMES | None = None , # type: ignore
41+ ** lora_kwargs : Any , # noqa: ANN401
4142 ) -> None :
4243 self .model_config = HFModelConfig .from_search_config (model_config )
4344 self .num_train_epochs = num_train_epochs
4445 self .batch_size = batch_size
4546 self .learning_rate = learning_rate
4647 self .seed = seed
47- self ._multilabel = False
48+ self .report_to = report_to
4849 self ._lora_config = LoraConfig (** lora_kwargs )
4950
5051 @classmethod
@@ -56,7 +57,7 @@ def from_context(
5657 batch_size : int = 8 ,
5758 learning_rate : float = 5e-5 ,
5859 seed : int = 0 ,
59- ** lora_kwargs : Any ,
60+ ** lora_kwargs : Any , # noqa: ANN401
6061 ) -> "BERTLoRAScorer" :
6162 if model_config is None :
6263 model_config = context .resolve_embedder ()
@@ -66,17 +67,13 @@ def from_context(
6667 batch_size = batch_size ,
6768 learning_rate = learning_rate ,
6869 seed = seed ,
70+ report_to = context .logging_config .report_to
6971 ** lora_kwargs ,
7072 )
7173
7274 def get_embedder_config (self ) -> dict [str , Any ]:
7375 return self .model_config .model_dump ()
7476
75- def _validate_task (self , labels : ListOfLabels ) -> None :
76- """Validate the task and set _multilabel flag."""
77- super ()._validate_task (labels )
78- self ._multilabel = isinstance (labels [0 ], list )
79-
8077 def fit (
8178 self ,
8279 utterances : list [str ],
@@ -87,20 +84,12 @@ def fit(
8784
8885 self ._validate_task (labels )
8986
90- if self ._multilabel :
91- labels_array = np .array (labels )
92- num_labels = labels_array .shape [1 ]
93- else :
94- num_labels = len (set (labels ))
95-
9687 model_name = self .model_config .model_name
9788 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
98- self ._model = AutoModelForSequenceClassification .from_pretrained (model_name , num_labels = num_labels )
99-
100- # Apply LoRA to the model
89+ self ._model = AutoModelForSequenceClassification .from_pretrained (model_name , num_labels = self ._n_classes )
10190 self ._model = get_peft_model (self ._model , self ._lora_config )
10291
103- use_cpu = hasattr ( self . model_config , "device" ) and self .model_config .device == "cpu"
92+ use_cpu = self .model_config .device == "cpu"
10493
10594 def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
10695 return self ._tokenizer ( # type: ignore[no-any-return]
@@ -120,7 +109,7 @@ def tokenize_function(examples: dict[str, Any]) -> dict[str, Any]:
120109 save_strategy = "no" ,
121110 logging_strategy = "steps" ,
122111 logging_steps = 10 ,
123- report_to = "wandb" ,
112+ report_to = self . report_to ,
124113 use_cpu = use_cpu ,
125114 )
126115
@@ -141,17 +130,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
141130 msg = "Model is not trained. Call fit() first."
142131 raise RuntimeError (msg )
143132
144- inputs = self ._tokenizer (
145- utterances , padding = True , truncation = True , max_length = self .model_config .tokenizer_config .max_length , return_tensors = "pt"
146- )
147-
148- with torch .no_grad ():
149- outputs = self ._model (** inputs )
150- logits = outputs .logits
151-
152- if self ._multilabel :
153- return torch .sigmoid (logits ).numpy ()
154- return torch .softmax (logits , dim = 1 ).numpy ()
133+ all_predictions = []
134+ for i in range (0 , len (utterances ), self .batch_size ):
135+ batch = utterances [i : i + self .batch_size ]
136+ inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
137+ with torch .no_grad ():
138+ outputs = self ._model (** inputs )
139+ logits = outputs .logits
140+ if self ._multilabel :
141+ batch_predictions = torch .sigmoid (logits ).numpy ()
142+ else :
143+ batch_predictions = torch .softmax (logits , dim = 1 ).numpy ()
144+ all_predictions .append (batch_predictions )
145+ return np .vstack (all_predictions ) if all_predictions else np .array ([])
155146
156147 def clear_cache (self ) -> None :
157148 if hasattr (self , "_model" ):
0 commit comments