11"""CNNScorer class for scoring."""
22
3- from collections import Counter
43import re
4+ from collections import Counter
55from typing import Any
66
77import numpy as np
88import numpy .typing as npt
9- from torch import nn
109import torch
11- from torch .utils .data import TensorDataset , DataLoader
10+ from torch import nn , Tensor
11+ from torch .utils .data import DataLoader , TensorDataset
1212
1313from autointent import Context
1414from autointent ._callbacks import REPORTERS_NAMES
@@ -61,7 +61,7 @@ def from_context(
6161 learning_rate : float = 5e-5 ,
6262 seed : int = 0 ,
6363 ** cnn_kwargs : dict [str , Any ],
64- ) -> CNNScorer :
64+ ) -> " CNNScorer" :
6565 return cls (
6666 num_train_epochs = num_train_epochs ,
6767 batch_size = batch_size ,
@@ -88,7 +88,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
8888
8989 # Initialize model
9090 if self ._vocab is None :
91- raise ValueError ("Vocabulary not built" )
91+ msg = "Vocabulary not built"
92+ raise ValueError (msg )
9293
9394 self ._model = TextCNN (
9495 vocab_size = len (self ._vocab ),
@@ -106,7 +107,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
106107
107108 def predict (self , utterances : list [str ]) -> npt .NDArray [Any ]:
108109 if self ._model is None :
109- raise ValueError ("Model not trained. Call fit() first." )
110+ msg = "Model not trained. Call fit() first."
111+ raise ValueError (msg )
110112
111113 x = self ._text_to_indices (utterances )
112114 x_tensor = torch .tensor (x , dtype = torch .long )
@@ -138,7 +140,8 @@ def _build_vocab(self, utterances: list[str]) -> None:
138140
139141 # Add words to vocabulary
140142 if self ._vocab is None :
141- raise ValueError ("Vocabulary not initialized" )
143+ msg = "Vocabulary not initialized"
144+ raise ValueError (msg )
142145
143146 for word , _ in word_counts .most_common ():
144147 if word not in self ._vocab :
@@ -150,7 +153,8 @@ def _build_vocab(self, utterances: list[str]) -> None:
150153 def _text_to_indices (self , utterances : list [str ]) -> list [list [int ]]:
151154 """Convert utterances to padded sequences of word indices."""
152155 if self ._vocab is None :
153- raise ValueError ("Vocabulary not built" )
156+ msg = "Vocabulary not built"
157+ raise ValueError (msg )
154158
155159 sequences : list [list [int ]] = []
156160 for utterance in utterances :
@@ -170,7 +174,8 @@ def clear_cache(self) -> None:
170174
171175 def _train_model (self , x : torch .Tensor , y : torch .Tensor ) -> None :
172176 if self ._model is None :
173- raise ValueError ("Model not initialized" )
177+ msg = "Model not initialized"
178+ raise ValueError (msg )
174179
175180 dataset = TensorDataset (x , y )
176181 dataloader = DataLoader (dataset , batch_size = self .batch_size , shuffle = True )
@@ -190,4 +195,3 @@ def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
190195 optimizer .step ()
191196
192197 self ._model .eval ()
193-
0 commit comments