1919from autointent import Context
2020from autointent ._callbacks import REPORTERS_NAMES
2121from autointent .configs import HFModelConfig
22- from autointent .custom_types import ListOfLabels
23- from autointent .modules .base import BaseScorer
22+ from autointent .modules .scoring ._bert import BertScorer
2423
2524
26- class BERTLoRAScorer (BaseScorer ):
25+ class BERTLoRAScorer (BertScorer ):
2726 name = "lora"
2827 supports_multiclass = True
2928 supports_multilabel = True
@@ -32,137 +31,52 @@ class BERTLoRAScorer(BaseScorer):
3231
3332 def __init__ (
3433 self ,
35- transformer_config : HFModelConfig | str | dict [str , Any ] | None = None ,
34+ classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
3635 num_train_epochs : int = 3 ,
3736 batch_size : int = 8 ,
3837 learning_rate : float = 5e-5 ,
3938 seed : int = 0 ,
4039 report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
4140 ** lora_kwargs : dict [str , Any ],
4241 ) -> None :
43- self .transformer_config = HFModelConfig .from_search_config (transformer_config )
44- self .num_train_epochs = num_train_epochs
45- self .batch_size = batch_size
46- self .learning_rate = learning_rate
47- self .seed = seed
48- self .report_to = report_to
42+ super (BERTLoRAScorer , self ).__init__ (
43+ classification_model_config = classification_model_config ,
44+ num_train_epochs = num_train_epochs ,
45+ batch_size = batch_size ,
46+ learning_rate = learning_rate ,
47+ seed = seed ,
48+ report_to = report_to , # type: ignore[no-any-return]
49+ )
4950 self ._lora_config = LoraConfig (** lora_kwargs )
5051
5152 @classmethod
5253 def from_context (
5354 cls ,
5455 context : Context ,
55- transformer_config : HFModelConfig | str | dict [str , Any ] | None = None ,
56+ classification_model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
5657 num_train_epochs : int = 3 ,
5758 batch_size : int = 8 ,
5859 learning_rate : float = 5e-5 ,
5960 seed : int = 0 ,
6061 ** lora_kwargs : dict [str , Any ],
6162 ) -> "BERTLoRAScorer" :
62- if transformer_config is None :
63- transformer_config = context .resolve_embedder ()
63+ if classification_model_config is None :
64+ classification_model_config = context .resolve_embedder ()
6465 return cls (
65- transformer_config = transformer_config ,
66+ classification_model_config = classification_model_config ,
6667 num_train_epochs = num_train_epochs ,
6768 batch_size = batch_size ,
6869 learning_rate = learning_rate ,
6970 seed = seed ,
7071 report_to = context .logging_config .report_to ,
7172 ** lora_kwargs ,
7273 )
73-
74- def get_embedder_config (self ) -> dict [str , Any ]:
75- return self .transformer_config .model_dump ()
76-
77- def fit (
78- self ,
79- utterances : list [str ],
80- labels : ListOfLabels ,
81- ) -> None :
82- if hasattr (self , "_model" ):
83- self .clear_cache ()
84-
85- self ._validate_task (labels )
86-
87- model_name = self .transformer_config .model_name
88- self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
74+
75+ def __initialize_model (self , ):
8976 self ._model = AutoModelForSequenceClassification .from_pretrained (
90- model_name ,
77+ self . classification_model_config . model_name ,
9178 num_labels = self ._n_classes ,
9279 problem_type = "multi_label_classification" if self ._multilabel else "single_label_classification" ,
93- trust_remote_code = self .transformer_config .trust_remote_code ,
80+ trust_remote_code = self .classification_model_config .trust_remote_code ,
9481 )
9582 self ._model = get_peft_model (self ._model , self ._lora_config )
96-
97- device = torch .device (self .transformer_config .device if self .transformer_config .device else "cpu" )
98- self ._model = self ._model .to (device )
99-
100- use_cpu = self .transformer_config .device == "cpu"
101-
102- def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
103- return self ._tokenizer ( # type: ignore[no-any-return]
104- examples ["text" ], return_tensors = "pt" , ** self .transformer_config .tokenizer_config .model_dump ()
105- )
106-
107- dataset = Dataset .from_dict ({"text" : utterances , "labels" : labels })
108- if self ._multilabel :
109- dataset = dataset .map (
110- lambda example : {"label" : torch .tensor (example ["labels" ], dtype = torch .float )}, remove_columns = ["labels" ]
111- )
112- dataset = dataset .rename_column ("label" , "labels" )
113- tokenized_dataset = dataset .map (tokenize_function , batched = True )
114-
115- with tempfile .TemporaryDirectory () as tmp_dir :
116- training_args = TrainingArguments (
117- output_dir = tmp_dir ,
118- num_train_epochs = self .num_train_epochs ,
119- per_device_train_batch_size = self .batch_size ,
120- learning_rate = self .learning_rate ,
121- seed = self .seed ,
122- save_strategy = "no" ,
123- logging_strategy = "steps" ,
124- logging_steps = 10 ,
125- report_to = self .report_to ,
126- use_cpu = use_cpu ,
127- )
128-
129- trainer = Trainer (
130- model = self ._model ,
131- args = training_args ,
132- train_dataset = tokenized_dataset ,
133- tokenizer = self ._tokenizer ,
134- data_collator = DataCollatorWithPadding (tokenizer = self ._tokenizer ),
135- )
136-
137- trainer .train ()
138-
139- self ._model .eval ()
140-
141- def predict (self , utterances : list [str ]) -> npt .NDArray [Any ]:
142- if not hasattr (self , "_model" ) or not hasattr (self , "_tokenizer" ):
143- msg = "Model is not trained. Call fit() first."
144- raise RuntimeError (msg )
145-
146- device = torch .device (self .transformer_config .device if self .transformer_config .device else "cpu" )
147- self ._model = self ._model .to (device )
148-
149- all_predictions = []
150- for i in range (0 , len (utterances ), self .batch_size ):
151- batch = utterances [i : i + self .batch_size ]
152- inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .transformer_config .tokenizer_config .model_dump ())
153- inputs = {k : v .to (device ) for k , v in inputs .items ()}
154- with torch .no_grad ():
155- outputs = self ._model (** inputs )
156- logits = outputs .logits
157- if self ._multilabel :
158- batch_predictions = torch .sigmoid (logits ).cpu ().numpy ()
159- else :
160- batch_predictions = torch .softmax (logits , dim = 1 ).cpu ().numpy ()
161- all_predictions .append (batch_predictions )
162- return np .vstack (all_predictions ) if all_predictions else np .array ([])
163-
164- def clear_cache (self ) -> None :
165- if hasattr (self , "_model" ):
166- del self ._model
167- if hasattr (self , "_tokenizer" ):
168- del self ._tokenizer
0 commit comments