@@ -13,47 +13,33 @@ def __init__(
1313 vocab_size : int = 0 ,
1414 n_classes : int = 0 ,
1515 embed_dim : int = 128 ,
16- kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
16+ kernel_sizes : list [int ] = [3 , 4 , 5 ], # noqa: B006
1717 num_filters : int = 100 ,
1818 dropout : float = 0.1 ,
1919 padding_idx : int = 0 ,
2020 pretrained_embs : torch .Tensor | None = None ,
2121 ) -> None :
22- """Initialize TextCNN model."""
2322 super ().__init__ ()
2423
25- # Register model hyperparameters as buffers
26- self .register_buffer ("vocab_size" , torch .tensor (vocab_size ))
27- self .register_buffer ("n_classes" , torch .tensor (n_classes ))
28- self .register_buffer ("embed_dim" , torch .tensor (embed_dim ))
29- self .register_buffer ("kernel_sizes" , torch .tensor (kernel_sizes ))
30- self .register_buffer ("num_filters" , torch .tensor (num_filters ))
31- self .register_buffer ("dropout_rate" , torch .tensor (dropout ))
32- self .register_buffer ("padding_idx" , torch .tensor (padding_idx ))
24+ self .vocab_size = vocab_size
25+ self .n_classes = n_classes
26+ self .embed_dim = embed_dim
27+ self .kernel_sizes = kernel_sizes
28+ self .num_filters = num_filters
29+ self .dropout_rate = dropout
30+ self .padding_idx = padding_idx
3331
3432 if pretrained_embs is not None :
3533 _ , embed_dim = pretrained_embs .shape
36- self .embedding = nn .Embedding .from_pretrained (pretrained_embs , freeze = True ) # type: ignore[no-untyped-call]
37- # Register pretrained embeddings as buffer if they exist
38- self .register_buffer ("pretrained_embs" , pretrained_embs )
34+ self .embedding = nn .Embedding .from_pretrained (pretrained_embs , freeze = True )
35+ self .pretrained_embs = pretrained_embs
3936 else :
4037 self .embedding = nn .Embedding (
4138 num_embeddings = vocab_size ,
4239 embedding_dim = embed_dim ,
43- padding_idx = padding_idx
40+ padding_idx = padding_idx ,
4441 )
45- # Register None for pretrained_embs buffer
46- self .register_buffer ("pretrained_embs" , None )
47-
48- self .convs = nn .ModuleList ([
49- nn .Conv1d (
50- in_channels = embed_dim ,
51- out_channels = num_filters ,
52- kernel_size = k
53- ) for k in kernel_sizes
54- ])
55- self .dropout = nn .Dropout (dropout )
56- self .fc = nn .Linear (num_filters * len (kernel_sizes ), n_classes )
42+ self .pretrained_embs = None
5743
5844 def forward (self , x : torch .Tensor ) -> torch .Tensor :
5945 """Forward pass of the model."""
@@ -73,7 +59,7 @@ def load(self, model_path: str) -> None:
7359 state_dict = torch .load (model_path )
7460 self .load_state_dict (state_dict )
7561
76- def get_config (self ) -> dict :
62+ def get_config (self ) -> - > dict [ str , int | list [ int ] | torch . Tensor | None ] :
7763 return {
7864 "vocab_size" : self .vocab_size .item (),
7965 "n_classes" : self .n_classes .item (),
0 commit comments