Skip to content

Commit 307c5db

Browse files
committed
fix type hints, externalize autostep logic
1 parent 1c0a6b9 commit 307c5db

File tree

10 files changed

+168
-140
lines changed

10 files changed

+168
-140
lines changed

ncalab/losses.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,16 @@
44

55

66
class DiceScore(nn.Module):
7-
"""_summary_
8-
9-
Args:
10-
nn (_type_): _description_
11-
"""
7+
""" """
128

139
def __init__(self):
14-
"""_summary_"""
10+
""""""
1511
super(DiceScore, self).__init__()
1612

17-
def forward(self, x: torch.Tensor, y: torch.Tensor, smooth: float = 1):
18-
"""_summary_
13+
def forward(
14+
self, x: torch.Tensor, y: torch.Tensor, smooth: float = 1
15+
) -> torch.Tensor:
16+
"""
1917
2018
Args:
2119
input (_type_): _description_
@@ -35,10 +33,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, smooth: float = 1):
3533

3634

3735
class DiceBCELoss(nn.Module):
38-
"""Combination of Dice and BCE Loss.
39-
40-
Args:
41-
nn (_type_): _description_
36+
"""
37+
Combination of Dice and BCE Loss.
4238
"""
4339

4440
def __init__(self):

ncalab/models/basicNCA.py

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
from __future__ import annotations
2-
from typing import Callable, List, Dict
2+
from typing import Callable, List, Optional, Dict
33
import numpy as np
44

55
import torch # type: ignore[import-untyped]
66
import torch.nn as nn # type: ignore[import-untyped]
77
import torch.nn.functional as F # type: ignore[import-untyped]
88

99

10+
class AutoStepper:
11+
def __init__(
12+
self,
13+
min_steps: int = 10,
14+
max_steps: int = 100,
15+
plateau: int = 5,
16+
verbose: bool = False,
17+
threshold: float = 1e-2,
18+
):
19+
assert min_steps >= 1
20+
assert plateau >= 1
21+
assert max_steps > min_steps
22+
self.min_steps = min_steps
23+
self.max_steps = max_steps
24+
self.plateau = plateau
25+
self.verbose = verbose
26+
self.threshold = threshold
27+
28+
1029
class BasicNCAModel(nn.Module):
1130
def __init__(
1231
self,
@@ -21,14 +40,10 @@ def __init__(
2140
num_learned_filters: int = 2,
2241
dx_noise: float = 0.0,
2342
filter_padding: str = "reflect",
43+
use_laplace: bool = False,
2444
kernel_size: int = 3,
25-
auto_step: bool = False,
26-
auto_max_steps: int = 100,
27-
auto_min_steps: int = 10,
28-
auto_plateau: int = 5,
29-
auto_verbose: bool = False,
30-
auto_threshold: float = 1e-2,
3145
pad_noise: bool = False,
46+
autostepper: Optional[AutoStepper] = None,
3247
):
3348
"""Basic abstract class for NCA models.
3449
@@ -57,19 +72,14 @@ def __init__(
5772
num_image_channels + num_hidden_channels + num_output_channels
5873
)
5974
self.fire_rate = fire_rate
75+
self.hidden_size = hidden_size
6076
self.use_alive_mask = use_alive_mask
6177
self.immutable_image_channels = immutable_image_channels
6278
self.num_learned_filters = num_learned_filters
79+
self.use_laplace = use_laplace
6380
self.dx_noise = dx_noise
64-
self.auto_step = auto_step
65-
self.auto_max_steps = auto_max_steps
66-
self.auto_min_steps = auto_min_steps
67-
self.auto_plateau = auto_plateau
68-
self.auto_verbose = auto_verbose
69-
self.auto_threshold = auto_threshold
7081
self.pad_noise = pad_noise
71-
72-
self.hidden_size = hidden_size
82+
self.autostepper = autostepper
7383

7484
self.plot_function: Callable | None = None
7585

@@ -93,17 +103,20 @@ def __init__(
93103
)
94104
self.filters = nn.ModuleList(filters)
95105
else:
96-
self.num_filters = 2
97106
sobel_x = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0
98107
sobel_y = sobel_x.T
108+
laplace = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])
99109
self.filters.append(sobel_x)
100110
self.filters.append(sobel_y)
111+
if self.use_laplace:
112+
self.filters.append(laplace)
113+
self.num_filters = len(self.filters)
101114

102115
self.network = nn.Sequential(
103116
nn.Linear(
104117
self.num_channels * (self.num_filters + 1), self.hidden_size, bias=True
105118
),
106-
#nn.LazyBatchNorm2d(),
119+
nn.LazyBatchNorm2d(),
107120
nn.ReLU(),
108121
nn.Linear(self.hidden_size, self.num_channels, bias=False),
109122
).to(device)
@@ -175,14 +188,14 @@ def update(self, x):
175188

