Skip to content

Commit c5f0f34

Browse files
committed
mypy4
1 parent 8c1e3ac commit c5f0f34

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

autointent/modules/scoring/_gcn/gcn_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
from pathlib import Path
3-
from typing import cast
43

54
import torch
65
from pydantic import BaseModel
@@ -98,12 +97,12 @@ def set_correlation_matrix(self, train_labels: torch.Tensor) -> None:
9897
self.correlation_matrix.data.copy_(corr_matrix)
9998

10099
def forward(self, bert_features: torch.Tensor, label_embeddings: torch.Tensor) -> torch.Tensor:
101-
classifiers = label_embeddings
100+
classifiers: torch.Tensor = label_embeddings
102101
for i in range(len(self.gcn_layers)):
103102
classifiers = self.gcn_layers[i](self.correlation_matrix, classifiers)
104-
classifiers = self.activations[i](classifiers) # type: ignore[operator]
103+
classifiers = self.activations[i](classifiers)
105104

106-
return torch.matmul(bert_features, cast(torch.Tensor, classifiers).T)
105+
return torch.matmul(bert_features, classifiers.T)
107106

108107
def dump(self, path: Path) -> None:
109108
metadata = GCNModelDumpMetadata(

0 commit comments

Comments
 (0)