@@ -35,6 +35,7 @@ def __init__(
3535 kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
3636 num_filters : int = 100 ,
3737 dropout : float = 0.1 ,
38+ batch_size : int = 8 ,
3839 cnn_config : CNNConfig | str | dict [str , Any ] | None = None ,
3940 ) -> None :
4041 self .num_train_epochs = num_train_epochs
@@ -56,7 +57,7 @@ def __init__(
5657 self ._multilabel : bool = False
5758 self ._pad_idx = self .cnn_config .padding_idx
5859 self ._unk_idx = self .cnn_config .unknown_idx
59- self .batch_size = self . cnn_config . batch_size
60+ self .batch_size = batch_size
6061 self .max_seq_length = self .cnn_config .max_seq_length
6162
6263 @classmethod
@@ -86,17 +87,6 @@ def from_context(
8687 cnn_config = cnn_config
8788 )
8889
89- def get_embedder_config (self ) -> dict [str , Any ]:
90- """Get the configuration of the embedder."""
91- config = self .cnn_config .model_dump ()
92- config .update ({
93- "embed_dim" : self .embed_dim ,
94- "hidden_dim" : self .hidden_dim ,
95- "n_layers" : self .n_layers ,
96- "dropout" : self .dropout ,
97- })
98- return config
99-
10090 def get_implicit_initialization_params (self ) -> dict [str , Any ]:
10191 return {"cnn_config" : self .cnn_config .model_dump ()}
10292
0 commit comments