Skip to content

Commit 43fbd9b

Browse files
committed
update rate cell and fix bug of passing a tuple of (jax Array -- not hashable) to jax jit functions. Basically, simplify the codebase by using a hashmap of functions
1 parent bfc200c commit 43fbd9b

File tree

1 file changed

+37
-35
lines changed

1 file changed

+37
-35
lines changed

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,24 @@
1212
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
1313
step_euler, step_rk2, step_rk4
1414

15-
## rewritten code
16-
# @partial(jit, static_argnums=[5])
17-
def _dfz_internal(z, j, j_td, tau_m, leak_gamma, prior_type=None): ## raw dynamics
18-
z_leak = z # * 2 ## Default: assume Gaussian
19-
prior_type_dict = {
20-
0: "laplacian",
21-
1: "cauchy",
22-
2: "exp"
23-
}
24-
prior_type = prior_type_dict.get(prior_type, None)
25-
if prior_type != None:
26-
if prior_type == "laplacian": ## Laplace dist
27-
z_leak = jnp.sign(z) ## d/dx of Laplace is signum
28-
elif prior_type == "cauchy": ## Cauchy dist: x ~ (1.0 + tf.math.square(z))
29-
z_leak = (z * 2)/(1. + jnp.square(z))
30-
elif prior_type == "exp": ## Exp dist: x ~ -exp(-x^2)
31-
z_leak = jnp.exp(-jnp.square(z)) * z * 2
15+
def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
16+
z_leak = jnp.sign(z) ## d/dx of Laplace is signum
17+
dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
18+
return dz_dt
19+
20+
def _dfz_internal_cauchy(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
21+
z_leak = (z * 2)/(1. + jnp.square(z))
22+
dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
23+
return dz_dt
24+
25+
def _dfz_internal_exp(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
26+
z_leak = jnp.exp(-jnp.square(z)) * z * 2
3227
dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
3328
return dz_dt
3429

35-
def _dfz(t, z, params): ## diff-eq dynamics wrapper
36-
j, j_td, tau_m, leak_gamma, priorType = params
37-
dz_dt = _dfz_internal(z, j, j_td, tau_m, leak_gamma, priorType)
30+
def _dfz_internal_gaussian(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
31+
z_leak = z # * 2 ## Default: assume Gaussian
32+
dz_dt = (-z_leak * leak_gamma + (j + j_td)) * (1./tau_m)
3833
return dz_dt
3934

4035
# @jit
@@ -52,8 +47,8 @@ def _modulate(j, dfx_val):
5247
"""
5348
return j * dfx_val
5449

55-
@partial(jit, static_argnames=["integType", "priorType"])
56-
def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None):
50+
# @partial(jit, static_argnames=["integType", "priorType"])
51+
def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=0):
5752
"""
5853
Runs leaky rate-coded state dynamics one step in time.
5954
@@ -77,15 +72,21 @@ def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None)
7772
Returns:
7873
New value of membrane/state for next time step
7974
"""
80-
if integType == 1:
81-
params = (j, j_td, tau_m, leak_gamma, priorType)
82-
_, _z = step_rk2(0., z, _dfz, dt, params)
83-
elif integType == 2:
84-
params = (j, j_td, tau_m, leak_gamma, priorType)
85-
_, _z = step_rk4(0., z, _dfz, dt, params)
86-
else:
87-
params = (j, j_td, tau_m, leak_gamma, priorType)
88-
_, _z = step_euler(0., z, _dfz, dt, params)
75+
_dfz_fns = {
76+
0: lambda t, z, params: _dfz_internal_gaussian(z, *params),
77+
1: lambda t, z, params: _dfz_internal_laplace(z, *params),
78+
2: lambda t, z, params: _dfz_internal_cauchy(z, *params),
79+
3: lambda t, z, params: _dfz_internal_exp(z, *params),
80+
}
81+
_dfz_fn = _dfz_fns.get(priorType, _dfz_internal_gaussian)
82+
_step_fns = {
83+
0: step_euler,
84+
1: step_rk2,
85+
2: step_rk4,
86+
}
87+
_step_fn = _step_fns.get(integType, step_euler)
88+
params = (j, j_td, tau_m, leak_gamma)
89+
_, _z = _step_fn(0., z, _dfz_fn, dt, params)
8990
return _z
9091

9192
# @jit
@@ -169,11 +170,12 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
169170
self.is_stateful = False
170171
priorType, leakRate = prior
171172
priorTypeDict = {
172-
"laplacian": 0,
173-
"cauchy": 1,
174-
"exp": 2
173+
"gaussian": 0,
174+
"laplacian": 1,
175+
"cauchy": 2,
176+
"exp": 3
175177
}
176-
self.priorType = priorTypeDict.get(priorType, -1)
178+
self.priorType = priorTypeDict.get(priorType, 0)
177179
self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
178180
thresholdType, thr_lmbda = threshold
179181
self.thresholdType = thresholdType ## type of thresholding function to use

0 commit comments

Comments
 (0)