Skip to content

Commit af283dc

Browse files
author
Alexander Ororbia
committed
Merge branch 'major_release_update' of github.com:NACLab/ngc-learn into major_release_update
2 parents b3c47a2 + 4c22428 commit af283dc

File tree

7 files changed

+419
-25
lines changed

7 files changed

+419
-25
lines changed

ngclearn/components/neurons/graded/gaussianErrorCell.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, jit
44
from ngclearn.utils import tensorstats
5+
from ngcsimlib.compilers.process import transition
56

67
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
78
"""
@@ -64,8 +65,9 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
6465
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
6566
self.mask = Compartment(restVals + 1.0)
6667

68+
@transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
6769
@staticmethod
68-
def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
70+
def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
6971
# Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
7072
# behavior of the local cost functional:
7173
# FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
@@ -83,16 +85,9 @@ def _advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian
8385
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
8486
return dmu, dtarget, dSigma, jnp.squeeze(L), mask
8587

86-
@resolver(_advance_state)
87-
def advance_state(self, dmu, dtarget, dSigma, L, mask):
88-
self.dmu.set(dmu)
89-
self.dtarget.set(dtarget)
90-
self.dSigma.set(dSigma)
91-
self.L.set(L)
92-
self.mask.set(mask)
93-
88+
@transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
9489
@staticmethod
95-
def _reset(batch_size, shape, sigma_shape): ## reset core components/statistics
90+
def reset(batch_size, shape, sigma_shape): ## reset core components/statistics
9691
_shape = (batch_size, shape[0])
9792
if len(shape) > 1:
9893
_shape = (batch_size, shape[0], shape[1], shape[2])
@@ -107,17 +102,6 @@ def _reset(batch_size, shape, sigma_shape): ## reset core components/statistics
107102
mask = jnp.ones(_shape)
108103
return dmu, dtarget, dSigma, target, mu, modulator, L, mask
109104

110-
@resolver(_reset)
111-
def reset(self, dmu, dtarget, dSigma, target, mu, modulator, L, mask):
112-
self.dmu.set(dmu)
113-
self.dtarget.set(dtarget)
114-
self.dSigma.set(dSigma)
115-
self.target.set(target)
116-
self.mu.set(mu)
117-
self.modulator.set(modulator)
118-
self.L.set(L)
119-
self.mask.set(mask)
120-
121105
@classmethod
122106
def help(cls): ## component help function
123107
properties = {

ngclearn/components/neurons/graded/rateCell.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@
1313
step_euler, step_rk2, step_rk4
1414

1515
## rewritten code
16-
@partial(jit, static_argnums=[5])
16+
# @partial(jit, static_argnums=[5])
1717
def _dfz_internal(z, j, j_td, tau_m, leak_gamma, prior_type=None): ## raw dynamics
1818
z_leak = z # * 2 ## Default: assume Gaussian
19+
prior_type_dict = {
20+
0: "laplacian",
21+
1: "cauchy",
22+
2: "exp"
23+
}
24+
prior_type = prior_type_dict.get(prior_type, None)
1925
if prior_type != None:
2026
if prior_type == "laplacian": ## Laplace dist
2127
z_leak = jnp.sign(z) ## d/dx of Laplace is signum
@@ -31,7 +37,7 @@ def _dfz(t, z, params): ## diff-eq dynamics wrapper
3137
dz_dt = _dfz_internal(z, j, j_td, tau_m, leak_gamma, priorType)
3238
return dz_dt
3339

34-
@jit
40+
# @jit
3541
def _modulate(j, dfx_val):
3642
"""
3743
Apply a signal modulator to j (typically of the form of a derivative/dampening function)
@@ -46,6 +52,7 @@ def _modulate(j, dfx_val):
4652
"""
4753
return j * dfx_val
4854

55+
@partial(jit, static_argnames=["integType", "priorType"])
4956
def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None):
5057
"""
5158
Runs leaky rate-coded state dynamics one step in time.
@@ -81,7 +88,7 @@ def _run_cell(dt, j, j_td, z, tau_m, leak_gamma=0., integType=0, priorType=None)
8188
_, _z = step_euler(0., z, _dfz, dt, params)
8289
return _z
8390

