File tree Expand file tree Collapse file tree 1 file changed +12
-4
lines changed
autointent/modules/scoring/_cnn Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -31,15 +31,23 @@ def __init__(
3131
3232 if pretrained_embs is not None :
3333 _ , embed_dim = pretrained_embs .shape
34- self .embedding = nn .Embedding .from_pretrained (pretrained_embs , freeze = True )
35- self .pretrained_embs = pretrained_embs
34+ self .embedding = nn .Embedding .from_pretrained (pretrained_embs , freeze = True ) # type: ignore[no-untyped-call]
3635 else :
3736 self .embedding = nn .Embedding (
3837 num_embeddings = vocab_size ,
3938 embedding_dim = embed_dim ,
40- padding_idx = padding_idx ,
39+ padding_idx = padding_idx
4140 )
42- self .pretrained_embs = None
41+
42+ self .convs = nn .ModuleList ([
43+ nn .Conv1d (
44+ in_channels = embed_dim ,
45+ out_channels = num_filters ,
46+ kernel_size = k
47+ ) for k in kernel_sizes
48+ ])
49+ self .dropout = nn .Dropout (dropout )
50+ self .fc = nn .Linear (num_filters * len (kernel_sizes ), n_classes )
4351
4452 def forward (self , x : torch .Tensor ) -> torch .Tensor :
4553 """Forward pass of the model."""
You can’t perform that action at this time.
0 commit comments