Skip to content

Commit ffd8f0e

Browse files
ago109willgebhardt
andauthored
Merging over Dynamics feature branch to main (#92)
* modded bernoulli-cell to include max-frequency constraint * added warning check to bernoulli, some cleanup * integrated if-cell, cleaned up lif and inits * mod to latency-cell * updated the poissonCell to be a true poisson * fixed minor bug in deprecation for poiss/bern * fixed minor bug in deprecation for poiss/bern * fixed validation fun in bern/poiss * moved back and cleaned up bernoulli and poisson cells * added threshold-clipping to latency cell * updates to if/lif * added batch-size arg to slif * fixed minor load bug in lif-cell * fixed a blocking jit-partial call in lif update_theta method; when loading * minor edit to dim-reduce * updated monitor plot code * update to dim-reduce * integrated phasor-cell, minor cleanup of latency * tweak to adex thr arg * tweak to adex thr arg * integrated resonate-and-fire neuronal cell * mod to raf-cell * cleaned up raf * cleaned up raf * cleaned up raf-cell * cleaned up raf-cell * cleaned up raf-cell * minor tweak to dim-reduce in utils * Additions for inhibition stuff * update to API modeling docs to reflect RAF neuronal cell --------- Co-authored-by: Alexander Ororbia <[email protected]> Co-authored-by: Will Gebhardt <[email protected]>
1 parent da2f24e commit ffd8f0e

File tree

20 files changed

+944
-129
lines changed

20 files changed

+944
-129
lines changed

docs/modeling/neurons.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,23 @@ cell supports either Euler or midpoint method / RK-2 integration.)
191191
:noindex:
192192
```
193193

194+
### The Resonate-and-Fire (RAF) Cell
195+
196+
This cell models dynamics over voltage `v` and a angular driver state/variable `w`; these
197+
two variables result in a dampened oscillatory spiking neuronal cell). In effect, the
198+
resonatoe-and-fire (RAF) model (or "resonator") evolves as a result of two coupled
199+
differential equations. (Note that this cell supports either Euler or RK-2 integration.)
200+
201+
```{eval-rst}
202+
.. autoclass:: ngclearn.components.RAFCell
203+
:noindex:
204+
205+
.. automethod:: advance_state
206+
:noindex:
207+
.. automethod:: reset
208+
:noindex:
209+
```
210+
194211
### The Izhikevich Cell
195212

196213
This cell models dynamics over voltage `v` and a recover variable `w` (where `w`

ngclearn/components/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
## point to standard spiking cell component types
1212
from .neurons.spiking.sLIFCell import SLIFCell
13+
from .neurons.spiking.IFCell import IFCell
1314
from .neurons.spiking.LIFCell import LIFCell
1415
from .neurons.spiking.WTASCell import WTASCell
1516
from .neurons.spiking.quadLIFCell import QuadLIFCell
1617
from .neurons.spiking.adExCell import AdExCell
1718
from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
1819
from .neurons.spiking.izhikevichCell import IzhikevichCell
19-
20-
20+
from .neurons.spiking.RAFCell import RAFCell
2121
## point to transformer/operater component types
2222
from .other.varTrace import VarTrace
2323
from .other.expKernel import ExpKernel
@@ -28,7 +28,7 @@
2828
from .input_encoders.bernoulliCell import BernoulliCell
2929
from .input_encoders.poissonCell import PoissonCell
3030
from .input_encoders.latencyCell import LatencyCell
31-
31+
from .input_encoders.phasorCell import PhasorCell
3232

3333
## point to synapse component types
3434
from .synapses.denseSynapse import DenseSynapse

ngclearn/components/base_monitor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def watch(self, compartment, window_length):
124124
"""
125125
cs, end = self._add_path(compartment.path)
126126

127+
dtype = compartment.value.dtype
127128
shape = compartment.value.shape
128-
new_comp = Compartment(np.zeros(shape))
129-
new_comp_store = Compartment(np.zeros((window_length, *shape)))
129+
new_comp = Compartment(np.zeros(shape, dtype=dtype))
130+
new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype))
130131

131132
comp_key = "*".join(compartment.path.split("/"))
132133
store_comp_key = comp_key + "*store"
@@ -310,4 +311,4 @@ def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None,
310311
for k in range(n):
311312
_ax.plot(vals[:, 0, k])
312313
else:
313-
plot_func(vals, ax=_ax)
314+
plot_func(vals[:, :, 0:n], ax=_ax)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .bernoulliCell import BernoulliCell
22
from .poissonCell import PoissonCell
33
from .latencyCell import LatencyCell
4+
from .phasorCell import PhasorCell

ngclearn/components/input_encoders/bernoulliCell.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from ngclearn.components.jaxComponent import JaxComponent
33
from jax import numpy as jnp, random, jit
44
from ngclearn.utils import tensorstats
5+
from functools import partial
6+
from ngcsimlib.deprecators import deprecate_args
7+
from ngcsimlib.logger import info, warn
58

69
@jit
710
def _update_times(t, s, tols):
@@ -21,25 +24,10 @@ def _update_times(t, s, tols):
2124
_tols = (1. - s) * tols + (s * t)
2225
return _tols
2326

24-
@jit
25-
def _sample_bernoulli(dkey, data):
26-
"""
27-
Samples a Bernoulli spike train on-the-fly
28-
29-
Args:
30-
dkey: JAX key to drive stochasticity/noise
31-
32-
data: sensory data (vector/matrix)
33-
34-
Returns:
35-
binary spikes
36-
"""
37-
s_t = random.bernoulli(dkey, p=data).astype(jnp.float32)
38-
return s_t
39-
4027
class BernoulliCell(JaxComponent):
4128
"""
42-
A Bernoulli cell that produces Bernoulli-distributed spikes on-the-fly.
29+
A Bernoulli cell that produces spikes by sampling a Bernoulli distribution
30+
on-the-fly (to produce data-scaled Bernoulli spike trains).
4331
4432
| --- Cell Input Compartments: ---
4533
| inputs - input (takes in external signals)
@@ -55,7 +43,6 @@ class BernoulliCell(JaxComponent):
5543
n_units: number of cellular entities (neural population size)
5644
"""
5745

58-
# Define Functions
5946
def __init__(self, name, n_units, batch_size=1, **kwargs):
6047
super().__init__(name, **kwargs)
6148

@@ -72,9 +59,9 @@ def __init__(self, name, n_units, batch_size=1, **kwargs):
7259
@staticmethod
7360
def _advance_state(t, key, inputs, tols):
7461
key, *subkeys = random.split(key, 2)
75-
outputs = _sample_bernoulli(subkeys[0], data=inputs)
76-
timeOfLastSpike = _update_times(t, outputs, tols)
77-
return outputs, timeOfLastSpike, key
62+
outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
63+
tols = _update_times(t, outputs, tols)
64+
return outputs, tols, key
7865

7966
@resolver(_advance_state)
8067
def advance_state(self, outputs, tols, key):

ngclearn/components/input_encoders/latencyCell.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ngclearn.utils.model_utils import clamp_min, clamp_max
55
from jax import numpy as jnp, random, jit
66
from functools import partial
7+
from ngcsimlib.logger import info
78

89
@jit
910
def _update_times(t, s, tols):
@@ -47,7 +48,7 @@ def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1.,
4748
projected spike times
4849
"""
4950
_tau = tau
50-
if normalize == True:
51+
if normalize:
5152
_tau = num_steps - 1. - first_spk_t ## linear normalization
5253
#torch.clamp_max((-tau * (data - 1)), -tau * (threshold - 1))
5354
stimes = -_tau * (data - 1.) ## calc raw latency code values
@@ -84,7 +85,7 @@ def _calc_spike_times_nonlinear(data, tau, thr, first_spk_t, eps=1e-7,
8485
stimes = jnp.log(_data / (_data - thr)) * tau ## calc spike times
8586
stimes = stimes + first_spk_t
8687

87-
if normalize == True:
88+
if normalize:
8889
term1 = (stimes - first_spk_t)
8990
term2 = (num_steps - first_spk_t - 1.)
9091
term3 = jnp.max(stimes - first_spk_t)
@@ -148,13 +149,16 @@ class LatencyCell(JaxComponent):
148149
:Note: if this set to True, you will need to choose a useful value
149150
for the "num_steps" argument (>1), depending on how many steps simulated
150151
152+
clip_spikes: should values under threshold be removed/suppressed?
153+
(default: False)
154+
151155
num_steps: number of discrete time steps to consider for normalized latency
152156
code (only useful if "normalize" is set to True) (Default: 1)
153157
"""
154158

155159
# Define Functions
156160
def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
157-
linearize=False, normalize=False, num_steps=1.,
161+
linearize=False, normalize=False, clip_spikes=False, num_steps=1.,
158162
batch_size=1, **kwargs):
159163
super().__init__(name, **kwargs)
160164

@@ -163,6 +167,7 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
163167
self.tau = tau
164168
self.threshold = threshold
165169
self.linearize = linearize
170+
self.clip_spikes = clip_spikes
166171
## normalize latency code s.t. final spike(s) occur w/in num_steps
167172
self.normalize = normalize
168173
self.num_steps = num_steps
@@ -175,17 +180,22 @@ def __init__(self, name, n_units, tau=1., threshold=0.01, first_spike_time=0.,
175180
restVals = jnp.zeros((self.batch_size, self.n_units))
176181
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
177182
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
178-
self.mask = Compartment(restVals, display_name="Mask Variable") # output compartment
183+
self.mask = Compartment(restVals, display_name="Spike Time Mask")
184+
self.clip_mask = Compartment(restVals, display_name="Clip Mask")
179185
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
180186
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
181187
#self.reset()
182188

183189
@staticmethod
184190
def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
185-
normalize, inputs):
191+
normalize, clip_spikes, inputs):
186192
## would call this function before processing a spike train (at start)
187193
data = inputs
188-
if linearize == True: ## linearize spike time calculation
194+
if clip_spikes:
195+
clip_mask = (data < threshold) * 1. ## find values under threshold
196+
else:
197+
clip_mask = data * 0. ## all values allowed to fire spikes
198+
if linearize: ## linearize spike time calculation
189199
stimes = _calc_spike_times_linear(data, tau, threshold,
190200
first_spike_time,
191201
num_steps, normalize)
@@ -196,18 +206,20 @@ def _calc_spike_times(linearize, tau, threshold, first_spike_time, num_steps,
196206
num_steps=num_steps,
197207
normalize=normalize)
198208
targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
199-
return targ_sp_times
209+
return targ_sp_times, clip_mask
200210

201211
@resolver(_calc_spike_times)
202-
def calc_spike_times(self, targ_sp_times):
212+
def calc_spike_times(self, targ_sp_times, clip_mask):
203213
self.targ_sp_times.set(targ_sp_times)
214+
self.clip_mask.set(clip_mask)
204215

205216
@staticmethod
206-
def _advance_state(t, dt, key, inputs, mask, targ_sp_times, tols):
217+
def _advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols):
207218
key, *subkeys = random.split(key, 2)
208219
data = inputs ## get sensory pattern data / features
209220
spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
210221
tols = _update_times(t, spikes, tols)
222+
spikes = spikes * (1. - clip_mask)
211223
return spikes, tols, spk_mask, targ_sp_times, key
212224

213225
@resolver(_advance_state)
@@ -221,14 +233,15 @@ def advance_state(self, outputs, tols, mask, targ_sp_times, key):
221233
@staticmethod
222234
def _reset(batch_size, n_units):
223235
restVals = jnp.zeros((batch_size, n_units))
224-
return (restVals, restVals, restVals, restVals, restVals)
236+
return (restVals, restVals, restVals, restVals, restVals, restVals)
225237

226238
@resolver(_reset)
227-
def reset(self, inputs, outputs, tols, mask, targ_sp_times):
239+
def reset(self, inputs, outputs, tols, mask, clip_mask, targ_sp_times):
228240
self.inputs.set(inputs)
229241
self.outputs.set(outputs)
230242
self.tols.set(tols)
231243
self.mask.set(mask)
244+
self.clip_mask.set(clip_mask)
232245
self.targ_sp_times.set(targ_sp_times)
233246

234247
def save(self, directory, **kwargs):

0 commit comments

Comments
 (0)