Skip to content

Commit f5e02b2

Browse files
committed
move training/inference steps to models, refactor perception module
1 parent 87db5f7 commit f5e02b2

File tree

15 files changed

+507
-223
lines changed

15 files changed

+507
-223
lines changed

ncalab/autostepper.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
import logging
2+
from typing import Any
23

34
import torch
45

6+
from .models import BasicNCAModel
7+
from .prediction import Prediction
8+
59

610
class AutoStepper:
711
"""
@@ -41,7 +45,7 @@ def __init__(
4145
self.hidden_i: torch.Tensor | None = None
4246
self.hidden_i_1: torch.Tensor | None = None
4347

44-
def score(self) -> torch.Tensor:
48+
def _score(self) -> torch.Tensor:
4549
"""
4650
Calculates activity score.
4751
@@ -57,7 +61,7 @@ def score(self) -> torch.Tensor:
5761
self.hidden_i
5862
)
5963

60-
def check(self, step: int) -> bool:
64+
def _check(self, step: int) -> bool:
6165
"""
6266
Checks whether to interrupt inference after the current step.
6367
@@ -74,7 +78,7 @@ def check(self, step: int) -> bool:
7478
return True
7579
if self.hidden_i is None or self.hidden_i_1 is None:
7680
return False
77-
if self.score() >= self.threshold:
81+
if self._score() >= self.threshold:
7882
self.cooldown = 0
7983
else:
8084
self.cooldown += 1
@@ -83,3 +87,15 @@ def check(self, step: int) -> bool:
8387
logging.info(f"Breaking after {step} steps.")
8488
return True
8589
return False
90+
91+
def run(self, nca: BasicNCAModel, x):
92+
prediction: Prediction = nca(x, steps=self.min_steps)
93+
for step in range(self.min_steps, self.max_steps):
94+
self.hidden_i_1 = prediction.hidden_channels
95+
prediction = nca(x, steps=self.min_steps)
96+
self.hidden_i = prediction.hidden_channels
97+
if self._check(step):
98+
return prediction
99+
100+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
101+
return self.run(args[0], args[1])

ncalab/models/basicNCA.py

Lines changed: 36 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
import torch.nn as nn # type: ignore[import-untyped]
88
import torch.nn.functional as F # type: ignore[import-untyped]
99

10-
from ..autostepper import AutoStepper
1110
from ..prediction import Prediction
12-
from ..utils import pad_input
11+
from ..utils import intepret_range_parameter, pad_input, unwrap
1312
from ..visualization import Visual
14-
from .basicNCArule import BasicNCARule
1513
from .basicNCAhead import BasicNCAHead
14+
from .basicNCAperception import BasicNCAPerception
15+
from .basicNCArule import BasicNCARule
1616

1717

1818
class BasicNCAModel(nn.Module):
1919
"""
2020
Abstract base class for NCA models.
21+
22+
BasicNCAModel is a composition of an NCA backbone model (called "rule"), and
23+
an (optional) head module for downstream tasks.
2124
"""
2225

2326
def __init__(
@@ -37,9 +40,10 @@ def __init__(
3740
use_laplace: bool = False,
3841
kernel_size: int = 3,
3942
pad_noise: bool = False,
40-
autostepper: Optional[AutoStepper] = None,
4143
use_temporal_encoding: bool = False,
4244
rule_type: type[BasicNCARule] = BasicNCARule,
45+
training_timesteps: int | Tuple[int, int] = 100,
46+
inference_timesteps: int | Tuple[int, int] = 100,
4347
):
4448
"""
4549
:param device: Pytorch device descriptor.
@@ -55,7 +59,6 @@ def __init__(
5559
:param use_laplace: Whether to use Laplace filter (only if num_learned_filters == 0)
5660
:param kernel_size: Filter kernel size (only for learned filters)
5761
:param pad_noise: Whether to pad input image tensor with noise in hidden / output channels
58-
:param autostepper: AutoStepper object to select number of time steps based on activity
5962
"""
6063
super(BasicNCAModel, self).__init__()
6164

@@ -77,61 +80,30 @@ def __init__(
7780
self.kernel_size = kernel_size
7881
self.filter_padding = filter_padding
7982
self.pad_noise = pad_noise
80-
self.autostepper = autostepper
8183
self.use_temporal_encoding = use_temporal_encoding
8284
self.plot_function = plot_function
8385
self.validation_metric = validation_metric
84-
85-
# define input filters
86-
self._define_filters(num_learned_filters)
86+
self.training_timesteps = training_timesteps
87+
self.inference_timesteps = inference_timesteps
8788

8889
# define model structure
89-
self.input_vector_size = self.num_channels * (self.num_filters + 1)
90+
# perception
91+
self.perception = BasicNCAPerception(self)
92+
self.input_vector_size = self.num_channels * (self.perception.num_filters + 1)
9093
if self.use_temporal_encoding:
9194
self.input_vector_size += 1
95+
# rule
9296
self.rule_type = rule_type
9397
self.rule = self._define_rule()
9498
self.head: BasicNCAHead | None = None
99+
# pre-compute stochastic weight update
100+
self._stochastic: torch.Tensor | None = None
95101

96102
def _define_rule(self):
97103
return self.rule_type(
98104
self.device, self.input_vector_size, self.hidden_size, self.num_channels
99105
)
100106

101-
def _define_filters(self, num_learned_filters: int):
102-
"""
103-
Define list of perception filters, based on parameters passed in constructor.
104-
105-
:param num_learned_filters: Number of learned filters in perception filter bank.
106-
:type num_learned_filters: int
107-
"""
108-
self.filters: list | nn.ModuleList = []
109-
if num_learned_filters > 0:
110-
self.num_filters = num_learned_filters
111-
filters = []
112-
for _ in range(num_learned_filters):
113-
filters.append(
114-
nn.Conv2d(
115-
self.num_channels,
116-
self.num_channels,
117-
kernel_size=self.kernel_size,
118-
stride=1,
119-
padding=(self.kernel_size // 2),
120-
padding_mode=self.filter_padding,
121-
groups=self.num_channels,
122-
bias=False,
123-
)
124-
)
125-
self.filters = nn.ModuleList(filters).to(self.device)
126-
else:
127-
sobel_x = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0
128-
sobel_y = sobel_x.T
129-
laplace = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]])
130-
self.filters.extend([sobel_x, sobel_y])
131-
if self.use_laplace:
132-
self.filters.append(laplace)
133-
self.num_filters = len(self.filters)
134-
135107
def prepare_input(self, x: torch.Tensor) -> torch.Tensor:
136108
"""
137109
Preprocess input. Intended to be overwritten by subclass, if preprocessing
@@ -155,34 +127,6 @@ def _alive(self, x):
155127
)
156128
return mask
157129