84-
@jit
91+
# @jit
8592
def _run_cell_stateless(j):
8693
"""
8794
A simplification of running a stateless set of dynamics over j (an identity
@@ -161,7 +168,12 @@ def __init__(self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identit
161168
if tau_m <= 0: ## trigger stateless mode
162169
self.is_stateful = False
163170
priorType, leakRate = prior
164-
self.priorType = priorType ## type of scale-shift prior to impose over the leak
171+
priorTypeDict = {
172+
"laplacian": 0,
173+
"cauchy": 1,
174+
"exp": 2
175+
}
176+
self.priorType = priorTypeDict.get(priorType, -1)
165177
self.priorLeakRate = leakRate ## degree to which rate neurons leak (according to prior)
166178
thresholdType, thr_lmbda = threshold
167179
self.thresholdType = thresholdType ## type of thresholding function to use
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import RateCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_RateCell1():
19+
name = "rate_ctx"
20+
dkey = random.PRNGKey(42)
21+
dkey, *subkeys = random.split(dkey, 100)
22+
dt = 1. # ms
23+
with Context(name) as ctx:
24+
a = RateCell(
25+
name="a", n_units=1, tau_m=50., prior=("gaussian", 0.), act_fx="identity",
26+
threshold=("none", 0.), integration_type="euler",
27+
batch_size=1, resist_scale=1., shape=None, is_stateful=True
28+
)
29+
advance_process = (Process() >> a.advance_state)
30+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
31+
reset_process = (Process() >> a.reset)
32+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
33+
34+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
35+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
36+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
37+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
38+
39+
@Context.dynamicCommand
40+
def clamp(x):
41+
a.j.set(x)
42+
43+
## input spike train
44+
x_seq = jnp.ones((1, 10))
45+
## desired output/epsp pulses
46+
y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32)
47+
48+
outs = []
49+
ctx.reset()
50+
for ts in range(x_seq.shape[1]):
51+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
52+
ctx.clamp(x_t)
53+
ctx.run(t=ts * 1., dt=dt)
54+
outs.append(a.z.value)
55+
outs = jnp.concatenate(outs, axis=1)
56+
print(outs)
57+
## output should equal input
58+
# assert_array_equal(outs, y_seq, tol=1e-3)
59+
np.testing.assert_allclose(outs, y_seq, atol=1e-3)
60+
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import BernoulliErrorCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_bernoulliErrorCell():
19+
np.random.seed(42)
20+
name = "bernoulli_error_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
with Context(name) as ctx:
25+
a = BernoulliErrorCell(
26+
name="a", n_units=1, batch_size=1, input_logits=False, shape=None
27+
)
28+
advance_process = (Process() >> a.advance_state)
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
30+
reset_process = (Process() >> a.reset)
31+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
32+
33+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
37+
38+
@Context.dynamicCommand
39+
def clamp(x):
40+
a.p.set(x)
41+
42+
@Context.dynamicCommand
43+
def clamp_target(x):
44+
a.target.set(x)
45+
46+
## input spike train
47+
x_seq = jnp.asarray(np.random.randn(1, 10))
48+
target_seq = (jnp.arange(10)[None] - 5.0) / 2.0
49+
## desired output/epsp pulses
50+
y_seq = jnp.asarray([[-2.8193381, -4976.9263, -2.1224928, -2939.0425, -1233.3916, -0.24662945, -708.30042, 0.28213939, 3550.8477, 1.3651246]], dtype=jnp.float32)
51+
52+
outs = []
53+
ctx.reset()
54+
for ts in range(x_seq.shape[1]):
55+
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
56+
ctx.clamp(x_t)
57+
target_xt = jnp.array([[target_seq[0, ts]]])
58+
ctx.clamp_target(target_xt)
59+
ctx.run(t=ts * 1., dt=dt)
60+
outs.append(a.dp.value)
61+
outs = jnp.concatenate(outs, axis=1)
62+
# print(outs)
63+
## output should equal input
64+
np.testing.assert_allclose(outs, y_seq, atol=1e-7)
65+
66+
# test_bernoulliErrorCell()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# %%
2+
3+
from jax import numpy as jnp, random, jit
4+
from ngcsimlib.context import Context
5+
import numpy as np
6+
np.random.seed(42)
7+
from ngclearn.components import GaussianErrorCell
8+
from ngcsimlib.compilers import compile_command, wrap_command
9+
from numpy.testing import assert_array_equal
10+
11+
from ngcsimlib.compilers.process import Process, transition
12+
from ngcsimlib.component import Component
13+
from ngcsimlib.compartment import Compartment
14+
from ngcsimlib.context import Context
15+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
16+
17+
18+
def test_gaussianErrorCell():
19+
np.random.seed(42)
20+
name = "gaussian_error_ctx"
21+
dkey = random.PRNGKey(42)
22+
dkey, *subkeys = random.split(dkey, 100)
23+
dt = 1. # ms
24+
with Context(name) as ctx:
25+
a = GaussianErrorCell(
26+
name="a", n_units=1, batch_size=1, sigma=1.0, shape=None
27+
)
28+
advance_process = (Process() >> a.advance_state)
29+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
30+
reset_process = (Process() >> a.reset)
31+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
32+
33+
# reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
34+
# ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
35+
# advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
36+
# ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
37+
38+
@Context.dynamicCommand
39+
def clamp_mu(x):
40+
a.mu.set(x)
41+
42+
@Context.dynamicCommand
43+
def clamp_target(x):
44+
a.target.set(x)
45+
46+
## input sequence
47+
mu_seq = jnp.asarray(np.random.randn(1, 10))
48+
target_seq = (jnp.arange(10)[None] - 5.0) / 2.0
49+
## expected output based on the Gaussian error cell formula
50+
## L = -0.5 * (target - mu)^2 / sigma, dmu = (target - mu) / sigma
51+
expected_dmu = (target_seq - mu_seq) / 1.0 # sigma = 1.0
52+
expected_L = -0.5 * jnp.square(target_seq - mu_seq) / 1.0
53+
54+
dmu_outs = []
55+
L_outs = []
56+
ctx.reset()
57+
for ts in range(mu_seq.shape[1]):
58+
mu_t = jnp.array([[mu_seq[0, ts]]]) ## get data at time t
59+
ctx.clamp_mu(mu_t)
60+
target_t = jnp.array([[target_seq[0, ts]]])
61+
ctx.clamp_target(target_t)
62+
ctx.run(t=ts * 1., dt=dt)
63+
dmu_outs.append(a.dmu.value)
64+
L_outs.append(a.L.value)
65+
66+
dmu_outs = jnp.concatenate(dmu_outs, axis=1)
67+
L_outs = jnp.array(L_outs)[None] # (1, 10)
68+
# print(dmu_outs.shape)
69+
# print(L_outs.shape)
70+
# print(expected_dmu.shape)
71+
# print(expected_L.shape)
72+
73+
## verify outputs match expected values
74+
np.testing.assert_allclose(dmu_outs, expected_dmu, atol=1e-5)
75+
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
76+
77+
# test_gaussianErrorCell()

0 commit comments

Comments
 (0)