Skip to content

Commit 135f86f

Browse files
committed
Update textcnn.py
1 parent 665b750 commit 135f86f

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,23 @@ def __init__(
3131

3232
if pretrained_embs is not None:
3333
_, embed_dim = pretrained_embs.shape
34-
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True)
35-
self.pretrained_embs = pretrained_embs
34+
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
3635
else:
3736
self.embedding = nn.Embedding(
3837
num_embeddings=vocab_size,
3938
embedding_dim=embed_dim,
40-
padding_idx=padding_idx,
39+
padding_idx=padding_idx
4140
)
42-
self.pretrained_embs = None
41+
42+
self.convs = nn.ModuleList([
43+
nn.Conv1d(
44+
in_channels=embed_dim,
45+
out_channels=num_filters,
46+
kernel_size=k
47+
) for k in kernel_sizes
48+
])
49+
self.dropout = nn.Dropout(dropout)
50+
self.fc = nn.Linear(num_filters * len(kernel_sizes), n_classes)
4351

4452
def forward(self, x: torch.Tensor) -> torch.Tensor:
4553
"""Forward pass of the model."""

0 commit comments

Comments
 (0)