158-
def _perceive(self, x, step) -> torch.Tensor:
159-
def _perceive_with(x, weight):
160-
if isinstance(weight, nn.Conv2d):
161-
return weight(x)
162-
# if using a hard coded filter matrix.
163-
# this is done in the original Growing NCA paper, but learned filters typically
164-
# work better.
165-
conv_weights = torch.from_numpy(weight.astype(np.float32)).to(self.device)
166-
conv_weights = conv_weights.view(1, 1, 3, 3).repeat(
167-
self.num_channels, 1, 1, 1
168-
)
169-
return F.conv2d(x, conv_weights, padding=1, groups=self.num_channels)
170-
171-
perception = [x]
172-
perception.extend([_perceive_with(x, w) for w in self.filters])
173-
if self.use_temporal_encoding:
174-
normalization = 100
175-
if self.autostepper is not None:
176-
normalization = self.autostepper.max_steps
177-
perception.append(
178-
torch.mul(
179-
torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])),
180-
step / normalization,
181-
).to(self.device)
182-
)
183-
dx = torch.cat(perception, 1)
184-
return dx
185-
186130
def _update(self, x: torch.Tensor, step: int) -> torch.Tensor:
187131
"""
188132
Compute residual cell update.
@@ -193,16 +137,13 @@ def _update(self, x: torch.Tensor, step: int) -> torch.Tensor:
193137
assert x.shape[1] == self.num_channels
194138

195139
# Perception
196-
dx = self._perceive(x, step)
140+
dx = self.perception.perceive(x, step)
197141

198142
# Compute delta from FFNN network
199143
dx = self.rule(dx)
200144

201145
# Stochastic weight update
202-
fire_rate = self.fire_rate
203-
stochastic = torch.rand([dx.size(0), 1, dx.size(2), dx.size(3)]) < fire_rate
204-
stochastic = stochastic.float().to(self.device)
205-
dx = dx * stochastic
146+
dx = dx * unwrap(self._stochastic)[step % len(unwrap(self._stochastic))]
206147

207148
if self.immutable_image_channels:
208149
dx[:, : self.num_image_channels, :, :] *= 0
@@ -232,33 +173,23 @@ def forward(
232173
233174
:returns [Prediction]: Prediction object.
234175
"""
235-
if self.autostepper is None:
236-
for step in range(steps):
237-
x = self._forward_step(x, step)
238-
return Prediction(self, steps, x)
239-
240-
for step in range(self.autostepper.max_steps):
241-
if self.autostepper.check(step):
242-
return Prediction(self, step, x)
243-
# save previous hidden state
244-
self.autostepper.hidden_i_1 = x[
245-
:,
246-
self.num_image_channels : self.num_image_channels
247-
+ self.num_hidden_channels,
248-
:,
249-
:,
250-
]
176+
assert x.shape[1] == self.num_channels
177+
S = torch.rand([steps, 1, 1, x.size(2), x.size(3)]) < self.fire_rate
178+
self._stochastic = S.float().to(self.device)
179+
for step in range(steps):
251180
x = self._forward_step(x, step)
252181

