Skip to content

Commit 16a7610

Browse files
committed
mypy5
1 parent c5f0f34 commit 16a7610

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

autointent/modules/scoring/_gcn/gcn_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from pathlib import Path
3+
from typing import cast
34

45
import torch
56
from pydantic import BaseModel
@@ -88,7 +89,7 @@ def create_correlation_matrix(
8889
reweighted_adj = adj_matrix_no_self_loop * weights_p.unsqueeze(1)
8990
reweighted_adj.fill_diagonal_(1 - p)
9091

91-
return reweighted_adj
92+
return cast(torch.Tensor, reweighted_adj)
9293

9394
def set_correlation_matrix(self, train_labels: torch.Tensor) -> None:
9495
corr_matrix = self.create_correlation_matrix(
@@ -133,4 +134,4 @@ def load(cls, path: Path, device: str | None = None) -> Self:
133134

134135
instance = instance.to(device)
135136
instance.eval()
136-
return instance
137+
return instance

0 commit comments

Comments
 (0)