Skip to content

Commit 9616c4d

Browse files
committed
let forward() method return Prediction object
1 parent b68041a commit 9616c4d

File tree

17 files changed

+393
-206
lines changed

17 files changed

+393
-206
lines changed

ncalab/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from .autostepper import AutoStepper # noqa: F401,F403
12
from .data import GrowingNCADataset # noqa: F401
23
from .experiment import * # noqa: F401,F403
34
from .losses import * # noqa: F401,F403
45
from .models import * # noqa: F401,F403
56
from .paths import * # noqa: F401,F403
7+
from .prediction import * # noqa: F401,F403
68
from .search import * # noqa: F401,F403
79
from .training import * # noqa: F401,F403
810
from .utils import * # noqa: F401,F403

ncalab/autostepper.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
22

3+
import torch
4+
35

46
class AutoStepper:
57
"""
@@ -20,8 +22,8 @@ def __init__(
2022
:param min_steps [int]: Minimum number of timesteps to always execute. Defaults to 10.
2123
:param max_steps [int]: Terminate after maximum number of steps. Defaults to 100.
2224
:param plateau [int]: _description_. Defaults to 5.
23-
:param verbose [bool]: Whether to communicate. Defaults to False.
24-
threshold (float, optional): _description_. Defaults to 1e-2.
25+
:param verbose [bool]: Whether to log interruption to stdout. Defaults to False.
26+
:param threshold [float]: Score threshold. Defaults to 1e-2.
2527
"""
2628
assert min_steps >= 1
2729
assert plateau >= 1
@@ -32,22 +34,38 @@ def __init__(
3234
self.verbose = verbose
3335
self.threshold = threshold
3436
self.cooldown = 0
37+
# invariant: auto_min_steps > 0, so both of these will be defined when used
38+
self.hidden_i: torch.Tensor | None = None
39+
self.hidden_i_1: torch.Tensor | None = None
40+
41+
def score(self) -> torch.Tensor:
42+
assert self.hidden_i is not None
43+
assert self.hidden_i_1 is not None
44+
# normalized absolute difference between two hidden states
45+
return (self.hidden_i - self.hidden_i_1).abs().sum() / torch.numel(
46+
self.hidden_i
47+
)
3548

36-
def check(self, step, score):
49+
def check(self, step: int) -> bool:
3750
"""
38-
_summary_
51+
Checks whether to interrupt inference after the current step.
3952
40-
:param score: _description_
41-
:type score: _type_
42-
:return: _description_
43-
:rtype: _type_
53+
:param score [int]: Current NCA inference step.
54+
:return [bool]: Whether to interrupt inference after the current step.
4455
"""
45-
if score >= self.threshold:
46-
self.cooldown = 0
47-
else:
48-
self.cooldown += 1
49-
if self.cooldown >= self.plateau:
50-
if self.verbose:
51-
logging.info(f"Breaking after {step} steps.")
52-
return True
53-
return False
56+
with torch.no_grad():
57+
if step < self.min_steps:
58+
return False
59+
if step >= self.max_steps:
60+
return True
61+
if self.hidden_i is None or self.hidden_i_1 is None:
62+
return False
63+
if self.score() >= self.threshold:
64+
self.cooldown = 0
65+
else:
66+
self.cooldown += 1
67+
if self.cooldown >= self.plateau:
68+
if self.verbose:
69+
logging.info(f"Breaking after {step} steps.")
70+
return True
71+
return False

ncalab/export/header.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
2-
import torch
2+
import os
33
from pathlib import PosixPath, Path
44

5+
import torch
6+
57
from ..models.basicNCA import BasicNCAModel
68

79

@@ -34,12 +36,12 @@ def export_header(
3436
):
3537
## prepare preamble
3638
# guard
37-
preamble = "#pragma once\n"
39+
preamble = f"#pragma once{os.linesep}"
3840
# add imports if any
3941
if imports:
4042
for header in imports:
41-
preamble += f'#include "{header}"'
42-
preamble += "\n\n"
43+
preamble += f'#include "{header}"{os.linesep}'
44+
preamble += f"{os.linesep}{os.linesep}"
4345

4446
with open(outfile, "w") as f:
4547
f.write(preamble)

ncalab/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .hook import Hook # noqa: F401

ncalab/hooks/hook.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
3+
4+
class Hook:
5+
def __init__(self, *args, **kwargs):
6+
pass
7+
8+
def pre_forward(self, x: torch.Tensor) -> torch.Tensor:
9+
return x
10+
11+
def post_forward(self, x: torch.Tensor) -> torch.Tensor:
12+
return x
13+
14+
def pre_perceive(self, x: torch.Tensor) -> torch.Tensor:
15+
return x
16+
17+
def pre_update(self, x: torch.Tensor) -> torch.Tensor:
18+
return x

ncalab/hooks/hook_output_noise.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
from .hook import Hook
4+
5+
class OutputNoiseHook(Hook):
6+
def __init__(self, dx_noise: float):
7+
self.dx_noise = dx_noise
8+
9+
def pre_forward(self, x: torch.Tensor) -> torch.Tensor:
10+
return x
11+
12+
def post_forward(self, x: torch.Tensor) -> torch.Tensor:
13+
return x
14+
15+
def pre_perceive(self, x: torch.Tensor) -> torch.Tensor:
16+
return x
17+
18+
def pre_update(self, x: torch.Tensor) -> torch.Tensor:
19+
return x

ncalab/models/basicNCA.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn.functional as F # type: ignore[import-untyped]
99

1010
from ..autostepper import AutoStepper
11+
from ..prediction import Prediction
1112
from ..utils import pad_input
1213

1314

@@ -22,6 +23,8 @@ def __init__(
2223
num_image_channels: int,
2324
num_hidden_channels: int,
2425
num_output_channels: int,
26+
plot_function: Optional[Callable] = None,
27+
validation_metric: Optional[str] = None,
2528
fire_rate: float = 0.5,
2629
hidden_size: int = 128,
2730
use_alive_mask: bool = False,
@@ -74,14 +77,15 @@ def __init__(
7477
self.pad_noise = pad_noise
7578
self.autostepper = autostepper
7679
self.use_temporal_encoding = use_temporal_encoding
77-
78-
# set by subclassing functions
79-
self.plot_function: Optional[Callable] = None
80-
self.validation_metric: Optional[str] = None
80+
self.plot_function = plot_function
81+
self.validation_metric = validation_metric
8182

8283
self._define_filters(num_learned_filters)
8384

8485
# define model structure
86+
self._define_network()
87+
88+
def _define_network(self):
8589
input_vector_size = self.num_channels * (self.num_filters + 1)
8690
if self.use_temporal_encoding:
8791
input_vector_size += 1
@@ -103,7 +107,7 @@ def __init__(
103107
padding=0,
104108
kernel_size=1,
105109
),
106-
).to(device)
110+
).to(self.device)
107111

108112
# initialize final layer with 0
109113
with torch.no_grad():
@@ -137,8 +141,7 @@ def _define_filters(self, num_learned_filters: int):
137141
sobel_x = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0
138142
sobel_y = sobel_x.T
139143
laplace = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])
140-
self.filters.append(sobel_x)
141-
self.filters.append(sobel_y)
144+
self.filters.extend([sobel_x, sobel_y])
142145
if self.use_laplace:
143146
self.filters.append(laplace)
144147
self.num_filters = len(self.filters)
@@ -154,7 +157,7 @@ def prepare_input(self, x: torch.Tensor) -> torch.Tensor:
154157
"""
155158
return x
156159

157-
def __alive(self, x):
160+
def _alive(self, x):
158161
mask = (
159162
F.max_pool2d(
160163
x[:, 3, :, :],
@@ -190,11 +193,12 @@ def _perceive_with(x, weight):
190193
dx = torch.cat(perception, 1)
191194
return dx
192195

193-
def _update(self, x: torch.Tensor, step):
196+
def _update(self, x: torch.Tensor, step: int) -> torch.Tensor:
194197
"""
195198
Compute residual cell update.
196199
197200
:param x [torch.Tensor]: Input tensor, BCWH
201+
:param step [int]: Current timestep, required for computing temporal encoding.
198202
"""
199203
assert x.shape[1] == self.num_channels
200204

@@ -218,40 +222,33 @@ def forward(
218222
self,
219223
x: torch.Tensor,
220224
steps: int = 1,
221-
) -> torch.Tensor | Tuple[torch.Tensor, int]:
225+
) -> Prediction:
222226
"""
223227
:param x [torch.Tensor]: Input image, padded along the channel dimension, BCWH.
224228
:param steps [int]: Time steps in forward pass.
225229
226-
:returns: Output image (BCWH)
230+
:returns [Prediction]: Prediction object.
227231
"""
228232
if self.autostepper is None:
229233
for step in range(steps):
230234
dx = self._update(x, step)
231235
x = x + dx
232-
return x, steps
233236

234-
# invariant: auto_min_steps > 0, so both of these will be defined when used
235-
hidden_i: torch.Tensor | None = None
236-
hidden_i_1: torch.Tensor | None = None
237+
# Alive masking
238+
if self.use_alive_mask:
239+
life_mask = self._alive(x)
240+
life_mask = life_mask
241+
x = x.permute(1, 0, 2, 3) # B C W H --> C B W H
242+
x = x * life_mask.float()
243+
x = x.permute(1, 0, 2, 3) # C B W H --> B C W H
244+
return Prediction(self, steps, x)
245+
246+
237247
for step in range(self.autostepper.max_steps):
238-
with torch.no_grad():
239-
if (
240-
step >= self.autostepper.min_steps
241-
and hidden_i is not None
242-
and hidden_i_1 is not None
243-
):
244-
# normalized absolute difference between two hidden states
245-
score = (hidden_i - hidden_i_1).abs().sum() / (
246-
hidden_i.shape[0]
247-
* hidden_i.shape[1]
248-
* hidden_i.shape[2]
249-
* hidden_i.shape[3]
250-
)
251-
if self.autostepper.check(step, score):
252-
return x, step
248+
if self.autostepper.check(step):
249+
return Prediction(self, step, x)
253250
# save previous hidden state
254-
hidden_i_1 = x[
251+
self.autostepper.hidden_i_1 = x[
255252
:,
256253
self.num_image_channels : self.num_image_channels
257254
+ self.num_hidden_channels,
@@ -264,20 +261,21 @@ def forward(
264261

265262
# Alive masking
266263
if self.use_alive_mask:
267-
life_mask = self.__alive(x)
264+
life_mask = self._alive(x)
268265
life_mask = life_mask
269266
x = x.permute(1, 0, 2, 3) # B C W H --> C B W H
270267
x = x * life_mask.float()
271268
x = x.permute(1, 0, 2, 3) # C B W H --> B C W H
269+
272270
# set current hidden state
273-
hidden_i = x[
271+
self.autostepper.hidden_i = x[
274272
:,
275273
self.num_image_channels : self.num_image_channels
276274
+ self.num_hidden_channels,
277275
:,
278276
:,
279277
]
280-
return x, self.autostepper.max_steps
278+
return Prediction(self, self.autostepper.max_steps, x)
281279

282280
def loss(self, image: torch.Tensor, label: torch.Tensor) -> Dict[str, torch.Tensor]:
283281
"""
@@ -317,11 +315,11 @@ def metrics(self, pred: torch.Tensor, label: torch.Tensor) -> Dict[str, float]:
317315
"""
318316
return {}
319317

320-
def predict(self, image: torch.Tensor, steps: int = 100) -> torch.Tensor:
318+
def predict(self, image: torch.Tensor, steps: int = 100) -> Prediction:
321319
"""
322320
:param image [torch.Tensor]: Input image, BCWH.
323321
324-
:returns [torch.Tensor]: Output image, BCWH
322+
:returns [Prediction]: Prediction object.
325323
"""
326324
assert steps >= 1
327325
assert image.shape[1] <= self.num_channels
@@ -330,19 +328,22 @@ def predict(self, image: torch.Tensor, steps: int = 100) -> torch.Tensor:
330328
x = image.clone()
331329
x = pad_input(x, self, noise=self.pad_noise)
332330
x = self.prepare_input(x)
333-
x, _ = self.forward(x, steps=steps) # type: ignore[assignment]
334-
return x
331+
prediction = self.forward(x, steps=steps)
332+
return prediction
335333

336334
def validate(
337335
self, image: torch.Tensor, label: torch.Tensor, steps: int
338-
) -> Optional[Tuple[Dict[str, float], torch.Tensor]]:
336+
) -> Optional[Tuple[Dict[str, float], Prediction]]:
339337
"""
338+
Make a prediction on an image of the validation set and return metrics computed
339+
with respect to a labelled validation image.
340+
340341
:param image [torch.Tensor]: Input image, BCWH
341342
:param label [torch.Tensor]: Ground truth label
342343
:param steps [int]: Inference steps
343344
344-
:returns [Tuple[float, torch.Tensor]]: Validation metric, predicted image BCWH
345+
:returns [Tuple[float, Prediction]]: Validation metric, predicted image BCWH
345346
"""
346-
pred = self.predict(image.to(self.device), steps=steps)
347-
metrics = self.metrics(pred, label.to(self.device))
348-
return metrics, pred
347+
prediction = self.predict(image.to(self.device), steps=steps)
348+
metrics = self.metrics(prediction.output_image, label.to(self.device))
349+
return metrics, prediction

0 commit comments

Comments
 (0)