253-
# set current hidden state
254-
self.autostepper.hidden_i = x[
182+
if self.head is not None:
183+
hidden = x[
255184
:,
256185
self.num_image_channels : self.num_image_channels
257186
+ self.num_hidden_channels,
258187
:,
259188
:,
260189
]
261-
return Prediction(self, self.autostepper.max_steps, x)
190+
head_prediction = self.head(hidden)
191+
return Prediction(self, steps, x, head_prediction)
192+
return Prediction(self, steps, x)
262193

263194
def loss(self, pred: Prediction, label: torch.Tensor) -> Dict[str, torch.Tensor]:
264195
"""
@@ -281,9 +212,7 @@ def finetune(self, freeze_head: bool = False):
281212
and setting to "train" mode.
282213
"""
283214
self.train()
284-
if self.num_learned_filters != 0:
285-
for filter in self.filters:
286-
filter.requires_grad_(False)
215+
self.perception.freeze()
287216
self.rule.freeze()
288217
if freeze_head and self.head is not None:
289218
self.head.freeze()
@@ -316,7 +245,9 @@ def predict(self, image: torch.Tensor, steps: int = 100) -> Prediction:
316245
prediction = self.forward(x, steps=steps)
317246
return prediction
318247

319-
def record(self, image: torch.Tensor, steps: int = 100) -> List[Prediction]:
248+
def record(
249+
self, image: torch.Tensor, steps: Optional[int] = None
250+
) -> List[Prediction]:
320251
"""
321252
Record predictions for all time steps and return the resulting
322253
sequence of predictions.
@@ -325,8 +256,9 @@ def record(self, image: torch.Tensor, steps: int = 100) -> List[Prediction]:
325256
326257
:returns [List[Prediction]]: List of Prediction objects.
327258
"""
328-
assert steps >= 1
329259
assert image.shape[1] <= self.num_channels
260+
if steps is None:
261+
steps = intepret_range_parameter(self.inference_timesteps)
330262
self.eval()
331263
sequence = []
332264
with torch.no_grad():
@@ -340,7 +272,7 @@ def record(self, image: torch.Tensor, steps: int = 100) -> List[Prediction]:
340272
return sequence
341273

342274
def validate(
343-
self, image: torch.Tensor, label: torch.Tensor, steps: int
275+
self, image: torch.Tensor, label: torch.Tensor, steps: Optional[int] = None
344276
) -> Optional[Tuple[Dict[str, float], Prediction]]:
345277
"""
346278
Make a prediction on an image of the validation set and return metrics computed
@@ -352,6 +284,8 @@ def validate(
352284
353285
:returns [Tuple[float, Prediction]]: Validation metric, predicted image BCWH
354286
"""
287+
if steps is None:
288+
steps = intepret_range_parameter(self.inference_timesteps)
355289
prediction = self.predict(image.to(self.device), steps=steps)
356290
metrics = self.metrics(prediction, label.to(self.device))
357291
return metrics, prediction

0 commit comments

Comments
 (0)