176189
# Stochastic weight update
177190
fire_rate = self.fire_rate
178-
stochastic = torch.rand([dx.size(0), dx.size(1), dx.size(2), 1]) > fire_rate
191+
stochastic = torch.rand([dx.size(0), dx.size(1), dx.size(2), 1]) < fire_rate
179192
stochastic = stochastic.float().to(self.device)
180193
dx = dx * stochastic
181194

182195
dx += self.dx_noise * torch.randn([dx.size(0), dx.size(1), dx.size(2), 1]).to(
183196
self.device
184197
)
185-
198+
186199
if self.immutable_image_channels:
187200
dx[..., : self.num_image_channels] *= 0
188201

@@ -205,65 +218,60 @@ def forward(
205218
steps: int = 1,
206219
return_steps: bool = False,
207220
):
208-
if self.auto_step:
209-
# Assumption: min_steps >= 1; otherwise we cannot compute distance
210-
assert self.auto_min_steps >= 1
211-
assert self.auto_plateau >= 1
212-
assert self.auto_max_steps > self.auto_min_steps
213-
214-
cooldown = 0
215-
# invariant: auto_min_steps > 0, so both of these will be set when used
216-
hidden_i: torch.Tensor | None = None
217-
hidden_i_1: torch.Tensor | None = None
218-
for step in range(self.auto_max_steps):
219-
with torch.no_grad():
220-
if (
221-
step >= self.auto_min_steps
222-
and hidden_i is not None
223-
and hidden_i_1 is not None
224-
):
225-
# normalized absolute difference between two hidden states
226-
score = (hidden_i - hidden_i_1).abs().sum() / (
227-
hidden_i.shape[0]
228-
* hidden_i.shape[1]
229-
* hidden_i.shape[2]
230-
* hidden_i.shape[3]
231-
)
232-
if score >= self.auto_threshold:
233-
cooldown = 0
234-
else:
235-
cooldown += 1
236-
if cooldown >= self.auto_plateau:
237-
if self.auto_verbose:
238-
print(f"Breaking after {step} steps.")
239-
if return_steps:
240-
return x, step
241-
return x
242-
# save previous hidden state
243-
hidden_i_1 = x[
244-
...,
245-
self.num_image_channels : self.num_image_channels
246-
+ self.num_hidden_channels,
247-
]
248-
# single inference time step
249-
x = self.update(x)
250-
# set current hidden state
251-
hidden_i = x[
252-
...,
253-
self.num_image_channels : self.num_image_channels
254-
+ self.num_hidden_channels,
255-
]
256-
if return_steps:
257-
return x, self.auto_max_steps
258-
return x
259-
else:
221+
if self.autostepper is None:
260222
for step in range(steps):
261223
x = self.update(x)
262224
if return_steps:
263225
return x, steps
264226
return x
265227

