1+ from jax import numpy as jnp , random , jit
2+ from functools import partial
3+ from ngclearn .utils import tensorstats
4+ from ngclearn import resolver , Component , Compartment
5+ from ngclearn .components .jaxComponent import JaxComponent
6+ from ngclearn .utils .model_utils import create_function , threshold_soft , \
7+ threshold_cauchy
8+ from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
9+ step_euler , step_rk2 , step_rk4
10+
11+ def _dfz_internal_gaussian (z , j , j_td , tau_m , leak_gamma ):
12+ z_leak = z # * 2 ## Default: assume Gaussian
13+ dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
14+ return dz_dt
15+
16+ def _dfz_internal_laplacian (z , j , j_td , tau_m , leak_gamma ):
17+ z_leak = jnp .sign (z ) ## d/dx of Laplace is signum
18+ dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
19+ return dz_dt
20+
21+ def _dfz_internal_cauchy (z , j , j_td , tau_m , leak_gamma ):
22+ z_leak = (z * 2 )/ (1. + jnp .square (z ))
23+ dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
24+ return dz_dt
25+
26+ def _dfz_internal_exp (z , j , j_td , tau_m , leak_gamma ):
27+ z_leak = jnp .exp (- jnp .square (z )) * z * 2
28+ dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
29+ return dz_dt
30+
31+
32+ def _dfz_gaussian (t , z , params ): ## diff-eq dynamics wrapper
33+ j , j_td , tau_m , leak_gamma = params
34+ dz_dt = _dfz_internal_gaussian (z , j , j_td , tau_m , leak_gamma )
35+ return dz_dt
36+
37+ def _dfz_laplacian (t , z , params ): ## diff-eq dynamics wrapper
38+ j , j_td , tau_m , leak_gamma = params
39+ dz_dt = _dfz_internal_laplacian (z , j , j_td , tau_m , leak_gamma )
40+ return dz_dt
41+
42+ def _dfz_cauchy (t , z , params ): ## diff-eq dynamics wrapper
43+ j , j_td , tau_m , leak_gamma = params
44+ dz_dt = _dfz_internal_cauchy (z , j , j_td , tau_m , leak_gamma )
45+ return dz_dt
46+
47+ def _dfz_exp (t , z , params ): ## diff-eq dynamics wrapper
48+ j , j_td , tau_m , leak_gamma = params
49+ dz_dt = _dfz_internal_exp (z , j , j_td , tau_m , leak_gamma )
50+ return dz_dt
51+
52+ @jit
53+ def _modulate (j , dfx_val ):
54+ """
55+ Apply a signal modulator to j (typically of the form of a derivative/dampening function)
56+
57+ Args:
58+ j: current/stimulus value to modulate
59+
60+ dfx_val: modulator signal
61+
62+ Returns:
63+ modulated j value
64+ """
65+ return j * dfx_val
66+
67+ def _run_cell (dt , j , j_td , z , tau_m , leak_gamma = 0. , integType = 0 , priorType = 0 ):
68+ """
69+ Runs leaky rate-coded state dynamics one step in time.
70+
71+ Args:
72+ dt: integration time constant
73+
74+ j: input (bottom-up) electrical/stimulus current
75+
76+ j_td: modulatory (top-down) electrical/stimulus pressure
77+
78+ z: current value of membrane/state
79+
80+ tau_m: membrane/state time constant
81+
82+ leak_gamma: strength of leak to apply to membrane/state
83+
84+ integType: integration type to use (0 --> Euler/RK1, 1 --> Midpoint/RK2, 2 --> RK4)
85+
86+ priorType: scale-shift prior distribution to impose over neural dynamics
87+
88+ Returns:
89+ New value of membrane/state for next time step
90+ """
91+ _dfz = {
92+ 0 : _dfz_gaussian ,
93+ 1 : _dfz_laplacian ,
94+ 2 : _dfz_cauchy ,
95+ 3 : _dfz_exp
96+ }.get (priorType , _dfz_gaussian )
97+ if integType == 1 :
98+ params = (j , j_td , tau_m , leak_gamma )
99+ _ , _z = step_rk2 (0. , z , _dfz , dt , params )
100+ elif integType == 2 :
101+ params = (j , j_td , tau_m , leak_gamma )
102+ _ , _z = step_rk4 (0. , z , _dfz , dt , params )
103+ else :
104+ params = (j , j_td , tau_m , leak_gamma )
105+ _ , _z = step_euler (0. , z , _dfz , dt , params )
106+ return _z
107+
108+ @jit
109+ def _run_cell_stateless (j ):
110+ """
111+ A simplification of running a stateless set of dynamics over j (an identity
112+ functional form of dynamics).
113+
114+ Args:
115+ j: stimulus to do nothing to
116+
117+ Returns:
118+ the stimulus
119+ """
120+ return j + 0
121+
122+ class RateCell (JaxComponent ): ## Rate-coded/real-valued cell
123+ """
124+ A non-spiking cell driven by the gradient dynamics of neural generative
125+ coding-driven predictive processing.
126+
127+ The specific differential equation that characterizes this cell
128+ is (for adjusting v, given current j, over time) is:
129+
130+ | tau_m * dz/dt = lambda * prior(z) + (j + j_td)
131+ | where j is the set of general incoming input signals (e.g., message-passed signals)
132+ | and j_td is taken to be the set of top-down pressure signals
133+
134+ | --- Cell Input Compartments: ---
135+ | j - input pressure (takes in external signals)
136+ | j_td - input/top-down pressure input (takes in external signals)
137+ | --- Cell State Compartments ---
138+ | z - rate activity
139+ | --- Cell Output Compartments: ---
140+ | zF - post-activation function activity, i.e., fx(z)
141+
142+ Args:
143+ name: the string name of this cell
144+
145+ n_units: number of cellular entities (neural population size)
146+
147+ tau_m: membrane/state time constant (milliseconds)
148+
149+ prior: a kernel for specifying the type of centered scale-shift distribution
150+ to impose over neuronal dynamics, applied to each neuron or
151+ dimension within this component (Default: ("gaussian", 0)); this is
152+ a tuple with 1st element containing a string name of the distribution
153+ one wants to use while the second value is a `leak rate` scalar
154+ that controls the influence/weighting that this distribution
155+ has on the dynamics; for example, ("laplacian, 0.001") means that a
156+ centered laplacian distribution scaled by `0.001` will be injected
157+ into this cell's dynamics ODE each step of simulated time
158+
159+ :Note: supported scale-shift distributions include "laplacian",
160+ "cauchy", "exp", and "gaussian"
161+
162+ act_fx: string name of activation function/nonlinearity to use
163+
164+ integration_type: type of integration to use for this cell's dynamics;
165+ current supported forms include "euler" (Euler/RK-1 integration)
166+ and "midpoint" or "rk2" (midpoint method/RK-2 integration) (Default: "euler")
167+
168+ :Note: setting the integration type to the midpoint method will
169+ increase the accuray of the estimate of the cell's evolution
170+ at an increase in computational cost (and simulation time)
171+
172+ resist_scale: a scaling factor applied to incoming pressure `j` (default: 1)
173+ """
174+
175+ # Define Functions
176+ def __init__ (self , name , n_units , tau_m , prior = ("gaussian" , 0. ), act_fx = "identity" ,
177+ threshold = ("none" , 0. ), integration_type = "euler" ,
178+ batch_size = 1 , resist_scale = 1. , shape = None , is_stateful = True , ** kwargs ):
179+ super ().__init__ (name , ** kwargs )
180+
181+ ## membrane parameter setup (affects ODE integration)
182+ self .tau_m = tau_m ## membrane time constant -- setting to 0 triggers "stateless" mode
183+ self .is_stateful = is_stateful
184+ if isinstance (tau_m , float ):
185+ if tau_m <= 0 : ## trigger stateless mode
186+ self .is_stateful = False
187+ priorType , leakRate = prior
188+ self .priorType = {
189+ "gaussian" : 0 ,
190+ "laplacian" : 1 ,
191+ "cauchy" : 2 ,
192+ "exp" : 3
193+ }.get (priorType , 0 ) ## type of scale-shift prior to impose over the leak
194+ self .priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
195+ thresholdType , thr_lmbda = threshold
196+ self .thresholdType = thresholdType ## type of thresholding function to use
197+ self .thr_lmbda = thr_lmbda ## scale to drive thresholding dynamics
198+ self .resist_scale = resist_scale ## a "resistance" scaling factor
199+
200+ ## integration properties
201+ self .integrationType = integration_type
202+ self .intgFlag = get_integrator_code (self .integrationType )
203+
204+ ## Layer size setup
205+ _shape = (batch_size , n_units ) ## default shape is 2D/matrix
206+ if shape is None :
207+ shape = (n_units ,) ## we set shape to be equal to n_units if nothing provided
208+ else :
209+ _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ]) ## shape is 4D tensor
210+ self .shape = shape
211+ self .n_units = n_units
212+ self .batch_size = batch_size
213+
214+ omega_0 = None
215+ if act_fx == "sine" :
216+ omega_0 = kwargs ["omega_0" ]
217+ self .fx , self .dfx = create_function (fun_name = act_fx , args = omega_0 )
218+
219+ # compartments (state of the cell & parameters will be updated through stateless calls)
220+ restVals = jnp .zeros (_shape )
221+ self .j = Compartment (restVals , display_name = "Input Stimulus Current" , units = "mA" ) # electrical current
222+ self .zF = Compartment (restVals , display_name = "Transformed Rate Activity" ) # rate-coded output - activity
223+ self .j_td = Compartment (restVals , display_name = "Modulatory Stimulus Current" , units = "mA" ) # top-down electrical current - pressure
224+ self .z = Compartment (restVals , display_name = "Rate Activity" , units = "mA" ) # rate activity
225+
226+ @staticmethod
227+ def _advance_state (dt , fx , dfx , tau_m , priorLeakRate , intgFlag , priorType ,
228+ resist_scale , thresholdType , thr_lmbda , is_stateful , j , j_td , z ):
229+ #if tau_m > 0.:
230+ if is_stateful :
231+ ### run a step of integration over neuronal dynamics
232+ ## Notes:
233+ ## self.pressure <-- "top-down" expectation / contextual pressure
234+ ## self.current <-- "bottom-up" data-dependent signal
235+ dfx_val = dfx (z )
236+ j = _modulate (j , dfx_val )
237+ j = j * resist_scale
238+ tmp_z = _run_cell (dt , j , j_td , z ,
239+ tau_m , leak_gamma = priorLeakRate ,
240+ integType = intgFlag , priorType = priorType )
241+ ## apply optional thresholding sub-dynamics
242+ if thresholdType == "soft_threshold" :
243+ tmp_z = threshold_soft (tmp_z , thr_lmbda )
244+ elif thresholdType == "cauchy_threshold" :
245+ tmp_z = threshold_cauchy (tmp_z , thr_lmbda )
246+ z = tmp_z ## pre-activation function value(s)
247+ zF = fx (z ) ## post-activation function value(s)
248+ else :
249+ ## run in "stateless" mode (when no membrane time constant provided)
250+ j_total = j + j_td
251+ z = _run_cell_stateless (j_total )
252+ zF = fx (z )
253+ return j , j_td , z , zF
254+
255+ @resolver (_advance_state )
256+ def advance_state (self , j , j_td , z , zF ):
257+ self .j .set (j )
258+ self .j_td .set (j_td )
259+ self .z .set (z )
260+ self .zF .set (zF )
261+
262+ @staticmethod
263+ def _reset (batch_size , shape ): #n_units
264+ _shape = (batch_size , shape [0 ])
265+ if len (shape ) > 1 :
266+ _shape = (batch_size , shape [0 ], shape [1 ], shape [2 ])
267+ restVals = jnp .zeros (_shape )
268+ return tuple ([restVals for _ in range (4 )])
269+
270+ @resolver (_reset )
271+ def reset (self , j , zF , j_td , z ):
272+ self .j .set (j ) # electrical current
273+ self .zF .set (zF ) # rate-coded output - activity
274+ self .j_td .set (j_td ) # top-down electrical current - pressure
275+ self .z .set (z ) # rate activity
276+
277+ def save (self , directory , ** kwargs ):
278+ ## do a protected save of constants, depending on whether they are floats or arrays
279+ tau_m = (self .tau_m if isinstance (self .tau_m , float )
280+ else jnp .ones ([[self .tau_m ]]))
281+ priorLeakRate = (self .priorLeakRate if isinstance (self .priorLeakRate , float )
282+ else jnp .ones ([[self .priorLeakRate ]]))
283+ resist_scale = (self .resist_scale if isinstance (self .resist_scale , float )
284+ else jnp .ones ([[self .resist_scale ]]))
285+
286+ file_name = directory + "/" + self .name + ".npz"
287+ jnp .savez (file_name ,
288+ tau_m = tau_m , priorLeakRate = priorLeakRate ,
289+ resist_scale = resist_scale ) #, key=self.key.value)
290+
291+ def load (self , directory , seeded = False , ** kwargs ):
292+ file_name = directory + "/" + self .name + ".npz"
293+ data = jnp .load (file_name )
294+ ## constants loaded in
295+ self .tau_m = data ['tau_m' ]
296+ self .priorLeakRate = data ['priorLeakRate' ]
297+ self .resist_scale = data ['resist_scale' ]
298+ #if seeded:
299+ # self.key.set(data['key'])
300+
301+ @classmethod
302+ def help (cls ): ## component help function
303+ properties = {
304+ "cell_type" : "RateCell - evolves neurons according to rate-coded/"
305+ "continuous dynamics "
306+ }
307+ compartment_props = {
308+ "inputs" :
309+ {"j" : "External input stimulus value(s)" ,
310+ "j_td" : "External top-down input stimulus value(s); these get "
311+ "multiplied by the derivative of f(x), i.e., df(x)" },
312+ "states" :
313+ {"z" : "Update to rate-coded continuous dynamics; value at time t" },
314+ "outputs" :
315+ {"zF" : "Nonlinearity/function applied to rate-coded dynamics; f(z)" },
316+ }
317+ hyperparams = {
318+ "n_units" : "Number of neuronal cells to model in this layer" ,
319+ "batch_size" : "Batch size dimension of this component" ,
320+ "tau_m" : "Cell state/membrane time constant" ,
321+ "prior" : "What kind of kurtotic prior to place over neuronal dynamics?" ,
322+ "act_fx" : "Elementwise activation function to apply over cell state `z`" ,
323+ "threshold" : "What kind of iterative thresholding function to place over neuronal dynamics?" ,
324+ "integration_type" : "Type of numerical integration to use for the cell dynamics" ,
325+ }
326+ info = {cls .__name__ : properties ,
327+ "compartments" : compartment_props ,
328+ "dynamics" : "tau_m * dz/dt = Prior(z; gamma) + (j + j_td)" ,
329+ "hyperparameters" : hyperparams }
330+ return info
331+
332+ def __repr__ (self ):
333+ comps = [varname for varname in dir (self ) if Compartment .is_compartment (getattr (self , varname ))]
334+ maxlen = max (len (c ) for c in comps ) + 5
335+ lines = f"[{ self .__class__ .__name__ } ] PATH: { self .name } \n "
336+ for c in comps :
337+ stats = tensorstats (getattr (self , c ).value )
338+ if stats is not None :
339+ line = [f"{ k } : { v } " for k , v in stats .items ()]
340+ line = ", " .join (line )
341+ else :
342+ line = "None"
343+ lines += f" { f'({ c } )' .ljust (maxlen )} { line } \n "
344+ return lines
345+
346+ if __name__ == '__main__' :
347+ from ngcsimlib .context import Context
348+ with Context ("Bar" ) as bar :
349+ X = RateCell ("X" , 9 , 0.03 )
350+ print (X )
0 commit comments