Skip to content

Commit 72f5c6a

Browse files
author
Alexander Ororbia
committed
updates to exp-syn/testing lesson
1 parent db4cb81 commit 72f5c6a

File tree

3 files changed

+36
-25
lines changed

3 files changed

+36
-25
lines changed

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,29 @@
1414
#from ngcsimlib.component import Component
1515
from ngcsimlib.compartment import Compartment
1616

17-
#@jit
18-
def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
19-
mask = (rfr >= refract_T) * 1. # get refractory mask
20-
## update voltage / membrane potential
21-
dv_dt = (v_rest - v) * v_decay + (j * mask)
22-
dv_dt = dv_dt * (1./tau_m)
23-
return dv_dt
17+
# def _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay=1.): ## raw voltage dynamics
18+
# mask = (rfr >= refract_T) * 1. # get refractory mask
19+
# ## update voltage / membrane potential
20+
# dv_dt = (v_rest - v) * v_decay + (j * mask)
21+
# dv_dt = dv_dt * (1./tau_m)
22+
# return dv_dt
23+
#
24+
# def _dfv(t, v, params): ## voltage dynamics wrapper
25+
# j, rfr, tau_m, refract_T, v_rest, v_decay = params
26+
# dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay)
27+
# return dv_dt
28+
29+
2430

2531
def _dfv(t, v, params): ## voltage dynamics wrapper
26-
j, rfr, tau_m, refract_T, v_rest, v_decay = params
27-
dv_dt = _dfv_internal(j, v, rfr, tau_m, refract_T, v_rest, v_decay)
32+
j, rfr, tau_m, refract_T, v_rest, g_L = params
33+
mask = (rfr >= refract_T) * 1. # get refractory mask
34+
## update voltage / membrane potential
35+
dv_dt = (v_rest - v) * g_L + (j * mask)
36+
dv_dt = dv_dt * (1. / tau_m)
2837
return dv_dt
2938

39+
3040
#@partial(jit, static_argnums=[3, 4])
3141
def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
3242
### Runs homeostatic threshold update dynamics one step (via Euler integration).
@@ -38,6 +48,7 @@ def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
3848
#_V_theta = V_theta + -V_theta * (dt/tau_theta) + S * alpha
3949
return _v_theta
4050

51+
4152
class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
4253
"""
4354
A spiking cell based on leaky integrate-and-fire (LIF) neuronal dynamics.
@@ -73,14 +84,14 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
7384
thr: base value for adaptive thresholds that govern short-term
7485
plasticity (in milliVolts, or mV; default: -52. mV)
7586
76-
v_rest: membrane resting potential (in mV; default: -65 mV)
87+
v_rest: reversal potential or membrane resting potential (in mV; default: -65 mV)
7788
7889
v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
7990
a neuronal cell's membrane potential will be set to this value;
8091
(default: -60 mV)
8192
82-
v_decay: decay factor applied to voltage leak (Default: 1.); setting this
83-
to 0 mV recovers pure integrate-and-fire (IF) dynamics
93+
conduct_leak: leak conductance (g_L) value or decay factor applied to voltage leak
94+
(Default: 1.); setting this to 0 mV recovers pure integrate-and-fire (IF) dynamics
8495
8596
tau_theta: homeostatic threshold time constant
8697
@@ -116,12 +127,12 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
116127
the value of `v_rest` (default: True)
117128
""" ## batch_size arg?
118129

119-
@deprecate_args(thr_jitter=None)
120-
def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
121-
v_reset=-60., v_decay=1., tau_theta=1e7, theta_plus=0.05,
122-
refract_time=5., one_spike=False, integration_type="euler",
123-
surrogate_type="straight_through", lower_clamp_voltage=True,
124-
**kwargs):
130+
@deprecate_args(thr_jitter=None, v_decay="conduct_leak")
131+
def __init__(
132+
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7,
133+
theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through",
134+
lower_clamp_voltage=True, **kwargs
135+
):
125136
super().__init__(name, **kwargs)
126137

127138
## Integration properties
@@ -136,7 +147,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
136147

137148
self.v_rest = v_rest #-65. # mV
138149
self.v_reset = v_reset # -60. # -65. # mV (milli-volts)
139-
self.v_decay = v_decay ## controls strength of voltage leak (1 -> LIF, 0 => IF)
150+
self.g_L = conduct_leak ## controls strength of voltage leak (1 -> LIF, 0 => IF)
140151
## basic asserts to prevent neuronal dynamics breaking...
141152
#assert (self.v_decay * self.dt / self.tau_m) <= 1. ## <-- to integrate in verify...
142153
assert self.resist_m > 0.
@@ -178,7 +189,7 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
178189
@transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
179190
@staticmethod
180191
def advance_state(
181-
t, dt, tau_m, resist_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, theta_plus,
192+
t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus,
182193
one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
183194
):
184195
skey = None ## this is an empty dkey if single_spike mode turned off
@@ -191,7 +202,7 @@ def advance_state(
191202
_v_thr = thr_theta + thr ## calc present voltage threshold
192203
#mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
193204
## update voltage / membrane potential
194-
v_params = (j, rfr, tau_m, refract_T, v_rest, v_decay)
205+
v_params = (j, rfr, tau_m, refract_T, v_rest, g_L)
195206
if intgFlag == 1:
196207
_, _v = step_rk2(0., v, _dfv, dt, v_params)
197208
else:

ngclearn/components/synapses/exponentialSynapse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def advance_state(
8888
_out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
8989
dgsyn_dt = _out * g_syn_bar - g_syn/tau_syn
9090
g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance
91-
i_syn = -g_syn * (v - syn_rest)
91+
i_syn = -g_syn * (v - syn_rest)
9292
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
9393
return outputs, i_syn, g_syn
9494

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "ngclearn"
7-
version = "2.0.0"
7+
version = "2.0.1"
88
description = "Simulation software for building and analyzing arbitrary predictive coding, spiking network, and biomimetic neural systems."
99
authors = [
1010
{name = "Alexander Ororbia", email = "[email protected]"},
@@ -14,13 +14,13 @@ readme = "README.md"
1414
keywords = ['python', 'ngc-learn', 'predictive-processing', 'predictive-coding', 'neuro-ai',
1515
'jax', 'spiking-neural-networks', 'biomimetics', 'bionics', 'computational-neuroscience']
1616
requires-python = ">=3.10" #3.8
17-
license = {text = "BSD-3-Clause License"}
17+
license = "BSD-3-Clause" # {text = "BSD-3-Clause License"}
1818
classifiers=[
1919
"Development Status :: 4 - Beta", #3 - Alpha", # 5 - Production/Stable
2020
"Intended Audience :: Education",
2121
"Intended Audience :: Science/Research",
2222
"Intended Audience :: Developers",
23-
"License :: OSI Approved :: BSD License",
23+
#"License :: OSI Approved :: BSD License",
2424
"Topic :: Scientific/Engineering",
2525
"Topic :: Scientific/Engineering :: Mathematics",
2626
"Topic :: Scientific/Engineering :: Artificial Intelligence",

0 commit comments

Comments
 (0)