266-
def loss(self, x, target) -> Dict[str, float]:
228+
cooldown = 0
229+
# invariant: auto_min_steps > 0, so both of these will be defined when used
230+
hidden_i: torch.Tensor | None = None
231+
hidden_i_1: torch.Tensor | None = None
232+
for step in range(self.autostepper.max_steps):
233+
with torch.no_grad():
234+
if (
235+
step >= self.autostepper.min_steps
236+
and hidden_i is not None
237+
and hidden_i_1 is not None
238+
):
239+
# normalized absolute difference between two hidden states
240+
score = (hidden_i - hidden_i_1).abs().sum() / (
241+
hidden_i.shape[0]
242+
* hidden_i.shape[1]
243+
* hidden_i.shape[2]
244+
* hidden_i.shape[3]
245+
)
246+
if score >= self.autostepper.threshold:
247+
cooldown = 0
248+
else:
249+
cooldown += 1
250+
if cooldown >= self.autostepper.plateau:
251+
if self.autostepper.verbose:
252+
print(f"Breaking after {step} steps.")
253+
if return_steps:
254+
return x, step
255+
return x
256+
# save previous hidden state
257+
hidden_i_1 = x[
258+
...,
259+
self.num_image_channels : self.num_image_channels
260+
+ self.num_hidden_channels,
261+
]
262+
# single inference time step
263+
x = self.update(x)
264+
# set current hidden state
265+
hidden_i = x[
266+
...,
267+
self.num_image_channels : self.num_image_channels
268+
+ self.num_hidden_channels,
269+
]
270+
if return_steps:
271+
return x, self.autostepper.max_steps
272+
return x
273+
274+
def loss(self, x, target) -> Dict[str, torch.Tensor]:
267275
"""_summary_
268276
269277
Args:

ncalab/models/classificationNCA.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
from tqdm import tqdm # type: ignore[import-untyped]
77

88
from .basicNCA import BasicNCAModel
9-
from .splitNCA import SplitNCAModel
109
from ..utils import pad_input
1110

1211

13-
class ClassificationNCAModel(SplitNCAModel):
12+
class ClassificationNCAModel(BasicNCAModel):
1413
def __init__(
1514
self,
1615
device,
@@ -171,7 +170,11 @@ def loss(self, x, target):
171170
loss = (
172171
1 - self.lambda_activity
173172
) * loss_classification + self.lambda_activity * loss_activity
174-
return { "total": loss }
173+
return {
174+
"total": loss,
175+
"activity": loss_activity,
176+
"classification": loss_classification,
177+
}
175178

176179
def validate(
177180
self,

ncalab/models/depthNCA.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from .basicNCA import BasicNCAModel
2-
from .splitNCA import SplitNCAModel
1+
from typing import Optional
2+
3+
from .basicNCA import BasicNCAModel, AutoStepper
34

45
from ..visualization import show_batch_depth
56
from ..utils import pad_input
@@ -58,25 +59,20 @@ def forward(self, depth_map, rgb_image):
5859
return total_loss
5960

6061

61-
class DepthNCAModel(SplitNCAModel):
62+
class DepthNCAModel(BasicNCAModel):
6263
def __init__(
6364
self,
6465
device: torch.device,
6566
num_image_channels: int,
6667
num_hidden_channels: int,
6768
num_classes: int,
68-
fire_rate: float = 0.0,
69+
fire_rate: float = 0.8,
6970
hidden_size: int = 128,
7071
use_alive_mask: bool = False,
7172
immutable_image_channels: bool = True,
72-
learned_filters: int = 4,
73+
learned_filters: int = 2,
7374
lambda_activity: float = 0.0,
74-
auto_step: bool = False,
75-
auto_max_steps: int = 100,
76-
auto_min_steps: int = 10,
77-
auto_plateau: int = 5,
78-
auto_verbose: bool = False,
79-
auto_threshold: float = 1e-2,
75+
autostepper: Optional[AutoStepper] = None,
8076
pad_noise: bool = False,
8177
):
8278
"""NCA model for monocular depth estimation.
@@ -105,12 +101,7 @@ def __init__(
105101
immutable_image_channels,
106102
learned_filters,
107103
kernel_size=3,
108-
auto_step=auto_step,
109-
auto_max_steps=auto_max_steps,
110-
auto_min_steps=auto_min_steps,
111-
auto_plateau=auto_plateau,
112-
auto_verbose=auto_verbose,
113-
auto_threshold=auto_threshold,
104+
autostepper=autostepper,
114105
pad_noise=pad_noise,
115106
)
116107
self.plot_function = show_batch_depth
@@ -174,7 +165,7 @@ def loss(self, x, y):
174165
y_SSI,
175166
)
176167

177-
loss = 0.5 * loss_tv + loss_depthmap + loss_ssim
168+
loss = 0.5 * loss_tv + loss_depthmap + 0.2 * loss_ssim
178169
return {"total": loss, "tv": loss_tv, "depth": loss_depthmap, "ssim": loss_ssim}
179170

180171
def validate(

ncalab/models/growingNCA.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import numpy as np
77

88
from .basicNCA import BasicNCAModel
9-
from .splitNCA import SplitNCAModel
109
from ..visualization import show_batch_growing
1110

1211

13-
class GrowingNCAModel(SplitNCAModel):
12+
class GrowingNCAModel(BasicNCAModel):
1413
def __init__(
1514
self,
1615
device: torch.device,
@@ -60,7 +59,7 @@ def loss(self, x, y):
6059
Tensor: MSE Loss
6160
"""
6261
loss = F.mse_loss(x[..., : self.num_image_channels], y)
63-
return { "total": loss }
62+
return {"total": loss}
6463

6564
def validate(self, *args, **kwargs):
6665
"""We typically don't validate during training of Growing NCA."""
@@ -88,7 +87,15 @@ def grow(
8887
for _ in range(steps):
8988
out = self.forward(out, steps=1)
9089
step_outs.append(
91-
np.clip(out[..., : self.num_image_channels].squeeze().detach().cpu().numpy(), 0, 1)
90+
np.clip(
91+
out[..., : self.num_image_channels]
92+
.squeeze()
93+
.detach()
94+
.cpu()
95+
.numpy(),
96+
0,
97+
1,
98+
)
9299
)
93100
return step_outs
94101
else:

0 commit comments

Comments
 (0)