File tree Expand file tree Collapse file tree 1 file changed +3
-4
lines changed
autointent/modules/scoring/_gcn Expand file tree Collapse file tree 1 file changed +3
-4
lines changed Original file line number Diff line number Diff line change 11import json
22from pathlib import Path
3- from typing import cast
43
54import torch
65from pydantic import BaseModel
@@ -98,12 +97,12 @@ def set_correlation_matrix(self, train_labels: torch.Tensor) -> None:
9897 self .correlation_matrix .data .copy_ (corr_matrix )
9998
10099 def forward (self , bert_features : torch .Tensor , label_embeddings : torch .Tensor ) -> torch .Tensor :
101- classifiers = label_embeddings
100+ classifiers : torch . Tensor = label_embeddings
102101 for i in range (len (self .gcn_layers )):
103102 classifiers = self .gcn_layers [i ](self .correlation_matrix , classifiers )
104- classifiers = self .activations [i ](classifiers ) # type: ignore[operator]
103+ classifiers = self .activations [i ](classifiers )
105104
106- return torch .matmul (bert_features , cast ( torch . Tensor , classifiers ) .T )
105+ return torch .matmul (bert_features , classifiers .T )
107106
108107 def dump (self , path : Path ) -> None :
109108 metadata = GCNModelDumpMetadata (
You can’t perform that action at this time.
0 commit comments