Skip to content

Commit 74840d9

Browse files
author
Alexander Ororbia
committed
latency-cell refactored and unit-tested
1 parent 479d94a commit 74840d9

File tree

2 files changed

+99
-51
lines changed

2 files changed

+99
-51
lines changed

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,16 @@
1-
from ngclearn import resolver, Component, Compartment
21
from ngclearn.components.jaxComponent import JaxComponent
3-
from ngclearn.utils import tensorstats
4-
from ngclearn.utils.model_utils import clamp_min, clamp_max
52
from jax import numpy as jnp, random, jit
63
from functools import partial
7-
from ngcsimlib.logger import info
8-
9-
@jit
10-
def _update_times(t, s, tols):
11-
"""
12-
Updates time-of-last-spike (tols) variable.
13-
14-
Args:
15-
t: current time (a scalar/int value)
4+
from ngclearn.utils import tensorstats
5+
from ngcsimlib.deprecators import deprecate_args
6+
from ngcsimlib.logger import info, warn
167

17-
s: binary spike vector
8+
from ngcsimlib.compilers.process import transition
9+
#from ngcsimlib.component import Component
10+
from ngcsimlib.compartment import Compartment
1811

19-
tols: current time-of-last-spike variable
12+
from ngclearn.utils.model_utils import clamp_min, clamp_max
2013

21-
Returns:
22-
updated tols variable
23-
"""
24-
_tols = (1. - s) * tols + (s * t)
25-
return _tols
2614

2715
@partial(jit, static_argnums=[5])
2816
def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1.,
@@ -157,9 +145,10 @@ class LatencyCell(JaxComponent):
157145
"""
158146

159147
# Define Functions
160-
def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
161-
linearize=False, normalize=False, clip_spikes=False, num_steps=1.,
162-
batch_size=1, **kwargs):
148+
def __init__(
149+
self, name, n_units, tau=1., threshold=0.01, first_spike_time=0., linearize=False, normalize=False,
150+
clip_spikes=False, num_steps=1., batch_size=1, **kwargs
151+
):
163152
super().__init__(name, **kwargs)
164153

165154
## latency meta-parameters
@@ -186,9 +175,11 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
186175
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
187176
#self.reset()
188177

178+
@transition(output_compartments=["targ_sp_times", "clip_mask"])
189179
@staticmethod
190-
def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
191-
normalize, clip_spikes, inputs):
180+
def calc_spike_times(
181+
linearize, tau, threshold, first_spike_time, num_steps, normalize, clip_spikes, inputs
182+
):
192183
## would call this function before processing a spike train (at start)
193184
data = inputs
194185
if clip_spikes:
@@ -208,42 +199,27 @@ def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
208199
targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
209200
return targ_sp_times, clip_mask
210201

211-
@resolver(_calc_spike_times)
212-
def calc_spike_times(self, targ_sp_times, clip_mask):
213-
self.targ_sp_times.set(targ_sp_times)
214-
self.clip_mask.set(clip_mask)
215-
202+
@transition(output_compartments=["outputs", "tols", "mask", "targ_sp_times", "key"])
216203
@staticmethod
217-
def _advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols):
204+
def advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols):
218205
key, *subkeys = random.split(key, 2)
219-
data = inputs ## get sensory pattern data / features
220-
spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
221-
tols = _update_times(t, spikes, tols)
206+
data = inputs ## get sensory pattern data / features
207+
spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
208+
209+
# Updates time-of-last-spike (tols) variable:
210+
# output = s = binary spike vector
211+
# tols = current time-of-last-spike variable
212+
tols = (1. - spikes) * tols + (spikes * t)
213+
222214
spikes = spikes * (1. - clip_mask)
223215
return spikes, tols, spk_mask, targ_sp_times, key
224216

225-
@resolver(_advance_state)
226-
def advance_state(self, outputs, tols, mask, targ_sp_times, key):
227-
self.outputs.set(outputs)
228-
self.tols.set(tols)
229-
self.mask.set(mask)
230-
self.targ_sp_times.set(targ_sp_times)
231-
self.key.set(key)
232-
217+
@transition(output_compartments=["inputs", "outputs", "tols", "mask", "clip_mask", "targ_sp_times"])
233218
@staticmethod
234-
def _reset(batch_size, n_units):
219+
def reset(batch_size, n_units):
235220
restVals = jnp.zeros((batch_size, n_units))
236221
return (restVals, restVals, restVals, restVals, restVals, restVals)
237222

238-
@resolver(_reset)
239-
def reset(self, inputs, outputs, tols, mask, clip_mask, targ_sp_times):
240-
self.inputs.set(inputs)
241-
self.outputs.set(outputs)
242-
self.tols.set(tols)
243-
self.mask.set(mask)
244-
self.clip_mask.set(clip_mask)
245-
self.targ_sp_times.set(targ_sp_times)
246-
247223
def save(self, directory, **kwargs):
248224
file_name = directory + "/" + self.name + ".npz"
249225
jnp.savez(file_name, key=self.key.value)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import LatencyCell
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
from ngcsimlib.utils.compartment import Get_Compartment_Batch
14+
15+
16+
def test_latencyCell():
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
T = 50 # 100 #5 ## number of simulation steps to run
21+
dt = 1. # 0.1 # ms ## compute integration time constant
22+
tau = 1.
23+
# ---- build a simple Poisson cell system ----
24+
with Context("Circuit") as ctx:
25+
a = LatencyCell(
26+
"a", n_units=4, tau=tau, threshold=0.01, linearize=True,
27+
normalize=True, num_steps=T, clip_spikes=False
28+
)
29+
30+
## create and compile core simulation commands
31+
advance_process = (Process()
32+
>> a.advance_state)
33+
ctx.wrap_and_add_command(advance_process.pure, name="advance")
34+
calc_spike_times_process = (Process()
35+
>> a.calc_spike_times)
36+
ctx.wrap_and_add_command(calc_spike_times_process.pure, name="calc_spike_times")
37+
reset_process = (Process()
38+
>> a.reset)
39+
ctx.wrap_and_add_command(reset_process.pure, name="reset")
40+
41+
## set up non-compiled utility commands
42+
@Context.dynamicCommand
43+
def clamp(x):
44+
a.inputs.set(x)
45+
46+
## input spike train
47+
inputs = jnp.asarray([[0.02, 0.5, 1., 0.0]])
48+
49+
targets = np.zeros((T, 4))
50+
targets[0, 2] = 1.
51+
targets[24, 1] = 1.
52+
targets[48, 0] = 1.
53+
targets[49, 3] = 1.
54+
targets = jnp.array(targets) ## gold-standard solution to check against
55+
56+
outs = []
57+
ctx.reset()
58+
ctx.clamp(inputs)
59+
ctx.calc_spike_times()
60+
for ts in range(T):
61+
ctx.clamp(inputs)
62+
ctx.advance(t=ts * dt, dt=dt)
63+
## naively extract simple statistics at time ts and print them to I/O
64+
s = a.outputs.value
65+
outs.append(s)
66+
#print(" {}: s {} ".format(ts, jnp.squeeze(s)))
67+
outs = jnp.concatenate(outs, axis=0)
68+
69+
## output should equal input
70+
assert_array_equal(outs, targets)
71+
72+
test_latencyCell()

0 commit comments

Comments
 (0)