Skip to content

Commit 8a5958d

Browse files
committed
update rate cell
1 parent 098f3db commit 8a5958d

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
step_euler, step_rk2, step_rk4
1414

1515
## rewritten code
16-
@partial(jit, static_argnums=[5])
16+
# @partial(jit, static_argnums=[5])
1717
def _dfz_internal(z, j, j_td, tau_m, leak_gamma, prior_type=None): ## raw dynamics
1818
z_leak = z # * 2 ## Default: assume Gaussian
19+
prior_type_dict = {
20+
0: "laplacian",
21+
1: "cauchy",
22+
2: "exp"
23+
}
24+
prior_type = prior_type_dict.get(prior_type, None)
1925
if prior_type != None:
2026
if prior_type == "laplacian": ## Laplace dist
2127
z_leak = jnp.sign(z) ## d/dx of Laplace is signum
@@ -31,7 +37,7 @@ def _dfz(t, z, params): ## diff-eq dynamics wrapper
3137
dz_dt = _dfz_internal(z, j, j_td, tau_m, leak_gamma, priorType)
3238
return dz_dt
3339

34-
@jit
40+
# @jit
3541
def _modulate(j, dfx_val):
3642
"""
3743
Apply a signal modulator to j (typically of the form of a derivative/dampening function)
@@ -46,6 +52,7 @@ def _modulate(j, dfx_val):
4652
"""
4753
return j * dfx_val
4854

55+
@partial(jit, static_argnames=["integType", "priorType"])
4956
def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None):
5057
"""
5158
Runs leaky rate-coded state dynamics one step in time.
@@ -81,7 +88,7 @@ def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None)
8188
_, _z = step_euler(0., z, _dfz, dt, params)
8289
return _z
8390

84-
@jit
91+
# @jit
8592
def _run_cell_stateless(j):
8693
"""
8794
A simplification of running a stateless set of dynamics over j (an identity
@@ -161,7 +168,12 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
161168
if tau_m <= 0: ## trigger stateless mode
162169
self.is_stateful = False
163170
priorType, leakRate = prior
164-
self.priorType = priorType ## type of scale-shift prior to impose over the leak
171+
priorTypeDict = {
172+
"laplacian": 0,
173+
"cauchy": 1,
174+
"exp": 2
175+
}
176+
self.priorType = priorTypeDict.get(priorType, -1)
165177
self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
166178
thresholdType, thr_lmbda = threshold
167179
self.thresholdType = thresholdType ## type of thresholding function to use
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import RateCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_RateCell1():
19+
name = "rate_ctx"
20+
dkey = random.PRNGKey(42)
21+
dkey, *subkeys = random.split(dkey, 100)
22+
dt = 1. # ms
23+
with Context(name) as ctx:
24+
a = RateCell(
25+
name="a", n_units=1, tau_m=50., prior=("gaussian", 0.), act_fx="identity",
26+
threshold=("none", 0.), integration_type="euler",
27+
batch_size=1, resist_scale=1., shape=None, is_stateful=True
28+
)
29+
advance_process = (Process() >> a.advance_state)
30+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31+
reset_process = (Process() >> a.reset)
32+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
33+
34+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
35+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
36+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
37+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
38+
39+
@Context.dynamicCommand
40+
def clamp(x):
41+
a.j.set(x)
42+
43+
## input spike train
44+
x_seq = jnp.ones((1, 10))
45+
## desired output/epsp pulses
46+
y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32)
47+
48+
outs = []
49+
ctx.reset()
50+
for ts in range(x_seq.shape[1]):
51+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
52+
ctx.clamp(x_t)
53+
ctx.run(t=ts * 1., dt=dt)
54+
outs.append(a.z.value)
55+
outs = jnp.concatenate(outs, axis=1)
56+
57+
## output should equal input
58+
# assert_array_equal(outs, y_seq, tol=1e-3)
59+
np.testing.assert_allclose(outs, y_seq, atol=1e-3)
60+

0 commit comments

Comments
 (0)