1- """CNNScorer class for scoring."""
2-
3- import re
4- from collections import Counter
51from typing import Any
62
73import numpy as np
84import numpy .typing as npt
95import torch
106from torch import nn
11- from torch .utils . data import DataLoader , TensorDataset
7+ from torch .optim import Adam
128
139from autointent import Context
1410from autointent ._callbacks import REPORTERS_NAMES
1915
2016
2117class CNNScorer (BaseScorer ):
22- """Convolutional Neural Network ( CNN) scorer for intent classification."""
18+ """Scorer based on CNN model for text classification."""
2319
2420 name = "cnn"
25- supports_multilabel = True
2621 supports_multiclass = True
22+ supports_multilabel = True
2723
2824 def __init__ (
2925 self ,
30- num_train_epochs : int = 3 ,
31- learning_rate : float = 5e-5 ,
32- seed : int = 0 ,
33- report_to : REPORTERS_NAMES | None = None , # type: ignore[valid-type]
3426 embed_dim : int = 128 ,
35- kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
27+ kernel_sizes : list [int ] = [3 , 4 , 5 ],
3628 num_filters : int = 100 ,
3729 dropout : float = 0.1 ,
38- batch_size : int = 8 ,
3930 cnn_config : CNNConfig | str | dict [str , Any ] | None = None ,
31+ num_train_epochs : int = 3 ,
32+ batch_size : int = 8 ,
33+ learning_rate : float = 5e-5 ,
34+ seed : int = 0 ,
35+ report_to : REPORTERS_NAMES | None = None , # type: ignore # noqa: PGH003
4036 ) -> None :
41- self .num_train_epochs = num_train_epochs
42- self .learning_rate = learning_rate
43- self .seed = seed
44- self .report_to = report_to
37+ """Initialize the CNN scorer."""
4538 self .embed_dim = embed_dim
4639 self .kernel_sizes = kernel_sizes
4740 self .num_filters = num_filters
4841 self .dropout = dropout
4942 self .cnn_config = CNNConfig .from_search_config (cnn_config )
50-
51- # Will be initialized during fit()
52- self ._model : TextCNN | None = None
53- self ._vocab : dict [str , int ] | None = None
54- self ._unk_token = "<UNK>" # noqa: S105
55- self ._pad_token = "<PAD>" # noqa: S105
56- self ._n_classes : int = 0
57- self ._multilabel : bool = False
58- self ._pad_idx = self .cnn_config .padding_idx
59- self ._unk_idx = self .cnn_config .unknown_idx
60- self .batch_size = batch_size
61- self .max_seq_length = self .cnn_config .max_seq_length
43+ self .num_train_epochs = num_train_epochs
44+ self .batch_size = batch_size or self .cnn_config .batch_size
45+ self .learning_rate = learning_rate
46+ self .seed = seed
47+ self .report_to = report_to
48+ self ._artifact = None
49+ self ._device = self .cnn_config .device or ("cuda" if torch .cuda .is_available () else "cpu" )
6250
6351 @classmethod
6452 def from_context (
6553 cls ,
6654 context : Context ,
55+ embed_dim : int = 128 ,
56+ kernel_sizes : list [int ] = [3 , 4 , 5 ],
57+ num_filters : int = 100 ,
58+ dropout : float = 0.1 ,
59+ cnn_config : CNNConfig | str | dict [str , Any ] | None = None ,
6760 num_train_epochs : int = 3 ,
6861 batch_size : int = 8 ,
6962 learning_rate : float = 5e-5 ,
7063 seed : int = 0 ,
71- embed_dim : int = 128 ,
72- kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
73- num_filters : int = 100 ,
74- dropout : float = 0.1 ,
75- cnn_config : CNNConfig | str | dict [str , Any ] | None = None
7664 ) -> "CNNScorer" :
65+ """Create a CNNScorer from context."""
66+ report_to = context .logging_config .report_to
67+
7768 return cls (
78- num_train_epochs = num_train_epochs ,
79- batch_size = batch_size ,
80- learning_rate = learning_rate ,
81- seed = seed ,
82- report_to = context .logging_config .report_to ,
8369 embed_dim = embed_dim ,
8470 kernel_sizes = kernel_sizes ,
8571 num_filters = num_filters ,
8672 dropout = dropout ,
87- cnn_config = cnn_config
88- )
89-
90- def get_implicit_initialization_params (self ) -> dict [str , Any ]:
91- return {"cnn_config" : self .cnn_config .model_dump ()}
92-
93- def fit (self , utterances : list [str ], labels : ListOfLabels ) -> None :
94- self ._validate_task (labels )
95- self ._multilabel = isinstance (labels [0 ], (list , np .ndarray )) # noqa: UP038
96-
97- # Build vocabulary and tokenize
98- self ._build_vocab (utterances )
99-
100- # Convert text to padded indices
101- x = self ._text_to_indices (utterances )
102- x_tensor = torch .tensor (x , dtype = torch .long )
103- y_tensor = torch .tensor (
104- labels , dtype = torch .long if not self ._multilabel else torch .float
73+ cnn_config = cnn_config ,
74+ num_train_epochs = num_train_epochs ,
75+ batch_size = batch_size ,
76+ learning_rate = learning_rate ,
77+ seed = seed ,
78+ report_to = report_to ,
10579 )
10680
107- # Initialize model
108- if self ._vocab is None :
109- msg = "Vocabulary not built"
110- raise ValueError (msg )
111-
81+ def get_embedder_config (self ) -> dict [str , Any ]:
82+ """Get the configuration of the embedder."""
83+ config = self .cnn_config .model_dump ()
84+ config .update ({
85+ "embed_dim" : self .embed_dim ,
86+ "kernel_sizes" : self .kernel_sizes ,
87+ "num_filters" : self .num_filters ,
88+ "dropout" : self .dropout ,
89+ })
90+ return config
91+
92+ def __initialize_model (self , vocab_size : int ) -> None :
93+ """Initialize the CNN model."""
11294 self ._model = TextCNN (
113- vocab_size = len ( self . _vocab ) ,
95+ vocab_size = vocab_size ,
11496 n_classes = self ._n_classes ,
11597 embed_dim = self .embed_dim ,
11698 kernel_sizes = self .kernel_sizes ,
11799 num_filters = self .num_filters ,
118100 dropout = self .dropout ,
119- padding_idx = self ._pad_idx
101+ padding_idx = self .cnn_config .padding_idx ,
102+ pretrained_embs = None ,
120103 )
104+ self ._model .to (self .device )
105+
106+ def fit (
107+ self ,
108+ utterances : list [str ],
109+ labels : ListOfLabels ,
110+ ) -> None :
111+ """Fit the model to the given data."""
112+ if hasattr (self , "_model" ):
113+ self .clear_cache ()
114+ self ._validate_task (labels )
115+ self ._create_vocab (utterances )
116+ self .__initialize_model (len (self ._vocab ))
117+ x = self ._texts_to_sequences (utterances )
118+ y = torch .tensor (labels , dtype = torch .float ) if self ._multilabel else torch .tensor (labels , dtype = torch .long )
119+ self ._train_model (x , y )
120+
121+ def _create_vocab (self , utterances : list [str ]) -> None :
122+ """Create vocabulary from utterances."""
123+ unique_words = set ()
124+ for text in utterances :
125+ for word in text .lower ().split ():
126+ unique_words .add (word )
127+
128+ self ._vocab = {"<PAD>" : 0 , "<UNK>" : 1 }
129+ for i , word in enumerate (unique_words ):
130+ self ._vocab [word ] = i + 2
131+
132+ def _texts_to_sequences (self , texts : list [str ]) -> torch .Tensor :
133+ """Convert texts to sequences using the vocabulary."""
134+ sequences = [[self ._vocab .get (word , self ._vocab ["<UNK>" ]) for word in text .lower ().split ()] for text in texts ]
135+
136+ max_len = min (max (len (seq ) for seq in sequences ), self .cnn_config .max_seq_length )
137+ padded_sequences = [
138+ seq [:max_len ] if len (seq ) > max_len else seq + [self ._vocab ["<PAD>" ]] * (max_len - len (seq ))
139+ for seq in sequences
140+ ]
141+
142+ return torch .tensor (padded_sequences , dtype = torch .long )
143+
144+ def _train_model (self , x : torch .Tensor , y : torch .Tensor ) -> None :
145+ """Train the model."""
146+ self ._model .train ()
147+ optimizer = Adam (self ._model .parameters (), lr = self .learning_rate )
148+
149+ criterion = nn .BCEWithLogitsLoss () if self ._multilabel else nn .CrossEntropyLoss ()
150+
151+ x = x .to (self ._device )
152+ y = y .to (self ._device )
153+
154+ dataset = torch .utils .data .TensorDataset (x , y )
155+ dataloader = torch .utils .data .DataLoader (dataset , batch_size = self .batch_size , shuffle = True )
156+
157+ torch .manual_seed (self .seed )
158+
159+ for _epoch in range (self .num_train_epochs ):
160+ total_loss = 0
161+ for batch_x , batch_y in dataloader :
162+ optimizer .zero_grad ()
163+ outputs = self ._model (batch_x )
164+ loss = criterion (outputs , batch_y )
165+ loss .backward ()
166+ optimizer .step ()
167+ total_loss += loss .item ()
121168
122- # Training
123- self ._train_model (x_tensor , y_tensor )
169+ self ._model .eval ()
124170
125171 def predict (self , utterances : list [str ]) -> npt .NDArray [Any ]:
126- if self ._model is None :
127- msg = "Model not trained. Call fit() first."
128- raise ValueError (msg )
172+ """Predict probabilities for utterances."""
173+ if not hasattr (self , "_model" ) or not hasattr (self , "_vocab" ):
174+ msg = "Model is not trained. Call fit() first."
175+ raise RuntimeError (msg )
129176
130- x = self ._text_to_indices (utterances )
131- x_tensor = torch . tensor ( x , dtype = torch . long )
177+ x = self ._texts_to_sequences (utterances )
178+ x = x . to ( self . device )
132179
133180 self ._model .eval ()
134- all_probs : list [ npt . NDArray [ Any ]] = []
181+ all_predictions = []
135182
136183 with torch .no_grad ():
137- for i in range (0 , len (x_tensor ), self .batch_size ):
138- batch_x = x_tensor [i : i + self .batch_size ]
184+ for i in range (0 , len (x ), self .batch_size ):
185+ batch_x = x [i : i + self .batch_size ]
139186 outputs = self ._model (batch_x )
187+
140188 if self ._multilabel :
141- probs = torch .sigmoid (outputs ).cpu ().numpy ()
189+ batch_predictions = torch .sigmoid (outputs ).cpu ().numpy ()
142190 else :
143- probs = torch .softmax (outputs , dim = 1 ).cpu ().numpy ()
144- all_probs .append (probs )
145-
146- return np .concatenate (all_probs , axis = 0 ) if all_probs else np .array ([])
147-
148- def _build_vocab (self , utterances : list [str ]) -> None :
149- """Build vocabulary from training utterances."""
150- word_counts : Counter [str ] = Counter ()
151- for utterance in utterances :
152- words = re .findall (r"\w+" , utterance .lower ())
153- word_counts .update (words )
154-
155- # Create vocabulary with special tokens
156- self ._vocab = {self ._pad_token : self ._pad_idx , self ._unk_token : self ._unk_idx }
157-
158- # Convert Counter to list of (word, count) tuples sorted by frequency
159- sorted_words = word_counts .most_common ()
160- for word , _ in sorted_words :
161- if word not in self ._vocab :
162- self ._vocab [word ] = len (self ._vocab )
163-
164- def _text_to_indices (self , utterances : list [str ]) -> list [list [int ]]:
165- """Convert utterances to padded sequences of word indices."""
166- if self ._vocab is None :
167- msg = "Vocabulary not built"
168- raise ValueError (msg )
169-
170- sequences : list [list [int ]] = []
171- for utterance in utterances :
172- words = re .findall (r"\w+" , utterance .lower ())
173- # Convert words to indices, using UNK for unknown words
174- seq = [self ._vocab .get (word , self ._unk_idx ) for word in words ]
175- # Truncate if too long
176- seq = seq [: self .max_seq_length ]
177- # Pad if too short
178- seq = seq + [self ._pad_idx ] * (self .max_seq_length - len (seq ))
179- sequences .append (seq )
180- return sequences
191+ batch_predictions = torch .softmax (outputs , dim = 1 ).cpu ().numpy ()
181192
182- def clear_cache (self ) -> None :
183- self ._model = None
184- torch .cuda .empty_cache ()
193+ all_predictions .append (batch_predictions )
185194
186- def _train_model (self , x : torch .Tensor , y : torch .Tensor ) -> None :
187- if self ._model is None :
188- msg = "Model not initialized"
189- raise ValueError (msg )
195+ return np .vstack (all_predictions ) if all_predictions else np .array ([])
190196
191- dataset = TensorDataset (x , y )
192- dataloader = DataLoader (dataset , batch_size = self .batch_size , shuffle = True )
197+ def clear_cache (self ) -> None :
198+ """Clear model cache."""
199+ if hasattr (self , "_model" ):
200+ del self ._model
193201
194- criterion = (
195- nn . CrossEntropyLoss () if not self . _multilabel else nn . BCEWithLogitsLoss ()
196- )
197- optimizer = torch . optim . Adam ( self ._model . parameters (), lr = self . learning_rate )
202+ @ property
203+ def device ( self ) -> str :
204+ """Get device used for model computations."""
205+ return self ._device
198206
199- self ._model .train ()
200- for _ in range (self .num_train_epochs ):
201- for batch_x , batch_y in dataloader :
202- optimizer .zero_grad ()
203- outputs = self ._model (batch_x )
204- loss = criterion (outputs , batch_y )
205- loss .backward ()
206- optimizer .step ()
207+ @device .setter
208+ def device (self , value : str ) -> None :
209+ """Set device for model computations."""
210+ self ._device = value
207211
208- self ._model .eval ()
212+ def get_implicit_initialization_params (self ) -> dict [str , Any ]:
213+ """Return default params used in ``__init__`` method."""
214+ return {"cnn_config" : self .cnn_config .model_dump ()}
0 commit comments