1212from ngclearn .utils .diffeq .ode_utils import get_integrator_code , \
1313 step_euler , step_rk2 , step_rk4
1414
15- ## rewritten code
16- # @partial(jit, static_argnums=[5])
17- def _dfz_internal (z , j , j_td , tau_m , leak_gamma , prior_type = None ): ## raw dynamics
18- z_leak = z # * 2 ## Default: assume Gaussian
19- prior_type_dict = {
20- 0 : "laplacian" ,
21- 1 : "cauchy" ,
22- 2 : "exp"
23- }
24- prior_type = prior_type_dict .get (prior_type , None )
25- if prior_type != None :
26- if prior_type == "laplacian" : ## Laplace dist
27- z_leak = jnp .sign (z ) ## d/dx of Laplace is signum
28- elif prior_type == "cauchy" : ## Cauchy dist: x ~ (1.0 + tf.math.square(z))
29- z_leak = (z * 2 )/ (1. + jnp .square (z ))
30- elif prior_type == "exp" : ## Exp dist: x ~ -exp(-x^2)
31- z_leak = jnp .exp (- jnp .square (z )) * z * 2
15+ def _dfz_internal_laplace (z , j , j_td , tau_m , leak_gamma ): ## raw dynamics
16+ z_leak = jnp .sign (z ) ## d/dx of Laplace is signum
17+ dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
18+ return dz_dt
19+
20+ def _dfz_internal_cauchy (z , j , j_td , tau_m , leak_gamma ): ## raw dynamics
21+ z_leak = (z * 2 )/ (1. + jnp .square (z ))
22+ dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
23+ return dz_dt
24+
25+ def _dfz_internal_exp (z , j , j_td , tau_m , leak_gamma ): ## raw dynamics
26+ z_leak = jnp .exp (- jnp .square (z )) * z * 2
3227 dz_dt = (- z_leak * leak_gamma + (j + j_td )) * (1. / tau_m )
3328 return dz_dt
3429
35- def _dfz ( t , z , params ): ## diff-eq dynamics wrapper
36- j , j_td , tau_m , leak_gamma , priorType = params
37- dz_dt = _dfz_internal ( z , j , j_td , tau_m , leak_gamma , priorType )
30+ def _dfz_internal_gaussian ( z , j , j_td , tau_m , leak_gamma ): ## raw dynamics
31+ z_leak = z # * 2 ## Default: assume Gaussian
32+ dz_dt = ( - z_leak * leak_gamma + ( j + j_td )) * ( 1. / tau_m )
3833 return dz_dt
3934
4035# @jit
@@ -52,8 +47,8 @@ def _modulate(j, dfx_val):
5247 """
5348 return j * dfx_val
5449
55- @partial (jit , static_argnames = ["integType" , "priorType" ])
56- def _run_cell (dt , j , j_td , z , tau_m , leak_gamma = 0. , integType = 0 , priorType = None ):
50+ # @partial(jit, static_argnames=["integType", "priorType"])
51+ def _run_cell (dt , j , j_td , z , tau_m , leak_gamma = 0. , integType = 0 , priorType = 0 ):
5752 """
5853 Runs leaky rate-coded state dynamics one step in time.
5954
@@ -77,15 +72,21 @@ def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None)
7772 Returns:
7873 New value of membrane/state for next time step
7974 """
80- if integType == 1 :
81- params = (j , j_td , tau_m , leak_gamma , priorType )
82- _ , _z = step_rk2 (0. , z , _dfz , dt , params )
83- elif integType == 2 :
84- params = (j , j_td , tau_m , leak_gamma , priorType )
85- _ , _z = step_rk4 (0. , z , _dfz , dt , params )
86- else :
87- params = (j , j_td , tau_m , leak_gamma , priorType )
88- _ , _z = step_euler (0. , z , _dfz , dt , params )
75+ _dfz_fns = {
76+ 0 : lambda t , z , params : _dfz_internal_gaussian (z , * params ),
77+ 1 : lambda t , z , params : _dfz_internal_laplace (z , * params ),
78+ 2 : lambda t , z , params : _dfz_internal_cauchy (z , * params ),
79+ 3 : lambda t , z , params : _dfz_internal_exp (z , * params ),
80+ }
81+ _dfz_fn = _dfz_fns .get (priorType , _dfz_internal_gaussian )
82+ _step_fns = {
83+ 0 : step_euler ,
84+ 1 : step_rk2 ,
85+ 2 : step_rk4 ,
86+ }
87+ _step_fn = _step_fns .get (integType , step_euler )
88+ params = (j , j_td , tau_m , leak_gamma )
89+ _ , _z = _step_fn (0. , z , _dfz_fn , dt , params )
8990 return _z
9091
9192# @jit
@@ -169,11 +170,12 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
169170 self .is_stateful = False
170171 priorType , leakRate = prior
171172 priorTypeDict = {
172- "laplacian" : 0 ,
173- "cauchy" : 1 ,
174- "exp" : 2
173+ "gaussian" : 0 ,
174+ "laplacian" : 1 ,
175+ "cauchy" : 2 ,
176+ "exp" : 3
175177 }
176- self .priorType = priorTypeDict .get (priorType , - 1 )
178+ self .priorType = priorTypeDict .get (priorType , 0 )
177179 self .priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
178180 thresholdType , thr_lmbda = threshold
179181 self .thresholdType = thresholdType ## type of thresholding function to use
0 commit comments