88import torch .nn .functional as F # type: ignore[import-untyped]
99
1010from ..autostepper import AutoStepper
11+ from ..prediction import Prediction
1112from ..utils import pad_input
1213
1314
@@ -22,6 +23,8 @@ def __init__(
2223 num_image_channels : int ,
2324 num_hidden_channels : int ,
2425 num_output_channels : int ,
26+ plot_function : Optional [Callable ] = None ,
27+ validation_metric : Optional [str ] = None ,
2528 fire_rate : float = 0.5 ,
2629 hidden_size : int = 128 ,
2730 use_alive_mask : bool = False ,
@@ -74,14 +77,15 @@ def __init__(
7477 self .pad_noise = pad_noise
7578 self .autostepper = autostepper
7679 self .use_temporal_encoding = use_temporal_encoding
77-
78- # set by subclassing functions
79- self .plot_function : Optional [Callable ] = None
80- self .validation_metric : Optional [str ] = None
80+ self .plot_function = plot_function
81+ self .validation_metric = validation_metric
8182
8283 self ._define_filters (num_learned_filters )
8384
8485 # define model structure
86+ self ._define_network ()
87+
88+ def _define_network (self ):
8589 input_vector_size = self .num_channels * (self .num_filters + 1 )
8690 if self .use_temporal_encoding :
8791 input_vector_size += 1
@@ -103,7 +107,7 @@ def __init__(
103107 padding = 0 ,
104108 kernel_size = 1 ,
105109 ),
106- ).to (device )
110+ ).to (self . device )
107111
108112 # initialize final layer with 0
109113 with torch .no_grad ():
@@ -137,8 +141,7 @@ def _define_filters(self, num_learned_filters: int):
137141 sobel_x = np .outer ([1 , 2 , 1 ], [- 1 , 0 , 1 ]) / 8.0
138142 sobel_y = sobel_x .T
139143 laplace = np .array ([[0 , 1 , 0 ], [1 , - 4 , 1 ], [0 , 1 , 0 ]])
140- self .filters .append (sobel_x )
141- self .filters .append (sobel_y )
144+ self .filters .extend ([sobel_x , sobel_y ])
142145 if self .use_laplace :
143146 self .filters .append (laplace )
144147 self .num_filters = len (self .filters )
@@ -154,7 +157,7 @@ def prepare_input(self, x: torch.Tensor) -> torch.Tensor:
154157 """
155158 return x
156159
157- def __alive (self , x ):
160+ def _alive (self , x ):
158161 mask = (
159162 F .max_pool2d (
160163 x [:, 3 , :, :],
@@ -190,11 +193,12 @@ def _perceive_with(x, weight):
190193 dx = torch .cat (perception , 1 )
191194 return dx
192195
193- def _update (self , x : torch .Tensor , step ) :
196+ def _update (self , x : torch .Tensor , step : int ) -> torch . Tensor :
194197 """
195198 Compute residual cell update.
196199
197200 :param x [torch.Tensor]: Input tensor, BCWH
201+ :param step [int]: Current timestep, required for computing temporal encoding.
198202 """
199203 assert x .shape [1 ] == self .num_channels
200204
@@ -218,40 +222,33 @@ def forward(
218222 self ,
219223 x : torch .Tensor ,
220224 steps : int = 1 ,
221- ) -> torch . Tensor | Tuple [ torch . Tensor , int ] :
225+ ) -> Prediction :
222226 """
223227 :param x [torch.Tensor]: Input image, padded along the channel dimension, BCWH.
224228 :param steps [int]: Time steps in forward pass.
225229
226- :returns: Output image (BCWH)
230+ :returns [Prediction]: Prediction object.
227231 """
228232 if self .autostepper is None :
229233 for step in range (steps ):
230234 dx = self ._update (x , step )
231235 x = x + dx
232- return x , steps
233236
234- # invariant: auto_min_steps > 0, so both of these will be defined when used
235- hidden_i : torch .Tensor | None = None
236- hidden_i_1 : torch .Tensor | None = None
237+ # Alive masking
238+ if self .use_alive_mask :
239+ life_mask = self ._alive (x )
240+ life_mask = life_mask
241+ x = x .permute (1 , 0 , 2 , 3 ) # B C W H --> C B W H
242+ x = x * life_mask .float ()
243+ x = x .permute (1 , 0 , 2 , 3 ) # C B W H --> B C W H
244+ return Prediction (self , steps , x )
245+
246+
237247 for step in range (self .autostepper .max_steps ):
238- with torch .no_grad ():
239- if (
240- step >= self .autostepper .min_steps
241- and hidden_i is not None
242- and hidden_i_1 is not None
243- ):
244- # normalized absolute difference between two hidden states
245- score = (hidden_i - hidden_i_1 ).abs ().sum () / (
246- hidden_i .shape [0 ]
247- * hidden_i .shape [1 ]
248- * hidden_i .shape [2 ]
249- * hidden_i .shape [3 ]
250- )
251- if self .autostepper .check (step , score ):
252- return x , step
248+ if self .autostepper .check (step ):
249+ return Prediction (self , step , x )
253250 # save previous hidden state
254- hidden_i_1 = x [
251+ self . autostepper . hidden_i_1 = x [
255252 :,
256253 self .num_image_channels : self .num_image_channels
257254 + self .num_hidden_channels ,
@@ -264,20 +261,21 @@ def forward(
264261
265262 # Alive masking
266263 if self .use_alive_mask :
267- life_mask = self .__alive (x )
264+ life_mask = self ._alive (x )
268265 life_mask = life_mask
269266 x = x .permute (1 , 0 , 2 , 3 ) # B C W H --> C B W H
270267 x = x * life_mask .float ()
271268 x = x .permute (1 , 0 , 2 , 3 ) # C B W H --> B C W H
269+
272270 # set current hidden state
273- hidden_i = x [
271+ self . autostepper . hidden_i = x [
274272 :,
275273 self .num_image_channels : self .num_image_channels
276274 + self .num_hidden_channels ,
277275 :,
278276 :,
279277 ]
280- return x , self .autostepper .max_steps
278+ return Prediction ( self , self .autostepper .max_steps , x )
281279
282280 def loss (self , image : torch .Tensor , label : torch .Tensor ) -> Dict [str , torch .Tensor ]:
283281 """
@@ -317,11 +315,11 @@ def metrics(self, pred: torch.Tensor, label: torch.Tensor) -> Dict[str, float]:
317315 """
318316 return {}
319317
320- def predict (self , image : torch .Tensor , steps : int = 100 ) -> torch . Tensor :
318+ def predict (self , image : torch .Tensor , steps : int = 100 ) -> Prediction :
321319 """
322320 :param image [torch.Tensor]: Input image, BCWH.
323321
324- :returns [torch.Tensor ]: Output image, BCWH
322+ :returns [Prediction ]: Prediction object.
325323 """
326324 assert steps >= 1
327325 assert image .shape [1 ] <= self .num_channels
@@ -330,19 +328,22 @@ def predict(self, image: torch.Tensor, steps: int = 100) -> torch.Tensor:
330328 x = image .clone ()
331329 x = pad_input (x , self , noise = self .pad_noise )
332330 x = self .prepare_input (x )
333- x , _ = self .forward (x , steps = steps ) # type: ignore[assignment]
334- return x
331+ prediction = self .forward (x , steps = steps )
332+ return prediction
335333
336334 def validate (
337335 self , image : torch .Tensor , label : torch .Tensor , steps : int
338- ) -> Optional [Tuple [Dict [str , float ], torch . Tensor ]]:
336+ ) -> Optional [Tuple [Dict [str , float ], Prediction ]]:
339337 """
338+ Make a prediction on an image of the validation set and return metrics computed
339+ with respect to a labelled validation image.
340+
340341 :param image [torch.Tensor]: Input image, BCWH
341342 :param label [torch.Tensor]: Ground truth label
342343 :param steps [int]: Inference steps
343344
344- :returns [Tuple[float, torch.Tensor ]]: Validation metric, predicted image BCWH
345+ :returns [Tuple[float, Prediction ]]: Validation metric, predicted image BCWH
345346 """
346- pred = self .predict (image .to (self .device ), steps = steps )
347- metrics = self .metrics (pred , label .to (self .device ))
348- return metrics , pred
347+ prediction = self .predict (image .to (self .device ), steps = steps )
348+ metrics = self .metrics (prediction . output_image , label .to (self .device ))
349+ return metrics , prediction
0 commit comments