Skip to content

Commit 33599fd

Browse files
committed
Merge branch 'v3' of github.com:NACLab/ngc-learn into v3
2 parents 058bf90 + 947d2cf commit 33599fd

24 files changed

+644
-794
lines changed

ngclearn/components/__init__.py

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,65 @@
1-
# from .jaxComponent import JaxComponent
2-
#
3-
# ## point to rate-coded cell component types
4-
# from .neurons.graded.rateCell import RateCell
5-
# from .neurons.graded.gaussianErrorCell import GaussianErrorCell
6-
# from .neurons.graded.laplacianErrorCell import LaplacianErrorCell
7-
# from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell
8-
# from .neurons.graded.rewardErrorCell import RewardErrorCell
9-
#
10-
# ## point to standard spiking cell component types
11-
# from .neurons.spiking.sLIFCell import SLIFCell
12-
# from .neurons.spiking.IFCell import IFCell
13-
# from .neurons.spiking.LIFCell import LIFCell
14-
# from .neurons.spiking.WTASCell import WTASCell
15-
# from .neurons.spiking.quadLIFCell import QuadLIFCell
16-
# from .neurons.spiking.adExCell import AdExCell
17-
# from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
18-
# from .neurons.spiking.izhikevichCell import IzhikevichCell
19-
# from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
20-
# from .neurons.spiking.RAFCell import RAFCell
21-
#
22-
# ## point to transformer/operator component types
23-
# from .other.varTrace import VarTrace
24-
# from .other.expKernel import ExpKernel
25-
#
26-
# ## point to input encoder component types
27-
# from .input_encoders.bernoulliCell import BernoulliCell
28-
# from .input_encoders.poissonCell import PoissonCell
29-
# from .input_encoders.latencyCell import LatencyCell
30-
# from .input_encoders.phasorCell import PhasorCell
31-
#
32-
# ## point to synapse component types
33-
# from .synapses.denseSynapse import DenseSynapse
34-
# from .synapses.staticSynapse import StaticSynapse
35-
# from .synapses.hebbian.hebbianSynapse import HebbianSynapse
36-
# from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
37-
# from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
38-
# from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
39-
# from .synapses.hebbian.BCMSynapse import BCMSynapse
40-
# from .synapses.STPDenseSynapse import STPDenseSynapse
41-
# from .synapses.exponentialSynapse import ExponentialSynapse
42-
# from .synapses.doubleExpSynapse import DoupleExpSynapse
43-
# from .synapses.alphaSynapse import AlphaSynapse
44-
#
45-
# ## point to convolutional component types
46-
# from .synapses.convolution.convSynapse import ConvSynapse
47-
# from .synapses.convolution.staticConvSynapse import StaticConvSynapse
48-
# from .synapses.convolution.hebbianConvSynapse import HebbianConvSynapse
49-
# from .synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
50-
# from .synapses.convolution.deconvSynapse import DeconvSynapse
51-
# from .synapses.convolution.staticDeconvSynapse import StaticDeconvSynapse
52-
# from .synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
53-
# from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
54-
# ## point to modulated component types
55-
# from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
56-
# from .synapses.modulated.REINFORCESynapse import REINFORCESynapse
57-
#
58-
# ## point to monitors
59-
# from .monitor import Monitor
60-
#
61-
# ## point to patched component types
62-
# from .synapses.patched.patchedSynapse import PatchedSynapse
63-
# from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse
64-
# from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse
65-
#
1+
from .jaxComponent import JaxComponent
2+
3+
## point to rate-coded cell component types
4+
from .neurons.graded.rateCell import RateCell
5+
from .neurons.graded.gaussianErrorCell import GaussianErrorCell
6+
from .neurons.graded.laplacianErrorCell import LaplacianErrorCell
7+
from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell
8+
from .neurons.graded.rewardErrorCell import RewardErrorCell
9+
10+
## point to standard spiking cell component types
11+
from .neurons.spiking.sLIFCell import SLIFCell
12+
from .neurons.spiking.IFCell import IFCell
13+
from .neurons.spiking.LIFCell import LIFCell
14+
from .neurons.spiking.WTASCell import WTASCell
15+
from .neurons.spiking.quadLIFCell import QuadLIFCell
16+
from .neurons.spiking.adExCell import AdExCell
17+
from .neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
18+
from .neurons.spiking.izhikevichCell import IzhikevichCell
19+
from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
20+
from .neurons.spiking.RAFCell import RAFCell
21+
22+
## point to transformer/operator component types
23+
from .other.varTrace import VarTrace
24+
from .other.expKernel import ExpKernel
25+
26+
## point to input encoder component types
27+
from .input_encoders.bernoulliCell import BernoulliCell
28+
from .input_encoders.poissonCell import PoissonCell
29+
from .input_encoders.latencyCell import LatencyCell
30+
from .input_encoders.phasorCell import PhasorCell
31+
32+
## point to synapse component types
33+
from .synapses.denseSynapse import DenseSynapse
34+
from .synapses.staticSynapse import StaticSynapse
35+
from .synapses.hebbian.hebbianSynapse import HebbianSynapse
36+
from .synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
37+
from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
38+
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
39+
from .synapses.hebbian.BCMSynapse import BCMSynapse
40+
from .synapses.STPDenseSynapse import STPDenseSynapse
41+
from .synapses.exponentialSynapse import ExponentialSynapse
42+
from .synapses.doubleExpSynapse import DoupleExpSynapse
43+
from .synapses.alphaSynapse import AlphaSynapse
44+
45+
## point to convolutional component types
46+
from .synapses.convolution.convSynapse import ConvSynapse
47+
from .synapses.convolution.staticConvSynapse import StaticConvSynapse
48+
from .synapses.convolution.hebbianConvSynapse import HebbianConvSynapse
49+
from .synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
50+
from .synapses.convolution.deconvSynapse import DeconvSynapse
51+
from .synapses.convolution.staticDeconvSynapse import StaticDeconvSynapse
52+
from .synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
53+
from .synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
54+
## point to modulated component types
55+
from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
56+
from .synapses.modulated.REINFORCESynapse import REINFORCESynapse
57+
58+
## point to monitors
59+
from .monitor import Monitor
60+
61+
## point to patched component types
62+
from .synapses.patched.patchedSynapse import PatchedSynapse
63+
from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse
64+
from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse
65+
Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
# ## point to rate-coded cell componet types
2-
# from .graded.rateCell import RateCell
3-
# from .graded.gaussianErrorCell import GaussianErrorCell
4-
# from .graded.laplacianErrorCell import LaplacianErrorCell
5-
# from .graded.bernoulliErrorCell import BernoulliErrorCell
6-
# from .graded.rewardErrorCell import RewardErrorCell
7-
# ## point to standard spiking cell component types
8-
# from .spiking.sLIFCell import SLIFCell
9-
# from .spiking.IFCell import IFCell
10-
# from .spiking.LIFCell import LIFCell
11-
# from .spiking.WTASCell import WTASCell
12-
# from .spiking.quadLIFCell import QuadLIFCell
13-
# from .spiking.adExCell import AdExCell
14-
# from .spiking.fitzhughNagumoCell import FitzhughNagumoCell
15-
# from .spiking.izhikevichCell import IzhikevichCell
16-
# from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
17-
# from .spiking.RAFCell import RAFCell
1+
## point to rate-coded cell componet types
2+
from .graded.rateCell import RateCell
3+
from .graded.gaussianErrorCell import GaussianErrorCell
4+
from .graded.laplacianErrorCell import LaplacianErrorCell
5+
from .graded.bernoulliErrorCell import BernoulliErrorCell
6+
from .graded.rewardErrorCell import RewardErrorCell
7+
## point to standard spiking cell component types
8+
from .spiking.sLIFCell import SLIFCell
9+
from .spiking.IFCell import IFCell
10+
from .spiking.LIFCell import LIFCell
11+
from .spiking.WTASCell import WTASCell
12+
from .spiking.quadLIFCell import QuadLIFCell
13+
from .spiking.adExCell import AdExCell
14+
from .spiking.fitzhughNagumoCell import FitzhughNagumoCell
15+
from .spiking.izhikevichCell import IzhikevichCell
16+
from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
17+
from .spiking.RAFCell import RAFCell

