Skip to content

Commit c6210ef

Browse files
committed
added CNNconfig
1 parent 245d686 commit c6210ef

File tree

3 files changed

+43
-10
lines changed

3 files changed

+43
-10
lines changed

autointent/configs/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"EmbedderConfig",
1111
"HFModelConfig",
1212
"InferenceNodeConfig",
13-
"InferenceNodeConfig",
1413
"LoggingConfig",
1514
"TaskTypeEnum",
1615
"TokenizerConfig",

autointent/configs/_transformers.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,20 @@ class CrossEncoderConfig(HFModelConfig):
122122
tokenizer_config: TokenizerConfig = Field(
123123
default_factory=lambda: TokenizerConfig(max_length=512)
124124
) # this is because sentence-transformers doesn't allow you to customize tokenizer settings properly
125+
126+
class СNNConfig(BaseModel):
127+
model_config = ConfigDict(extra="forbid")
128+
max_seq_length: int = Field(128, description="Maximum sequence length.")
129+
padding_idx: int = Field(0, description="Index used for padding.")
130+
unknown_idx: int = Field(1, description="Index used for unknown.")
131+
batch_size: PositiveInt = Field(32, description="Batch size for model inference.")
132+
133+
@classmethod
134+
def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) -> "СNNConfig":
135+
if values is None:
136+
return cls()
137+
if isinstance(values, BaseModel):
138+
return values # type: ignore[return-value]
139+
if isinstance(values, str):
140+
return cls()
141+
return cls(**values)

autointent/modules/scoring/_cnn/cnn.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from autointent import Context
1414
from autointent._callbacks import REPORTERS_NAMES
15+
from autointent.configs import CNNConfig
1516
from autointent.custom_types import ListOfLabels
1617
from autointent.modules.base import BaseScorer
1718
from autointent.modules.scoring._cnn.textcnn import TextCNN
@@ -26,37 +27,37 @@ class CNNScorer(BaseScorer):
2627

2728
def __init__(
2829
self,
29-
max_seq_length: int = 50,
3030
num_train_epochs: int = 3,
31-
batch_size: int = 8,
3231
learning_rate: float = 5e-5,
3332
seed: int = 0,
3433
report_to: REPORTERS_NAMES | None = None, # type: ignore[valid-type]
3534
embed_dim: int = 128,
3635
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
3736
num_filters: int = 100,
38-
dropout: float = 0.1
37+
dropout: float = 0.1,
38+
cnn_config: CNNConfig | str | dict[str, Any] | None = None,
3939
) -> None:
40-
self.max_seq_length = max_seq_length
4140
self.num_train_epochs = num_train_epochs
42-
self.batch_size = batch_size
4341
self.learning_rate = learning_rate
4442
self.seed = seed
4543
self.report_to = report_to
4644
self.embed_dim = embed_dim
4745
self.kernel_sizes = kernel_sizes
4846
self.num_filters = num_filters
4947
self.dropout = dropout
48+
self.cnn_config = CNNConfig.from_search_config(cnn_config)
5049

5150
# Will be initialized during fit()
5251
self._model: TextCNN | None = None
5352
self._vocab: dict[str, int] | None = None
5453
self._unk_token = "<UNK>" # noqa: S105
5554
self._pad_token = "<PAD>" # noqa: S105
56-
self._unk_idx = 1
57-
self._pad_idx = 0
5855
self._n_classes: int = 0
5956
self._multilabel: bool = False
57+
self._pad_idx = self.cnn_config.padding_idx
58+
self._unk_idx = self.cnn_config.unknown_idx
59+
self.batch_size = self.cnn_config.batch_size
60+
self.max_seq_length = self.cnn_config.max_seq_length
6061

6162
@classmethod
6263
def from_context(
@@ -69,7 +70,8 @@ def from_context(
6970
embed_dim: int = 128,
7071
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
7172
num_filters: int = 100,
72-
dropout: float = 0.1
73+
dropout: float = 0.1,
74+
cnn_config: CNNConfig | str | dict[str, Any] | None = None
7375
) -> "CNNScorer":
7476
return cls(
7577
num_train_epochs=num_train_epochs,
@@ -80,8 +82,23 @@ def from_context(
8082
embed_dim=embed_dim,
8183
kernel_sizes=kernel_sizes,
8284
num_filters=num_filters,
83-
dropout=dropout
85+
dropout=dropout,
86+
cnn_config=cnn_config
8487
)
88+
89+
def get_embedder_config(self) -> dict[str, Any]:
90+
"""Get the configuration of the embedder."""
91+
config = self.cnn_config.model_dump()
92+
config.update({
93+
"embed_dim": self.embed_dim,
94+
"hidden_dim": self.hidden_dim,
95+
"n_layers": self.n_layers,
96+
"dropout": self.dropout,
97+
})
98+
return config
99+
100+
def get_implicit_initialization_params(self) -> dict[str, Any]:
101+
return {"cnn_config": self.cnn_config.model_dump()}
85102

86103
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
87104
self._validate_task(labels)

0 commit comments

Comments
 (0)