Skip to content

Commit 9815080

Browse files
committed
create class-based API for tensorboard visuals
1 parent e4d3e35 commit 9815080

19 files changed

+426
-353
lines changed

ncalab/models/basicNCA.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Callable, Optional, Dict, Tuple
2+
from typing import Any, Dict, Optional, Tuple
33

44
import numpy as np
55

@@ -10,6 +10,7 @@
1010
from ..autostepper import AutoStepper
1111
from ..prediction import Prediction
1212
from ..utils import pad_input
13+
from ..visualization import Visual
1314

1415

1516
class BasicNCAModel(nn.Module):
@@ -23,7 +24,7 @@ def __init__(
2324
num_image_channels: int,
2425
num_hidden_channels: int,
2526
num_output_channels: int,
26-
plot_function: Optional[Callable] = None,
27+
plot_function: Optional[Visual] = None,
2728
validation_metric: Optional[str] = None,
2829
fire_rate: float = 0.5,
2930
hidden_size: int = 128,
@@ -46,7 +47,7 @@ def __init__(
4647
:param num_output_channels [int]: Number of output channels.
4748
:param fire_rate [float]: Fire rate for stochastic weight update. Defaults to 0.5.
4849
:param hidden_size [int]: Number of neurons in hidden layer. Defaults to 128.
49-
:param use_alive_mask [bool]: Whether to use alive masking during training. Defaults to False.
50+
:param use_alive_mask [bool]: Whether to use alive masking (channel 3) during training. Defaults to False.
5051
:param immutable_image_channels [bool]: If image channels should be fixed during inference, which is the case for most segmentation or classification problems. Defaults to True.
5152
:param num_learned_filters [int]: Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
5253
:param filter_padding [str]: Padding type to use. Might affect reliance on spatial cues. Defaults to "circular".
@@ -80,16 +81,17 @@ def __init__(
8081
self.plot_function = plot_function
8182
self.validation_metric = validation_metric
8283

84+
# define input filters
8385
self._define_filters(num_learned_filters)
8486

8587
# define model structure
86-
self._define_network()
88+
self.network = self._define_network().to(self.device)
8789

8890
def _define_network(self):
8991
input_vector_size = self.num_channels * (self.num_filters + 1)
9092
if self.use_temporal_encoding:
9193
input_vector_size += 1
92-
self.network = nn.Sequential(
94+
network = nn.Sequential(
9395
nn.Conv2d(
9496
in_channels=input_vector_size,
9597
out_channels=self.hidden_size,
@@ -107,11 +109,11 @@ def _define_network(self):
107109
padding=0,
108110
kernel_size=1,
109111
),
110-
).to(self.device)
111-
112+
)
112113
# initialize final layer with 0
113114
with torch.no_grad():
114-
self.network[-1].weight.data.fill_(0)
115+
network[-1].weight.data.fill_(0)
116+
return network
115117

116118
def _define_filters(self, num_learned_filters: int):
117119
"""
@@ -134,9 +136,9 @@ def _define_filters(self, num_learned_filters: int):
134136
padding_mode=self.filter_padding,
135137
groups=self.num_channels,
136138
bias=False,
137-
).to(self.device)
139+
)
138140
)
139-
self.filters = nn.ModuleList(filters)
141+
self.filters = nn.ModuleList(filters).to(self.device)
140142
else:
141143
sobel_x = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0
142144
sobel_y = sobel_x.T
@@ -185,9 +187,13 @@ def _perceive_with(x, weight):
185187
perception = [x]
186188
perception.extend([_perceive_with(x, w) for w in self.filters])
187189
if self.use_temporal_encoding:
190+
normalization = 100
191+
if self.autostepper is not None:
192+
normalization = self.autostepper.max_steps
188193
perception.append(
189194
torch.mul(
190-
torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])), step / 100
195+
torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])),
196+
step / normalization,
191197
).to(self.device)
192198
)
193199
dx = torch.cat(perception, 1)
@@ -243,7 +249,6 @@ def forward(
243249
x = x.permute(1, 0, 2, 3) # C B W H --> B C W H
244250
return Prediction(self, steps, x)
245251

246-
247252
for step in range(self.autostepper.max_steps):
248253
if self.autostepper.check(step):
249254
return Prediction(self, step, x)
@@ -347,3 +352,6 @@ def validate(
347352
prediction = self.predict(image.to(self.device), steps=steps)
348353
metrics = self.metrics(prediction.output_image, label.to(self.device))
349354
return metrics, prediction
355+
356+
def to_dict(self) -> Dict[str, Any]:
357+
return dict()

ncalab/models/cascadeNCA.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,7 @@ def __init__(self, backbone: BasicNCAModel, scales: List[int], steps: List[int])
9292
models = [backbone for _ in scales]
9393
self.models = nn.ModuleList(models)
9494

95-
def forward(
96-
self, x: torch.Tensor, *args, **kwargs
97-
) -> Prediction:
95+
def forward(self, x: torch.Tensor, *args, **kwargs) -> Prediction:
9896
"""
9997
:param x [torch.Tensor]: Input image tensor, BCWH.
10098
:param steps [int]: Unused, as steps are defined in constructor.

ncalab/models/classificationNCA.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Dict
1+
from typing import Dict, Optional
22

33
import torch # type: ignore[import-untyped]
44
import torch.nn.functional as F # type: ignore[import-untyped]
55

66
import torchmetrics
77
import torchmetrics.classification
88

9+
from ..autostepper import AutoStepper
910
from .basicNCA import BasicNCAModel
11+
from ..visualization import VisualBinaryImageClassification, VisualMultiImageClassification
1012

1113

1214
class ClassificationNCAModel(BasicNCAModel):
@@ -17,39 +19,66 @@ def __init__(
1719
num_hidden_channels: int,
1820
num_classes: int,
1921
fire_rate: float = 0.8,
22+
hidden_size: int = 128,
2023
use_alive_mask: bool = False,
2124
pixel_wise_loss: bool = False,
25+
num_learned_filters: int = 2,
2226
filter_padding: str = "reflect",
27+
use_laplace: bool = False,
28+
kernel_size: int = 3,
2329
pad_noise: bool = False,
30+
autostepper: Optional[AutoStepper] = None,
31+
use_temporal_encoding: bool = False,
2432
**kwargs,
2533
):
2634
"""
27-
:param device [torch.device]: Compute device.
35+
Constructor.
36+
37+
:param device [device]: Pytorch device descriptor.
2838
:param num_image_channels [int]: _description_
2939
:param num_hidden_channels [int]: _description_
3040
:param num_classes [int]: _description_
31-
:param fire_rate [float]: _description_. Defaults to 0.8.
32-
:param use_alive_mask [bool]: _description_. Defaults to False.
41+
:param fire_rate [float]: Fire rate for stochastic weight update. Defaults to 0.8.
42+
:param hidden_size [int]: Number of neurons in hidden layer. Defaults to 128.
43+
:param use_alive_mask [bool]: Whether to use alive masking (channel 3) during training. Defaults to False.
3344
:param pixel_wise_loss [bool]: Whether a prediction per pixel is desired, like in self-classifying MNIST. Defaults to False.
34-
:param filter_padding [str]: _description_. Defaults to "reflect".
35-
:param pad_noise [bool]: _description_. Defaults to False.
45+
:param num_learned_filters [int]: Number of learned filters. If zero, use two sobel filters instead. Defaults to 2.
46+
:param filter_padding [str]: Padding type to use. Might affect reliance on spatial cues. Defaults to "circular".
47+
:param pad_noise [bool]: Whether to pad input image tensor with noise in hidden / output channels
3648
"""
3749
super(ClassificationNCAModel, self).__init__(
38-
device,
39-
num_image_channels,
40-
num_hidden_channels,
41-
num_classes,
50+
device=device,
51+
num_image_channels=num_image_channels,
52+
num_hidden_channels=num_hidden_channels,
53+
num_output_channels=num_classes,
4254
fire_rate=fire_rate,
55+
hidden_size=hidden_size,
4356
use_alive_mask=use_alive_mask,
4457
immutable_image_channels=True,
4558
plot_function=None,
4659
validation_metric="accuracy_micro",
4760
filter_padding=filter_padding,
61+
use_laplace=use_laplace,
62+
kernel_size=kernel_size,
4863
pad_noise=pad_noise,
49-
**kwargs,
64+
autostepper=autostepper,
65+
use_temporal_encoding=use_temporal_encoding,
5066
)
51-
self.num_classes = num_classes
67+
self._num_classes = num_classes
5268
self.pixel_wise_loss = pixel_wise_loss
69+
if num_classes < 2:
70+
self.plot_function = VisualBinaryImageClassification()
71+
else:
72+
self.plot_function = VisualMultiImageClassification()
73+
74+
@property
75+
def num_classes(self) -> int:
76+
return self._num_classes
77+
78+
@num_classes.setter
79+
def num_classes(self, x: int):
80+
self._num_classes = x
81+
self.num_output_channels = x
5382

5483
def classify(
5584
self, image: torch.Tensor, steps: int = 100, reduce: bool = False
@@ -181,7 +210,9 @@ def metrics(self, pred: torch.Tensor, label: torch.Tensor) -> Dict[str, float]:
181210
]
182211
y_prob = class_channels
183212
y_prob = torch.mean(y_prob, dim=(2, 3))
184-
y_true = label.squeeze(1)
213+
y_true = label
214+
if len(y_true.shape) == 2:
215+
y_true = label.squeeze(1)
185216

186217
accuracy_macro_metric.update(y_prob, y_true)
187218
accuracy_micro_metric.update(y_prob, y_true)

ncalab/models/depthNCA.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .basicNCA import BasicNCAModel, AutoStepper
44

5-
from ..visualization import show_batch_depth
5+
from ..visualization import VisualDepthEstimation
66

77
import torch # type: ignore[import-untyped]
88
import torch.nn as nn # type: ignore[import-untyped]
@@ -83,7 +83,7 @@ def __init__(
8383
device,
8484
num_image_channels,
8585
num_hidden_channels,
86-
plot_function=show_batch_depth,
86+
plot_function=VisualDepthEstimation(),
8787
validation_metric="ssim",
8888
num_output_channels=1,
8989
fire_rate=fire_rate,

ncalab/models/growingNCA.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .basicNCA import AutoStepper, BasicNCAModel
99
from ..prediction import Prediction
10-
from ..visualization import show_batch_growing
10+
from ..visualization import VisualGrowing
1111

1212

1313
class GrowingNCAModel(BasicNCAModel):
@@ -40,7 +40,7 @@ def __init__(
4040
device,
4141
num_image_channels,
4242
num_hidden_channels,
43-
plot_function=show_batch_growing,
43+
plot_function=VisualGrowing(),
4444
num_output_channels=0,
4545
fire_rate=fire_rate,
4646
hidden_size=hidden_size,
@@ -99,11 +99,7 @@ def grow(
9999
prediction = self.forward(x, steps=1) # type: ignore[assignment]
100100
step_outs.append(
101101
np.clip(
102-
prediction.image_channels
103-
.squeeze(0)
104-
.detach()
105-
.cpu()
106-
.numpy(),
102+
prediction.image_channels.squeeze(0).detach().cpu().numpy(),
107103
0,
108104
1,
109105
)
@@ -112,8 +108,6 @@ def grow(
112108
return step_outs
113109
else:
114110
prediction = self.forward(x, steps=steps) # type: ignore[assignment]
115-
out_np = (
116-
prediction.image_channels.detach().cpu().numpy().squeeze(0)
117-
)
111+
out_np = prediction.image_channels.detach().cpu().numpy().squeeze(0)
118112
out_np = np.clip(out_np, 0, 1)
119113
return out_np

ncalab/models/segmentationNCA.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .basicNCA import AutoStepper, BasicNCAModel
77
from ..losses import DiceBCELoss
8-
from ..visualization import show_batch_binary_segmentation
8+
from ..visualization import VisualBinaryImageSegmentation
99

1010

1111
class SegmentationNCAModel(BasicNCAModel):
@@ -46,7 +46,7 @@ def __init__(
4646
num_image_channels,
4747
num_hidden_channels,
4848
num_output_channels=num_classes,
49-
plot_function=show_batch_binary_segmentation,
49+
plot_function=VisualBinaryImageSegmentation(),
5050
validation_metric="Dice",
5151
fire_rate=fire_rate,
5252
hidden_size=hidden_size,
@@ -58,7 +58,6 @@ def __init__(
5858
**kwargs,
5959
)
6060

61-
6261
def loss(self, image: torch.Tensor, label: torch.Tensor) -> Dict[str, torch.Tensor]:
6362
"""
6463
Compute Dice + BCE loss.
@@ -85,16 +84,14 @@ def loss(self, image: torch.Tensor, label: torch.Tensor) -> Dict[str, torch.Tens
8584
loss = loss_segmentation
8685
return {"total": loss}
8786

88-
def metrics(self, pred: torch.Tensor, label: torch.Tensor):
87+
def metrics(self, pred: torch.Tensor, label: torch.Tensor) -> Dict[str, float]:
8988
"""
9089
Return dict of standard evaluation metrics.
9190
9291
:param pred [torch.Tensor]: Predicted image.
9392
:param label [torch.Tensor]: Ground truth label.
9493
"""
95-
outputs = pred[
96-
:, self.num_image_channels + self.num_hidden_channels :, :, :
97-
]
94+
outputs = pred[:, self.num_image_channels + self.num_hidden_channels :, :, :]
9895
tp, fp, fn, tn = smp.metrics.get_stats(
9996
outputs.cpu(),
10097
label[:, None, :, :].cpu().long(),

0 commit comments

Comments
 (0)