ngclearn/components/neurons/spiking/IFCell.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from ngclearn.components.jaxComponent import JaxComponent
2-
from jax import numpy as jnp, random, jit, nn
3-
from functools import partial
2+
from jax import numpy as jnp, random, nn, Array, jit
43
from ngclearn.utils import tensorstats
54
from ngcsimlib import deprecate_args
6-
from ngcsimlib.logger import info, warn
75
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
86
step_euler, step_rk2
9-
# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
10-
# triangular_estimator,
11-
# straight_through_estimator)
7+
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
8+
triangular_estimator,
9+
straight_through_estimator)
1210

1311
from ngcsimlib.parser import compilable
1412
from ngcsimlib.compartment import Compartment
@@ -89,7 +87,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
8987
the value of `v_rest` (default: True)
9088
"""
9189

92-
@deprecate_args(thr_jitter=None)
90+
#@deprecate_args(thr_jitter=None)
9391
def __init__(
9492
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0.,
9593
integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs
@@ -135,7 +133,7 @@ def __init__(
135133
display_name="Refractory Time Period", units="ms")
136134
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
137135
units="ms") ## time-of-last-spike
138-
self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
136+
#self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
139137

140138
@compilable
141139
def advance_state(
@@ -180,33 +178,6 @@ def reset(self):
180178
self.tols.set(restVals)
181179
#surrogate = restVals + 1.
182180

183-
def save(self, directory, **kwargs):
184-
## do a protected save of constants, depending on whether they are floats or arrays
185-
tau_m = (self.tau_m if isinstance(self.tau_m, float)
186-
else jnp.asarray([[self.tau_m * 1.]]))
187-
thr = (self.thr if isinstance(self.thr, float)
188-
else jnp.asarray([[self.thr * 1.]]))
189-
v_rest = (self.v_rest if isinstance(self.v_rest, float)
190-
else jnp.asarray([[self.v_rest * 1.]]))
191-
v_reset = (self.v_reset if isinstance(self.v_reset, float)
192-
else jnp.asarray([[self.v_reset * 1.]]))
193-
v_decay = (self.v_decay if isinstance(self.v_decay, float)
194-
else jnp.asarray([[self.v_decay * 1.]]))
195-
resist_m = (self.resist_m if isinstance(self.resist_m, float)
196-
else jnp.asarray([[self.resist_m * 1.]]))
197-
tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
198-
else jnp.asarray([[self.tau_theta * 1.]]))
199-
theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
200-
else jnp.asarray([[self.theta_plus * 1.]]))
201-
202-
file_name = directory + "/" + self.name + ".npz"
203-
jnp.savez(file_name,
204-
tau_m=tau_m, thr=thr, v_rest=v_rest,
205-
v_reset=v_reset, v_decay=v_decay,
206-
resist_m=resist_m, tau_theta=tau_theta,
207-
theta_plus=theta_plus,
208-
key=self.key.value)
209-
210181
def load(self, directory, seeded=False, **kwargs):
211182
file_name = directory + "/" + self.name + ".npz"
212183
data = jnp.load(file_name)

ngclearn/components/neurons/spiking/LIFCell.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ngclearn.components.jaxComponent import JaxComponent
2-
from jax import numpy as jnp, random, nn
2+
from jax import numpy as jnp, random, nn, Array
3+
from ngclearn.utils import tensorstats
34
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
45
step_euler, step_rk2
56
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
@@ -151,20 +152,15 @@ def __init__(
151152
# else: ## default: straight_through
152153
# spike_fx, d_spike_fx = straight_through_estimator()
153154

154-
155155
## Compartment setup
156156
restVals = jnp.zeros((self.batch_size, self.n_units))
157157
self.j = Compartment(restVals, display_name="Current", units="mA")
158-
self.v = Compartment(restVals + self.v_rest,
159-
display_name="Voltage", units="mV")
158+
self.v = Compartment(restVals + self.v_rest, display_name="Voltage", units="mV")
160159
self.s = Compartment(restVals, display_name="Spikes")
161160
self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses")
162-
self.rfr = Compartment(restVals + self.refract_T,
163-
display_name="Refractory Time Period", units="ms")
164-
self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift",
165-
units="mV")
166-
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
167-
units="ms") ## time-of-last-spike
161+
self.rfr = Compartment(restVals + self.refract_T, display_name="Refractory Time Period", units="ms")
162+
self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift", units="mV")
163+
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike
168164
# self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
169165

170166
@compilable
@@ -205,7 +201,7 @@ def advance_state(self, dt, t):
205201
thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
206202
self.thr_theta.set(thr_theta)
207203

208-
## update tols
204+
## update time-of-last spike variable(s)
209205
self.tols.set((1. - s) * self.tols.get() + (s * t))
210206

211207
if self.v_min is not None: ## ensures voltage never < v_rest
@@ -258,24 +254,33 @@ def help(cls): ## component help function
258254
"v_reset": "Reset membrane potential value",
259255
"conduct_leak": "Conductance leak / voltage decay factor",
260256
"tau_theta": "Threshold/homoestatic increment time constant",
261-
"theta_plus": "Amount to increment threshold by upon occurrence "
262-
"of spike",
257+
"theta_plus": "Amount to increment threshold by upon occurrence of a spike",
263258
"refract_time": "Length of relative refractory period (ms)",
264-
"one_spike": "Should only one spike be sampled/allowed to emit at "
265-
"any given time step?",
266-
"integration_type": "Type of numerical integration to use for the "
267-
"cell dynamics",
259+
"one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
260+
"integration_type": "Type of numerical integration to use for the cell dynamics",
268261
"surrgoate_type": "Type of surrogate function to use approximate "
269262
"derivative of spike w.r.t. voltage/current",
270-
"lower_bound_clamp": "Should voltage be lower bounded to be never "
271-
"be below `v_rest`"
263+
"v_min": "Minimum voltage allowed before voltage variables are min-clipped/clamped"
272264
}
273265
info = {cls.__name__: properties,
274266
"compartments": compartment_props,
275267
"dynamics": "tau_m * dv/dt = (v_rest - v) + j * resist_m",
276268
"hyperparameters": hyperparams}
277269
return info
278270

271+
def __repr__(self):
272+
comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
273+
maxlen = max(len(c) for c in comps) + 5
274+
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
275+
for c in comps:
276+
stats = tensorstats(getattr(self, c).value)
277+
if stats is not None:
278+
line = [f"{k}: {v}" for k, v in stats.items()]
279+
line = ", ".join(line)
280+
else:
281+
line = "None"
282+
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
283+
return lines
279284

280285
if __name__ == '__main__':
281286
from ngcsimlib.context import Context

0 commit comments

Comments
 (0)