Skip to content

Commit d5287b6

Browse files
committed
fix type errors in cascaded model
1 parent 07361d8 commit d5287b6

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

ncalab/models/basicNCA.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
use_alive_mask: bool = False,
3232
immutable_image_channels: bool = True,
3333
num_learned_filters: int = 2,
34-
filter_padding: str = "circular",
34+
filter_padding: str = "reflect",
3535
use_laplace: bool = False,
3636
kernel_size: int = 3,
3737
pad_noise: bool = False,

ncalab/models/cascadeNCA.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Dict, Optional, Tuple, List
22

33
import numpy as np
44
import torch # type: ignore[import-untyped]
@@ -99,12 +99,13 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Prediction:
9999
assert len(self.scales) > 0
100100
assert len(self.models) > 0
101101
assert len(self.steps) > 0
102+
prediction = None
102103
x_scaled = downscale(x, self.scales[0])
103104
for i, (model, scale, scale_steps) in enumerate(
104105
zip(self.models, self.scales, self.steps)
105106
):
106107
steps = scale_steps + np.random.randint(
107-
-int(scale_steps * 0.2), int(scale_steps * 0.2)
108+
-int(scale_steps * 0.2), int(scale_steps * 0.2) + 1
108109
)
109110
if steps <= 0:
110111
steps = 1
@@ -117,12 +118,13 @@ def forward(self, x: torch.Tensor, *args, **kwargs) -> Prediction:
117118
self.scales[i + 1],
118119
)
119120
# TODO prediction has incorrect number of steps
120-
return prediction
121+
return unwrap(prediction)
121122

122123
def record_steps(self, x: torch.Tensor):
123124
# TODO let "Prediction" class record steps
124125
step_outputs = []
125126
x_scaled = downscale(x, self.scales[0])
127+
prediction = None
126128
for i, (model, scale, scale_steps) in enumerate(
127129
zip(self.models, self.scales, self.steps)
128130
):
@@ -132,15 +134,15 @@ def record_steps(self, x: torch.Tensor):
132134
step_outputs.append(upscale(prediction.output_image, scale))
133135
x_in = prediction.output_image
134136
if i < len(self.scales) - 1:
135-
x_scaled = upscale(prediction.output_image, scale / self.scales[i + 1])
137+
x_scaled = upscale(unwrap(prediction).output_image, scale / self.scales[i + 1])
136138
# replace input with downscaled variant of original image
137139
x_scaled[:, : model.num_image_channels, :, :] = downscale(
138140
x[:, : model.num_image_channels, :, :],
139141
self.scales[i + 1],
140142
)
141143
return step_outputs
142144

143-
def validate(self, image: torch.Tensor, label: torch.Tensor, steps: int = 1):
145+
def validate(self, image: torch.Tensor, label: torch.Tensor, steps: int = 1) -> Optional[Tuple[Dict[str, float], Prediction]]:
144146
"""
145147
Validation method.
146148
@@ -175,4 +177,4 @@ def validate(self, image: torch.Tensor, label: torch.Tensor, steps: int = 1):
175177
image[:, : model.num_image_channels, :, :],
176178
self.scales[i + 1],
177179
)
178-
return metrics, prediction
180+
return unwrap(metrics), unwrap(prediction)

ncalab/models/segmentationNCA.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
num_learned_filters: int = 2,
2828
pad_noise: bool = True,
2929
autostepper: Optional[AutoStepper] = None,
30+
filter_padding: str = "reflect",
3031
**kwargs,
3132
):
3233
"""
@@ -38,6 +39,7 @@ def __init__(
3839
:param hidden_size [int]: Number of neurons in hidden layer. Defaults to 128.
3940
:param learned_filters [int]: Number of learned filters. If 0, use sobel. Defaults to 2.
4041
:param pad_noise [bool]: Whether to pad input images with noise. Defaults to True.
42+
:param filter_padding [str]: Padding type to use. Might affect reliance on spatial cues. Defaults to "circular".
4143
"""
4244
self.num_classes = num_classes
4345
super(SegmentationNCAModel, self).__init__(
@@ -54,6 +56,7 @@ def __init__(
5456
num_learned_filters=num_learned_filters,
5557
pad_noise=pad_noise,
5658
autostepper=autostepper,
59+
filter_padding=filter_padding,
5760
**kwargs,
5861
)
5962

@@ -92,15 +95,15 @@ def metrics(self, pred: torch.Tensor, label: torch.Tensor) -> Dict[str, float]:
9295
"""
9396
outputs = pred[:, self.num_image_channels + self.num_hidden_channels :, :, :]
9497
tp, fp, fn, tn = smp.metrics.get_stats(
95-
outputs.cpu(),
98+
outputs.cpu().float(),
9699
label[:, None, :, :].cpu().long(),
97100
mode="binary",
98101
threshold=0.1,
99102
)
100-
tp = tp.squeeze()
101-
fp = fp.squeeze()
102-
fn = fn.squeeze()
103-
tn = tn.squeeze()
103+
tp = tp.squeeze().long()
104+
fp = fp.squeeze().long()
105+
fn = fn.squeeze().long()
106+
tn = tn.squeeze().long()
104107
iou_score = smp.metrics.iou_score(
105108
tp,
106109
fp,

0 commit comments

Comments
 (0)