1212
1313from autointent import Context
1414from autointent ._callbacks import REPORTERS_NAMES
15+ from autointent .configs import CNNConfig
1516from autointent .custom_types import ListOfLabels
1617from autointent .modules .base import BaseScorer
1718from autointent .modules .scoring ._cnn .textcnn import TextCNN
@@ -26,37 +27,37 @@ class CNNScorer(BaseScorer):
2627
2728 def __init__ (
2829 self ,
29- max_seq_length : int = 50 ,
3030 num_train_epochs : int = 3 ,
31- batch_size : int = 8 ,
3231 learning_rate : float = 5e-5 ,
3332 seed : int = 0 ,
3433 report_to : REPORTERS_NAMES | None = None , # type: ignore[valid-type]
3534 embed_dim : int = 128 ,
3635 kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
3736 num_filters : int = 100 ,
38- dropout : float = 0.1
37+ dropout : float = 0.1 ,
38+ cnn_config : CNNConfig | str | dict [str , Any ] | None = None ,
3939 ) -> None :
40- self .max_seq_length = max_seq_length
4140 self .num_train_epochs = num_train_epochs
42- self .batch_size = batch_size
4341 self .learning_rate = learning_rate
4442 self .seed = seed
4543 self .report_to = report_to
4644 self .embed_dim = embed_dim
4745 self .kernel_sizes = kernel_sizes
4846 self .num_filters = num_filters
4947 self .dropout = dropout
48+ self .cnn_config = CNNConfig .from_search_config (cnn_config )
5049
5150 # Will be initialized during fit()
5251 self ._model : TextCNN | None = None
5352 self ._vocab : dict [str , int ] | None = None
5453 self ._unk_token = "<UNK>" # noqa: S105
5554 self ._pad_token = "<PAD>" # noqa: S105
56- self ._unk_idx = 1
57- self ._pad_idx = 0
5855 self ._n_classes : int = 0
5956 self ._multilabel : bool = False
57+ self ._pad_idx = self .cnn_config .padding_idx
58+ self ._unk_idx = self .cnn_config .unknown_idx
59+ self .batch_size = self .cnn_config .batch_size
60+ self .max_seq_length = self .cnn_config .max_seq_length
6061
6162 @classmethod
6263 def from_context (
@@ -69,7 +70,8 @@ def from_context(
6970 embed_dim : int = 128 ,
7071 kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
7172 num_filters : int = 100 ,
72- dropout : float = 0.1
73+ dropout : float = 0.1 ,
74+ cnn_config : CNNConfig | str | dict [str , Any ] | None = None
7375 ) -> "CNNScorer" :
7476 return cls (
7577 num_train_epochs = num_train_epochs ,
@@ -80,8 +82,23 @@ def from_context(
8082 embed_dim = embed_dim ,
8183 kernel_sizes = kernel_sizes ,
8284 num_filters = num_filters ,
83- dropout = dropout
85+ dropout = dropout ,
86+ cnn_config = cnn_config
8487 )
88+
89+ def get_embedder_config (self ) -> dict [str , Any ]:
90+ """Get the configuration of the embedder."""
91+ config = self .cnn_config .model_dump ()
92+ config .update ({
93+ "embed_dim" : self .embed_dim ,
94+ "hidden_dim" : self .hidden_dim ,
95+ "n_layers" : self .n_layers ,
96+ "dropout" : self .dropout ,
97+ })
98+ return config
99+
100+ def get_implicit_initialization_params (self ) -> dict [str , Any ]:
101+ return {"cnn_config" : self .cnn_config .model_dump ()}
85102
86103 def fit (self , utterances : list [str ], labels : ListOfLabels ) -> None :
87104 self ._validate_task (labels )
0 commit comments