Skip to content

Commit 8c1e3ac

Browse files
committed
mypy3
1 parent 1429ff7 commit 8c1e3ac

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

autointent/modules/scoring/_gcn/gcn_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from pathlib import Path
3+
from typing import cast
34

45
import torch
56
from pydantic import BaseModel
@@ -100,9 +101,9 @@ def forward(self, bert_features: torch.Tensor, label_embeddings: torch.Tensor) -
100101
classifiers = label_embeddings
101102
for i in range(len(self.gcn_layers)):
102103
classifiers = self.gcn_layers[i](self.correlation_matrix, classifiers)
103-
classifiers = self.activations[i](classifiers)
104+
classifiers = self.activations[i](classifiers) # type: ignore[operator]
104105

105-
return torch.matmul(bert_features, classifiers.T)
106+
return torch.matmul(bert_features, cast(torch.Tensor, classifiers).T)
106107

107108
def dump(self, path: Path) -> None:
108109
metadata = GCNModelDumpMetadata(

0 commit comments

Comments
 (0)