@@ -37,7 +37,7 @@ def __init__(
3737 batch_size : int = 8 ,
3838 learning_rate : float = 5e-5 ,
3939 seed : int = 0 ,
40- report_to : REPORTERS_NAMES | None = None , # type: ignore # noqa: PGH003
40+ report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
4141 ** lora_kwargs : Any , # noqa: ANN401
4242 ) -> None :
4343 self .model_config = HFModelConfig .from_search_config (model_config )
@@ -53,7 +53,7 @@ def from_context(
5353 cls ,
5454 context : Context ,
5555 model_config : HFModelConfig | str | dict [str , Any ] | None = None ,
56- num_train_epochs : int = 10 ,
56+ num_train_epochs : int = 3 ,
5757 batch_size : int = 8 ,
5858 learning_rate : float = 5e-5 ,
5959 seed : int = 0 ,
@@ -67,7 +67,7 @@ def from_context(
6767 batch_size = batch_size ,
6868 learning_rate = learning_rate ,
6969 seed = seed ,
70- report_to = context .logging_config .report_to
70+ report_to = context .logging_config .report_to ,
7171 ** lora_kwargs ,
7272 )
7373
@@ -86,9 +86,16 @@ def fit(
8686
8787 model_name = self .model_config .model_name
8888 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
89- self ._model = AutoModelForSequenceClassification .from_pretrained (model_name , num_labels = self ._n_classes )
89+ self ._model = AutoModelForSequenceClassification .from_pretrained (
90+ model_name ,
91+ num_labels = self ._n_classes ,
92+ problem_type = "multi_label_classification" if self ._multilabel else "single_label_classification"
93+ )
9094 self ._model = get_peft_model (self ._model , self ._lora_config )
9195
96+ device = torch .device (self .model_config .device )
97+ self ._model = self ._model .to (device )
98+
9299 use_cpu = self .model_config .device == "cpu"
93100
94101 def tokenize_function (examples : dict [str , Any ]) -> dict [str , Any ]:
@@ -129,18 +136,22 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
129136 if not hasattr (self , "_model" ) or not hasattr (self , "_tokenizer" ):
130137 msg = "Model is not trained. Call fit() first."
131138 raise RuntimeError (msg )
139+
140+ device = torch .device (self .model_config .device )
141+ self ._model = self ._model .to (device )
132142
133143 all_predictions = []
134144 for i in range (0 , len (utterances ), self .batch_size ):
135145 batch = utterances [i : i + self .batch_size ]
136146 inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
147+ inputs = {k : v .to (device ) for k , v in inputs .items ()}
137148 with torch .no_grad ():
138149 outputs = self ._model (** inputs )
139150 logits = outputs .logits
140151 if self ._multilabel :
141- batch_predictions = torch .sigmoid (logits ).numpy ()
152+ batch_predictions = torch .sigmoid (logits ).cpu (). numpy ()
142153 else :
143- batch_predictions = torch .softmax (logits , dim = 1 ).numpy ()
154+ batch_predictions = torch .softmax (logits , dim = 1 ).cpu (). numpy ()
144155 all_predictions .append (batch_predictions )
145156 return np .vstack (all_predictions ) if all_predictions else np .array ([])
146157
0 commit comments