11"""CNNScorer class for scoring."""
22
3- from collections import Counter
3+ from __future__ import annotations
4+
45import re
5- from typing import Any
6+ from collections import Counter
7+ from typing import Any , Dict , List , Optional , Union
68
79import numpy as np
810import numpy .typing as npt
9- from torch import nn
1011import torch
11- from torch .utils .data import TensorDataset , DataLoader
12+ from torch import nn , Tensor
13+ from torch .utils .data import DataLoader , TensorDataset
1214
1315from autointent import Context
1416from autointent ._callbacks import REPORTERS_NAMES
@@ -21,8 +23,6 @@ class CNNScorer(BaseScorer):
2123 """Convolutional Neural Network (CNN) scorer for intent classification."""
2224
2325 name = "cnn"
24- _n_classes : int
25- _multilabel : bool
2626 supports_multilabel = True
2727 supports_multiclass = True
2828
@@ -33,8 +33,8 @@ def __init__(
3333 batch_size : int = 8 ,
3434 learning_rate : float = 5e-5 ,
3535 seed : int = 0 ,
36- report_to : REPORTERS_NAMES | None = None ,
37- ** cnn_kwargs : dict [str , Any ],
36+ report_to : REPORTERS_NAMES | None = None , # type: ignore[no-any-return]
37+ ** cnn_kwargs : Dict [str , Any ],
3838 ) -> None :
3939 self .max_seq_length = max_seq_length
4040 self .num_train_epochs = num_train_epochs
@@ -45,11 +45,14 @@ def __init__(
4545 self .cnn_config = cnn_kwargs
4646
4747 # Will be initialized during fit()
48- self ._model = None
49- self ._vocab = None
48+ self ._model : Optional [ TextCNN ] = None
49+ self ._vocab : Optional [ Dict [ str , int ]] = None
5050 self ._padding_idx = 0
5151 self ._unk_token = "<UNK>" # noqa: S105
5252 self ._pad_token = "<PAD>" # noqa: S105
53+ self ._unk_idx = 1
54+ self ._n_classes : int = 0
55+ self ._multilabel : bool = False
5356
5457 @classmethod
5558 def from_context (
@@ -59,7 +62,7 @@ def from_context(
5962 batch_size : int = 8 ,
6063 learning_rate : float = 5e-5 ,
6164 seed : int = 0 ,
62- ** cnn_kwargs : dict [str , Any ],
65+ ** cnn_kwargs : Dict [str , Any ],
6366 ) -> "CNNScorer" :
6467 return cls (
6568 num_train_epochs = num_train_epochs ,
@@ -70,22 +73,23 @@ def from_context(
7073 ** cnn_kwargs ,
7174 )
7275
73- def fit (self , utterances : list [str ], labels : ListOfLabels , clear_cache : bool = False ) -> None :
74- if clear_cache :
75- self .clear_cache ()
76-
76+ def fit (self , utterances : List [str ], labels : ListOfLabels ) -> None :
7777 self ._validate_task (labels )
78- self ._multilabel = isinstance (labels [0 ], list | np .ndarray )
78+ self ._multilabel = isinstance (labels [0 ], (list , np .ndarray ))
79+ self ._n_classes = len (labels [0 ]) if self ._multilabel else len (set (labels ))
7980
8081 # Build vocabulary and tokenize
8182 self ._build_vocab (utterances )
8283
8384 # Convert text to padded indices
8485 x = self ._text_to_indices (utterances )
85- x = torch .tensor (x , dtype = torch .long )
86- y = torch .tensor (labels , dtype = torch .long )
86+ x_tensor = torch .tensor (x , dtype = torch .long )
87+ y_tensor = torch .tensor (labels , dtype = torch .long if not self . _multilabel else torch . float )
8788
8889 # Initialize model
90+ if self ._vocab is None :
91+ raise RuntimeError ("Vocabulary not built" )
92+
8993 self ._model = TextCNN (
9094 vocab_size = len (self ._vocab ),
9195 n_classes = self ._n_classes ,
@@ -98,22 +102,21 @@ def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = F
98102 )
99103
100104 # Training
101- self ._train_model (x , y )
105+ self ._train_model (x_tensor , y_tensor )
102106
103- def predict (self , utterances : list [str ]) -> npt .NDArray [Any ]:
107+ def predict (self , utterances : List [str ]) -> npt .NDArray [Any ]:
104108 if self ._model is None :
105- error_msg = "Model not trained. Call fit() first."
106- raise RuntimeError (error_msg )
109+ raise RuntimeError ("Model not trained. Call fit() first." )
107110
108111 x = self ._text_to_indices (utterances )
109- x = torch .tensor (x , dtype = torch .long )
112+ x_tensor = torch .tensor (x , dtype = torch .long )
110113
111114 self ._model .eval ()
112- all_probs = []
115+ all_probs : List [ npt . NDArray [ Any ]] = []
113116
114117 with torch .no_grad ():
115- for i in range (0 , len (x ), self .batch_size ):
116- batch_x = x [i :i + self .batch_size ]
118+ for i in range (0 , len (x_tensor ), self .batch_size ):
119+ batch_x = x_tensor [i :i + self .batch_size ]
117120 outputs = self ._model (batch_x )
118121 if self ._multilabel :
119122 probs = torch .sigmoid (outputs ).cpu ().numpy ()
@@ -123,9 +126,9 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
123126
124127 return np .concatenate (all_probs , axis = 0 ) if all_probs else np .array ([])
125128
126- def _build_vocab (self , utterances : list [str ]) -> None :
129+ def _build_vocab (self , utterances : List [str ]) -> None :
127130 """Build vocabulary from training utterances."""
128- word_counts = Counter ()
131+ word_counts : Dict [ str , int ] = Counter ()
129132 for utterance in utterances :
130133 words = re .findall (r"\w+" , utterance .lower ())
131134 word_counts .update (words )
@@ -137,20 +140,26 @@ def _build_vocab(self, utterances: list[str]) -> None:
137140 }
138141
139142 # Add words to vocabulary
143+ if self ._vocab is None :
144+ raise RuntimeError ("Vocabulary not initialized" )
145+
140146 for word , _ in word_counts .most_common ():
141147 if word not in self ._vocab :
142148 self ._vocab [word ] = len (self ._vocab )
143149
144150 self ._unk_idx = 1
145151 self ._padding_idx = 0
146152
147- def _text_to_indices (self , utterances : list [str ]) -> list [ list [int ]]:
153+ def _text_to_indices (self , utterances : List [str ]) -> List [ List [int ]]:
148154 """Convert utterances to padded sequences of word indices."""
149- sequences = []
155+ if self ._vocab is None :
156+ raise RuntimeError ("Vocabulary not built" )
157+
158+ sequences : List [List [int ]] = []
150159 for utterance in utterances :
151160 words = re .findall (r"\w+" , utterance .lower ())
152161 # Convert words to indices, using UNK for unknown words
153- seq = [self ._vocab .get (word , self ._unk_idx ) for word in words ]
162+ seq = [self ._vocab .get (word , self ._unk_idx ) for word in words ] # type: ignore
154163 # Truncate if too long
155164 seq = seq [:self .max_seq_length ]
156165 # Pad if too short
@@ -162,7 +171,10 @@ def clear_cache(self) -> None:
162171 self ._model = None
163172 torch .cuda .empty_cache ()
164173
165- def _train_model (self , x : torch .Tensor , y : torch .Tensor ) -> None :
174+ def _train_model (self , x : Tensor , y : Tensor ) -> None :
175+ if self ._model is None :
176+ raise RuntimeError ("Model not initialized" )
177+
166178 dataset = TensorDataset (x , y )
167179 dataloader = DataLoader (
168180 dataset ,
@@ -182,4 +194,4 @@ def _train_model(self, x: torch.Tensor, y: torch.Tensor) -> None:
182194 loss .backward ()
183195 optimizer .step ()
184196
185- self ._model .eval ()
197+ self ._model .eval ()
0 commit comments