@@ -52,21 +52,13 @@ def __init__(
5252
5353 if pretrained_embs is not None :
5454 _ , embed_dim = pretrained_embs .shape
55- self .embedding = nn .Embedding .from_pretrained (pretrained_embs , freeze = True ) # type: ignore[no-untyped-call]
55+ self .embedding = nn .Embedding .from_pretrained (pretrained_embs , freeze = True ) # type: ignore[no-untyped-call]
5656 else :
57- self .embedding = nn .Embedding (
58- num_embeddings = vocab_size ,
59- embedding_dim = embed_dim ,
60- padding_idx = padding_idx
61- )
62-
63- self .convs = nn .ModuleList ([
64- nn .Conv1d (
65- in_channels = embed_dim ,
66- out_channels = num_filters ,
67- kernel_size = k
68- ) for k in kernel_sizes
69- ])
57+ self .embedding = nn .Embedding (num_embeddings = vocab_size , embedding_dim = embed_dim , padding_idx = padding_idx )
58+
59+ self .convs = nn .ModuleList (
60+ [nn .Conv1d (in_channels = embed_dim , out_channels = num_filters , kernel_size = k ) for k in kernel_sizes ]
61+ )
7062 self .dropout = nn .Dropout (dropout )
7163 self .fc = nn .Linear (num_filters * len (kernel_sizes ), n_classes )
7264
@@ -77,7 +69,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7769 conved : list [torch .Tensor ] = [F .relu (conv (embedded )).max (dim = 2 )[0 ] for conv in self .convs ]
7870 concatenated : torch .Tensor = torch .cat (conved , dim = 1 )
7971 dropped : torch .Tensor = self .dropout (concatenated )
80- return self .fc (dropped ) # type: ignore[no-any-return]
72+ return self .fc (dropped ) # type: ignore[no-any-return]
8173
8274 def dump (self , path : Path ) -> None :
8375 metadata = {
@@ -87,7 +79,7 @@ def dump(self, path: Path) -> None:
8779 "kernel_sizes" : self .kernel_sizes ,
8880 "num_filters" : self .num_filters ,
8981 "dropout" : self .dropout_rate ,
90- "padding_idx" : self .padding_idx
82+ "padding_idx" : self .padding_idx ,
9183 }
9284 with (path / self ._metadata_dict_name ).open ("w" ) as file :
9385 json .dump (metadata , file , indent = 4 )
0 commit comments