11from __future__ import annotations
2- from typing import Callable , List , Dict
2+ from typing import Callable , List , Optional , Dict
33import numpy as np
44
55import torch # type: ignore[import-untyped]
66import torch .nn as nn # type: ignore[import-untyped]
77import torch .nn .functional as F # type: ignore[import-untyped]
88
99
10+ class AutoStepper :
11+ def __init__ (
12+ self ,
13+ min_steps : int = 10 ,
14+ max_steps : int = 100 ,
15+ plateau : int = 5 ,
16+ verbose : bool = False ,
17+ threshold : float = 1e-2 ,
18+ ):
19+ assert min_steps >= 1
20+ assert plateau >= 1
21+ assert max_steps > min_steps
22+ self .min_steps = min_steps
23+ self .max_steps = max_steps
24+ self .plateau = plateau
25+ self .verbose = verbose
26+ self .threshold = threshold
27+
28+
1029class BasicNCAModel (nn .Module ):
1130 def __init__ (
1231 self ,
@@ -21,14 +40,10 @@ def __init__(
2140 num_learned_filters : int = 2 ,
2241 dx_noise : float = 0.0 ,
2342 filter_padding : str = "reflect" ,
43+ use_laplace : bool = False ,
2444 kernel_size : int = 3 ,
25- auto_step : bool = False ,
26- auto_max_steps : int = 100 ,
27- auto_min_steps : int = 10 ,
28- auto_plateau : int = 5 ,
29- auto_verbose : bool = False ,
30- auto_threshold : float = 1e-2 ,
3145 pad_noise : bool = False ,
46+ autostepper : Optional [AutoStepper ] = None ,
3247 ):
3348 """Basic abstract class for NCA models.
3449
@@ -57,19 +72,14 @@ def __init__(
5772 num_image_channels + num_hidden_channels + num_output_channels
5873 )
5974 self .fire_rate = fire_rate
75+ self .hidden_size = hidden_size
6076 self .use_alive_mask = use_alive_mask
6177 self .immutable_image_channels = immutable_image_channels
6278 self .num_learned_filters = num_learned_filters
79+ self .use_laplace = use_laplace
6380 self .dx_noise = dx_noise
64- self .auto_step = auto_step
65- self .auto_max_steps = auto_max_steps
66- self .auto_min_steps = auto_min_steps
67- self .auto_plateau = auto_plateau
68- self .auto_verbose = auto_verbose
69- self .auto_threshold = auto_threshold
7081 self .pad_noise = pad_noise
71-
72- self .hidden_size = hidden_size
82+ self .autostepper = autostepper
7383
7484 self .plot_function : Callable | None = None
7585
@@ -93,17 +103,20 @@ def __init__(
93103 )
94104 self .filters = nn .ModuleList (filters )
95105 else :
96- self .num_filters = 2
97106 sobel_x = np .outer ([1 , 2 , 1 ], [- 1 , 0 , 1 ]) / 8.0
98107 sobel_y = sobel_x .T
108+ laplace = np .array ([[0 , 1 , 0 ], [1 , - 4 , 1 ], [0 , 1 , 0 ]])
99109 self .filters .append (sobel_x )
100110 self .filters .append (sobel_y )
111+ if self .use_laplace :
112+ self .filters .append (laplace )
113+ self .num_filters = len (self .filters )
101114
102115 self .network = nn .Sequential (
103116 nn .Linear (
104117 self .num_channels * (self .num_filters + 1 ), self .hidden_size , bias = True
105118 ),
106- # nn.LazyBatchNorm2d(),
119+ nn .LazyBatchNorm2d (),
107120 nn .ReLU (),
108121 nn .Linear (self .hidden_size , self .num_channels , bias = False ),
109122 ).to (device )
@@ -175,14 +188,14 @@ def update(self, x):
175188
176189 # Stochastic weight update
177190 fire_rate = self .fire_rate
178- stochastic = torch .rand ([dx .size (0 ), dx .size (1 ), dx .size (2 ), 1 ]) > fire_rate
191+ stochastic = torch .rand ([dx .size (0 ), dx .size (1 ), dx .size (2 ), 1 ]) < fire_rate
179192 stochastic = stochastic .float ().to (self .device )
180193 dx = dx * stochastic
181194
182195 dx += self .dx_noise * torch .randn ([dx .size (0 ), dx .size (1 ), dx .size (2 ), 1 ]).to (
183196 self .device
184197 )
185-
198+
186199 if self .immutable_image_channels :
187200 dx [..., : self .num_image_channels ] *= 0
188201
@@ -205,65 +218,60 @@ def forward(
205218 steps : int = 1 ,
206219 return_steps : bool = False ,
207220 ):
208- if self .auto_step :
209- # Assumption: min_steps >= 1; otherwise we cannot compute distance
210- assert self .auto_min_steps >= 1
211- assert self .auto_plateau >= 1
212- assert self .auto_max_steps > self .auto_min_steps
213-
214- cooldown = 0
215- # invariant: auto_min_steps > 0, so both of these will be set when used
216- hidden_i : torch .Tensor | None = None
217- hidden_i_1 : torch .Tensor | None = None
218- for step in range (self .auto_max_steps ):
219- with torch .no_grad ():
220- if (
221- step >= self .auto_min_steps
222- and hidden_i is not None
223- and hidden_i_1 is not None
224- ):
225- # normalized absolute difference between two hidden states
226- score = (hidden_i - hidden_i_1 ).abs ().sum () / (
227- hidden_i .shape [0 ]
228- * hidden_i .shape [1 ]
229- * hidden_i .shape [2 ]
230- * hidden_i .shape [3 ]
231- )
232- if score >= self .auto_threshold :
233- cooldown = 0
234- else :
235- cooldown += 1
236- if cooldown >= self .auto_plateau :
237- if self .auto_verbose :
238- print (f"Breaking after { step } steps." )
239- if return_steps :
240- return x , step
241- return x
242- # save previous hidden state
243- hidden_i_1 = x [
244- ...,
245- self .num_image_channels : self .num_image_channels
246- + self .num_hidden_channels ,
247- ]
248- # single inference time step
249- x = self .update (x )
250- # set current hidden state
251- hidden_i = x [
252- ...,
253- self .num_image_channels : self .num_image_channels
254- + self .num_hidden_channels ,
255- ]
256- if return_steps :
257- return x , self .auto_max_steps
258- return x
259- else :
221+ if self .autostepper is None :
260222 for step in range (steps ):
261223 x = self .update (x )
262224 if return_steps :
263225 return x , steps
264226 return x
265227
266- def loss (self , x , target ) -> Dict [str , float ]:
228+ cooldown = 0
229+ # invariant: auto_min_steps > 0, so both of these will be defined when used
230+ hidden_i : torch .Tensor | None = None
231+ hidden_i_1 : torch .Tensor | None = None
232+ for step in range (self .autostepper .max_steps ):
233+ with torch .no_grad ():
234+ if (
235+ step >= self .autostepper .min_steps
236+ and hidden_i is not None
237+ and hidden_i_1 is not None
238+ ):
239+ # normalized absolute difference between two hidden states
240+ score = (hidden_i - hidden_i_1 ).abs ().sum () / (
241+ hidden_i .shape [0 ]
242+ * hidden_i .shape [1 ]
243+ * hidden_i .shape [2 ]
244+ * hidden_i .shape [3 ]
245+ )
246+ if score >= self .autostepper .threshold :
247+ cooldown = 0
248+ else :
249+ cooldown += 1
250+ if cooldown >= self .autostepper .plateau :
251+ if self .autostepper .verbose :
252+ print (f"Breaking after { step } steps." )
253+ if return_steps :
254+ return x , step
255+ return x
256+ # save previous hidden state
257+ hidden_i_1 = x [
258+ ...,
259+ self .num_image_channels : self .num_image_channels
260+ + self .num_hidden_channels ,
261+ ]
262+ # single inference time step
263+ x = self .update (x )
264+ # set current hidden state
265+ hidden_i = x [
266+ ...,
267+ self .num_image_channels : self .num_image_channels
268+ + self .num_hidden_channels ,
269+ ]
270+ if return_steps :
271+ return x , self .autostepper .max_steps
272+ return x
273+
274+ def loss (self , x , target ) -> Dict [str , torch .Tensor ]:
267275 """_summary_
268276
269277 Args:
0 commit comments