11"""CNNScorer class for scoring."""
22
3- from __future__ import annotations
4-
5- import re
63from collections import Counter
7- from typing import Any , Dict , List , Optional , Union
4+ import re
5+ from typing import Any
86
97import numpy as np
108import numpy .typing as npt
9+ from torch import nn
1110import torch
12- from torch import nn , Tensor
13- from torch .utils .data import DataLoader , TensorDataset
11+ from torch .utils .data import TensorDataset , DataLoader
1412
1513from autointent import Context
1614from autointent ._callbacks import REPORTERS_NAMES
@@ -33,8 +31,8 @@ def __init__(
3331 batch_size : int = 8 ,
3432 learning_rate : float = 5e-5 ,
3533 seed : int = 0 ,
36- report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
37- ** cnn_kwargs : Dict [str , Any ],
34+ report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
35+ ** cnn_kwargs : dict [str , Any ],
3836 ) -> None :
3937 self .max_seq_length = max_seq_length
4038 self .num_train_epochs = num_train_epochs
@@ -43,10 +41,10 @@ def __init__(
4341 self .seed = seed
4442 self .report_to = report_to
4543 self .cnn_config = cnn_kwargs
46-
44+
4745 # Will be initialized during fit()
48- self ._model : Optional [ TextCNN ] = None
49- self ._vocab : Optional [ Dict [ str , int ]] = None
46+ self ._model : TextCNN | None = None
47+ self ._vocab : dict [ str , int ] | None = None
5048 self ._padding_idx = 0
5149 self ._unk_token = "<UNK>" # noqa: S105
5250 self ._pad_token = "<PAD>" # noqa: S105
@@ -62,8 +60,8 @@ def from_context(
6260 batch_size : int = 8 ,
6361 learning_rate : float = 5e-5 ,
6462 seed : int = 0 ,
65- ** cnn_kwargs : Dict [str , Any ],
66- ) -> " CNNScorer" :
63+ ** cnn_kwargs : dict [str , Any ],
64+ ) -> CNNScorer :
6765 return cls (
6866 num_train_epochs = num_train_epochs ,
6967 batch_size = batch_size ,
@@ -73,23 +71,25 @@ def from_context(
7371 ** cnn_kwargs ,
7472 )
7573
76- def fit (self , utterances : List [str ], labels : ListOfLabels ) -> None :
74+ def fit (self , utterances : list [str ], labels : ListOfLabels ) -> None :
7775 self ._validate_task (labels )
78- self ._multilabel = isinstance (labels [0 ], ( list , np .ndarray ) )
76+ self ._multilabel = isinstance (labels [0 ], list | np .ndarray )
7977 self ._n_classes = len (labels [0 ]) if self ._multilabel else len (set (labels ))
80-
78+
8179 # Build vocabulary and tokenize
8280 self ._build_vocab (utterances )
83-
81+
8482 # Convert text to padded indices
8583 x = self ._text_to_indices (utterances )
8684 x_tensor = torch .tensor (x , dtype = torch .long )
87- y_tensor = torch .tensor (labels , dtype = torch .long if not self ._multilabel else torch .float )
88-
85+ y_tensor = torch .tensor (
86+ labels , dtype = torch .long if not self ._multilabel else torch .float
87+ )
88+
8989 # Initialize model
9090 if self ._vocab is None :
91- raise RuntimeError ("Vocabulary not built" )
92-
91+ raise ValueError ("Vocabulary not built" )
92+
9393 self ._model = TextCNN (
9494 vocab_size = len (self ._vocab ),
9595 n_classes = self ._n_classes ,
@@ -98,70 +98,67 @@ def fit(self, utterances: List[str], labels: ListOfLabels) -> None:
9898 num_filters = self .cnn_config .get ("num_filters" , 100 ),
9999 dropout = self .cnn_config .get ("dropout" , 0.1 ),
100100 padding_idx = self ._padding_idx ,
101- pretrained_embs = self .cnn_config .get ("pretrained_embs" , None )
101+ pretrained_embs = self .cnn_config .get ("pretrained_embs" , None ),
102102 )
103-
103+
104104 # Training
105105 self ._train_model (x_tensor , y_tensor )
106106
107- def predict (self , utterances : List [str ]) -> npt .NDArray [Any ]:
107+ def predict (self , utterances : list [str ]) -> npt .NDArray [Any ]:
108108 if self ._model is None :
109- raise RuntimeError ("Model not trained. Call fit() first." )
110-
109+ raise ValueError ("Model not trained. Call fit() first." )
110+
111111 x = self ._text_to_indices (utterances )
112112 x_tensor = torch .tensor (x , dtype = torch .long )
113-
113+
114114 self ._model .eval ()
115- all_probs : List [npt .NDArray [Any ]] = []
116-
115+ all_probs : list [npt .NDArray [Any ]] = []
116+
117117 with torch .no_grad ():
118118 for i in range (0 , len (x_tensor ), self .batch_size ):
119- batch_x = x_tensor [i : i + self .batch_size ]
119+ batch_x = x_tensor [i : i + self .batch_size ]
120120 outputs = self ._model (batch_x )
121121 if self ._multilabel :
122122 probs = torch .sigmoid (outputs ).cpu ().numpy ()
123123 else :
124124 probs = torch .softmax (outputs , dim = 1 ).cpu ().numpy ()
125125 all_probs .append (probs )
126-
126+
127127 return np .concatenate (all_probs , axis = 0 ) if all_probs else np .array ([])
128128
129- def _build_vocab (self , utterances : List [str ]) -> None :
129+ def _build_vocab (self , utterances : list [str ]) -> None :
130130 """Build vocabulary from training utterances."""
131- word_counts : Dict [str , int ] = Counter ()
131+ word_counts : dict [str , int ] = Counter ()
132132 for utterance in utterances :
133133 words = re .findall (r"\w+" , utterance .lower ())
134134 word_counts .update (words )
135-
135+
136136 # Create vocabulary with special tokens
137- self ._vocab = {
138- self ._pad_token : 0 ,
139- self ._unk_token : 1
140- }
141-
137+ self ._vocab = {self ._pad_token : 0 , self ._unk_token : 1 }
138+
142139 # Add words to vocabulary
143140 if self ._vocab is None :
144- raise RuntimeError ("Vocabulary not initialized" )
145-
141+ raise ValueError ("Vocabulary not initialized" )
142+
146143 for word , _ in word_counts .most_common ():
147144 if word not in self ._vocab :
148145 self ._vocab [word ] = len (self ._vocab )
149-
146+
150147 self ._unk_idx = 1
151148 self ._padding_idx = 0
152149
153- def _text_to_indices (self , utterances : List [str ]) -> List [ List [int ]]:
150+ def _text_to_indices (self , utterances : list [str ]) -> list [ list [int ]]:
154151 """Convert utterances to padded sequences of word indices."""
155152 if self ._vocab is None :
156- raise RuntimeError ("Vocabulary not built" )
157-
158- sequences : List [ List [int ]] = []
153+ raise ValueError ("Vocabulary not built" )
154+
155+ sequences : list [ list [int ]] = []
159156 for utterance in utterances :
160157 words = re .findall (r"\w+" , utterance .lower ())
161158 # Convert words to indices, using UNK for unknown words
162- seq = [self ._vocab .get (word , self ._unk_idx ) for word in words ] # type: ignore
159+ seq = [self ._vocab .get (word , self ._unk_idx ) for word in words ] # type: ignore[union-attr]
163160 # Truncate if too long
164- seq = seq [:self .max_seq_length ]
161+ seq = seq [: self .max_seq_length ]
165162 # Pad if too short
166163 seq = seq + [self ._padding_idx ] * (self .max_seq_length - len (seq ))
167164 sequences .append (seq )
@@ -171,20 +168,18 @@ def clear_cache(self) -> None:
171168 self ._model = None
172169 torch .cuda .empty_cache ()
173170
174- def _train_model (self , x : Tensor , y : Tensor ) -> None :
171+ def _train_model (self , x : torch . Tensor , y : torch . Tensor ) -> None :
175172 if self ._model is None :
176- raise RuntimeError ("Model not initialized" )
177-
173+ raise ValueError ("Model not initialized" )
174+
178175 dataset = TensorDataset (x , y )
179- dataloader = DataLoader (
180- dataset ,
181- batch_size = self . batch_size ,
182- shuffle = True
176+ dataloader = DataLoader (dataset , batch_size = self . batch_size , shuffle = True )
177+
178+ criterion = (
179+ nn . CrossEntropyLoss () if not self . _multilabel else nn . BCEWithLogitsLoss ()
183180 )
184-
185- criterion = nn .CrossEntropyLoss () if not self ._multilabel else nn .BCEWithLogitsLoss ()
186181 optimizer = torch .optim .Adam (self ._model .parameters (), lr = self .learning_rate )
187-
182+
188183 self ._model .train ()
189184 for _ in range (self .num_train_epochs ):
190185 for batch_x , batch_y in dataloader :
@@ -193,5 +188,6 @@ def _train_model(self, x: Tensor, y: Tensor) -> None:
193188 loss = criterion (outputs , batch_y )
194189 loss .backward ()
195190 optimizer .step ()
196-
191+
197192 self ._model .eval ()
193+
0 commit comments