77import torch .nn as nn # type: ignore[import-untyped]
88import torch .nn .functional as F # type: ignore[import-untyped]
99
10- from ..autostepper import AutoStepper
1110from ..prediction import Prediction
12- from ..utils import pad_input
11+ from ..utils import intepret_range_parameter , pad_input , unwrap
1312from ..visualization import Visual
14- from .basicNCArule import BasicNCARule
1513from .basicNCAhead import BasicNCAHead
14+ from .basicNCAperception import BasicNCAPerception
15+ from .basicNCArule import BasicNCARule
1616
1717
1818class BasicNCAModel (nn .Module ):
1919 """
2020 Abstract base class for NCA models.
21+
22+ BasicNCAModel is a composition of an NCA backbone model (called "rule"), and
23+ an (optional) head module for downstream tasks.
2124 """
2225
2326 def __init__ (
@@ -37,9 +40,10 @@ def __init__(
3740 use_laplace : bool = False ,
3841 kernel_size : int = 3 ,
3942 pad_noise : bool = False ,
40- autostepper : Optional [AutoStepper ] = None ,
4143 use_temporal_encoding : bool = False ,
4244 rule_type : type [BasicNCARule ] = BasicNCARule ,
45+ training_timesteps : int | Tuple [int , int ] = 100 ,
46+ inference_timesteps : int | Tuple [int , int ] = 100 ,
4347 ):
4448 """
4549 :param device: Pytorch device descriptor.
@@ -55,7 +59,6 @@ def __init__(
5559 :param use_laplace: Whether to use Laplace filter (only if num_learned_filters == 0)
5660 :param kernel_size: Filter kernel size (only for learned filters)
5761 :param pad_noise: Whether to pad input image tensor with noise in hidden / output channels
58- :param autostepper: AutoStepper object to select number of time steps based on activity
5962 """
6063 super (BasicNCAModel , self ).__init__ ()
6164
@@ -77,61 +80,30 @@ def __init__(
7780 self .kernel_size = kernel_size
7881 self .filter_padding = filter_padding
7982 self .pad_noise = pad_noise
80- self .autostepper = autostepper
8183 self .use_temporal_encoding = use_temporal_encoding
8284 self .plot_function = plot_function
8385 self .validation_metric = validation_metric
84-
85- # define input filters
86- self ._define_filters (num_learned_filters )
86+ self .training_timesteps = training_timesteps
87+ self .inference_timesteps = inference_timesteps
8788
8889 # define model structure
89- self .input_vector_size = self .num_channels * (self .num_filters + 1 )
90+ # perception
91+ self .perception = BasicNCAPerception (self )
92+ self .input_vector_size = self .num_channels * (self .perception .num_filters + 1 )
9093 if self .use_temporal_encoding :
9194 self .input_vector_size += 1
95+ # rule
9296 self .rule_type = rule_type
9397 self .rule = self ._define_rule ()
9498 self .head : BasicNCAHead | None = None
99+ # pre-compute stochastic weight update
100+ self ._stochastic : torch .Tensor | None = None
95101
96102 def _define_rule (self ):
97103 return self .rule_type (
98104 self .device , self .input_vector_size , self .hidden_size , self .num_channels
99105 )
100106
101- def _define_filters (self , num_learned_filters : int ):
102- """
103- Define list of perception filters, based on parameters passed in constructor.
104-
105- :param num_learned_filters: Number of learned filters in perception filter bank.
106- :type num_learned_filters: int
107- """
108- self .filters : list | nn .ModuleList = []
109- if num_learned_filters > 0 :
110- self .num_filters = num_learned_filters
111- filters = []
112- for _ in range (num_learned_filters ):
113- filters .append (
114- nn .Conv2d (
115- self .num_channels ,
116- self .num_channels ,
117- kernel_size = self .kernel_size ,
118- stride = 1 ,
119- padding = (self .kernel_size // 2 ),
120- padding_mode = self .filter_padding ,
121- groups = self .num_channels ,
122- bias = False ,
123- )
124- )
125- self .filters = nn .ModuleList (filters ).to (self .device )
126- else :
127- sobel_x = np .outer ([1 , 2 , 1 ], [- 1 , 0 , 1 ]) / 8.0
128- sobel_y = sobel_x .T
129- laplace = np .array ([[0 , 1 , 0 ], [1 , - 4 , 1 ], [0 , 1 , 0 ]])
130- self .filters .extend ([sobel_x , sobel_y ])
131- if self .use_laplace :
132- self .filters .append (laplace )
133- self .num_filters = len (self .filters )
134-
135107 def prepare_input (self , x : torch .Tensor ) -> torch .Tensor :
136108 """
137109 Preprocess input. Intended to be overwritten by subclass, if preprocessing
@@ -155,34 +127,6 @@ def _alive(self, x):
155127 )
156128 return mask
157129
158- def _perceive (self , x , step ) -> torch .Tensor :
159- def _perceive_with (x , weight ):
160- if isinstance (weight , nn .Conv2d ):
161- return weight (x )
162- # if using a hard coded filter matrix.
163- # this is done in the original Growing NCA paper, but learned filters typically
164- # work better.
165- conv_weights = torch .from_numpy (weight .astype (np .float32 )).to (self .device )
166- conv_weights = conv_weights .view (1 , 1 , 3 , 3 ).repeat (
167- self .num_channels , 1 , 1 , 1
168- )
169- return F .conv2d (x , conv_weights , padding = 1 , groups = self .num_channels )
170-
171- perception = [x ]
172- perception .extend ([_perceive_with (x , w ) for w in self .filters ])
173- if self .use_temporal_encoding :
174- normalization = 100
175- if self .autostepper is not None :
176- normalization = self .autostepper .max_steps
177- perception .append (
178- torch .mul (
179- torch .ones ((x .shape [0 ], 1 , x .shape [2 ], x .shape [3 ])),
180- step / normalization ,
181- ).to (self .device )
182- )
183- dx = torch .cat (perception , 1 )
184- return dx
185-
186130 def _update (self , x : torch .Tensor , step : int ) -> torch .Tensor :
187131 """
188132 Compute residual cell update.
@@ -193,16 +137,13 @@ def _update(self, x: torch.Tensor, step: int) -> torch.Tensor:
193137 assert x .shape [1 ] == self .num_channels
194138
195139 # Perception
196- dx = self ._perceive (x , step )
140+ dx = self .perception . perceive (x , step )
197141
198142 # Compute delta from FFNN network
199143 dx = self .rule (dx )
200144
201145 # Stochastic weight update
202- fire_rate = self .fire_rate
203- stochastic = torch .rand ([dx .size (0 ), 1 , dx .size (2 ), dx .size (3 )]) < fire_rate
204- stochastic = stochastic .float ().to (self .device )
205- dx = dx * stochastic
146+ dx = dx * unwrap (self ._stochastic )[step % len (unwrap (self ._stochastic ))]
206147
207148 if self .immutable_image_channels :
208149 dx [:, : self .num_image_channels , :, :] *= 0
@@ -232,33 +173,23 @@ def forward(
232173
233174 :returns [Prediction]: Prediction object.
234175 """
235- if self .autostepper is None :
236- for step in range (steps ):
237- x = self ._forward_step (x , step )
238- return Prediction (self , steps , x )
239-
240- for step in range (self .autostepper .max_steps ):
241- if self .autostepper .check (step ):
242- return Prediction (self , step , x )
243- # save previous hidden state
244- self .autostepper .hidden_i_1 = x [
245- :,
246- self .num_image_channels : self .num_image_channels
247- + self .num_hidden_channels ,
248- :,
249- :,
250- ]
176+ assert x .shape [1 ] == self .num_channels
177+ S = torch .rand ([steps , 1 , 1 , x .size (2 ), x .size (3 )]) < self .fire_rate
178+ self ._stochastic = S .float ().to (self .device )
179+ for step in range (steps ):
251180 x = self ._forward_step (x , step )
252181
253- # set current hidden state
254- self . autostepper . hidden_i = x [
182+ if self . head is not None :
183+ hidden = x [
255184 :,
256185 self .num_image_channels : self .num_image_channels
257186 + self .num_hidden_channels ,
258187 :,
259188 :,
260189 ]
261- return Prediction (self , self .autostepper .max_steps , x )
190+ head_prediction = self .head (hidden )
191+ return Prediction (self , steps , x , head_prediction )
192+ return Prediction (self , steps , x )
262193
263194 def loss (self , pred : Prediction , label : torch .Tensor ) -> Dict [str , torch .Tensor ]:
264195 """
@@ -281,9 +212,7 @@ def finetune(self, freeze_head: bool = False):
281212 and setting to "train" mode.
282213 """
283214 self .train ()
284- if self .num_learned_filters != 0 :
285- for filter in self .filters :
286- filter .requires_grad_ (False )
215+ self .perception .freeze ()
287216 self .rule .freeze ()
288217 if freeze_head and self .head is not None :
289218 self .head .freeze ()
@@ -316,7 +245,9 @@ def predict(self, image: torch.Tensor, steps: int = 100) -> Prediction:
316245 prediction = self .forward (x , steps = steps )
317246 return prediction
318247
319- def record (self , image : torch .Tensor , steps : int = 100 ) -> List [Prediction ]:
248+ def record (
249+ self , image : torch .Tensor , steps : Optional [int ] = None
250+ ) -> List [Prediction ]:
320251 """
321252 Record predictions for all time steps and return the resulting
322253 sequence of predictions.
@@ -325,8 +256,9 @@ def record(self, image: torch.Tensor, steps: int = 100) -> List[Prediction]:
325256
326257 :returns [List[Prediction]]: List of Prediction objects.
327258 """
328- assert steps >= 1
329259 assert image .shape [1 ] <= self .num_channels
260+ if steps is None :
261+ steps = intepret_range_parameter (self .inference_timesteps )
330262 self .eval ()
331263 sequence = []
332264 with torch .no_grad ():
@@ -340,7 +272,7 @@ def record(self, image: torch.Tensor, steps: int = 100) -> List[Prediction]:
340272 return sequence
341273
342274 def validate (
343- self , image : torch .Tensor , label : torch .Tensor , steps : int
275+ self , image : torch .Tensor , label : torch .Tensor , steps : Optional [ int ] = None
344276 ) -> Optional [Tuple [Dict [str , float ], Prediction ]]:
345277 """
346278 Make a prediction on an image of the validation set and return metrics computed
@@ -352,6 +284,8 @@ def validate(
352284
353285 :returns [Tuple[float, Prediction]]: Validation metric, predicted image BCWH
354286 """
287+ if steps is None :
288+ steps = intepret_range_parameter (self .inference_timesteps )
355289 prediction = self .predict (image .to (self .device ), steps = steps )
356290 metrics = self .metrics (prediction , label .to (self .device ))
357291 return metrics , prediction
0 commit comments