Skip to content

Commit 3fa2c65

Browse files
committed
add more type hints to function return values
1 parent c368d3b commit 3fa2c65

File tree

6 files changed

+15
-16
lines changed

6 files changed

+15
-16
lines changed

ncalab/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# TODO implement pattern pool
2-
# TODO implement damage
31
import numpy as np
42
import torch # type: ignore[import-untyped]
53
from torch.utils.data import Dataset # type: ignore[import-untyped]
@@ -11,7 +9,9 @@ def __init__(
119
image: np.ndarray,
1210
num_channels: int,
1311
batch_size: int = 8,
12+
# TODO implement
1413
use_pattern_pool: bool = False,
14+
# TODO implement
1515
damage: bool = False,
1616
):
1717
"""Dedicated dataset for "growing" tasks, like growing emoji.

ncalab/losses.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,10 @@ def __init__(self):
1111
super(DiceScore, self).__init__()
1212

1313
def forward(
14-
self, x: torch.Tensor, y: torch.Tensor, smooth: float = 1
14+
self, x: torch.Tensor, y: torch.Tensor, smooth: float = 1.0
1515
) -> torch.Tensor:
1616
"""
17-
18-
Args:
19-
input (_type_): _description_
20-
target (_type_): _description_
21-
smooth (int, optional): _description_. Defaults to 1.
22-
23-
Returns:
24-
_type_: _description_
17+
:returns: Dice score as Tensor
2518
"""
2619
x = torch.sigmoid(x)
2720
x = torch.flatten(x)

ncalab/models/classificationNCA.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def loss(self, image: torch.Tensor, label: torch.Tensor) -> Dict[str, torch.Tens
165165
"classification": loss_classification,
166166
}
167167

168-
def metrics(self, pred: torch.Tensor, label: torch.Tensor):
168+
def metrics(self, pred: torch.Tensor, label: torch.Tensor) -> Dict[str, float]:
169169
"""
170170
Return dict of standard evaluation metrics.
171171

ncalab/training/earlystopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, patience: int, min_delta: float = 1e-6):
1616
self.best_accuracy = 0.0
1717
self.counter = 0
1818

19-
def done(self):
19+
def done(self) -> bool:
2020
"""
2121
Checks whether the training can be stopped.
2222

ncalab/training/kfold.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(self):
7272
self.dataloader_test = None
7373

7474
@staticmethod
75-
def read(path: PosixPath):
75+
def read(path: PosixPath) -> "SplitDefinition":
7676
"""
7777
Reads json files with split definitions, similar to those created by nnUNet.
7878
@@ -105,7 +105,7 @@ def read(path: PosixPath):
105105
# TODO validate structure
106106
return sd
107107

108-
def __len__(self):
108+
def __len__(self) -> int:
109109
return len(self.folds)
110110

111111
def __getitem__(self, idx) -> TrainValRecord:

ncalab/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
import torch.nn.functional as F # type: ignore[import-untyped]
88

99

10+
from typing import TYPE_CHECKING
11+
12+
if TYPE_CHECKING:
13+
from .models import BasicNCAModel
14+
15+
1016
def get_compute_device(device: str = "cuda:0") -> torch.device:
1117
"""
1218
Obtain a pytorch compute device handle based on input string.
@@ -23,7 +29,7 @@ def get_compute_device(device: str = "cuda:0") -> torch.device:
2329
return d
2430

2531

26-
def pad_input(x, nca, noise: bool = True):
32+
def pad_input(x: torch.Tensor, nca: BasicNCAModel, noise: bool = True) -> torch.Tensor:
2733
"""
2834
Pads input tensor along channel dimension to match the expected number of
2935
channels required by the NCA model. Pads with either Gaussian noise or zeros,

0 commit comments

Comments
 (0)