Skip to content

Commit 33e0cc1

Browse files
author
Alexander Ororbia
committed
refactored lava components to new sim-lib
1 parent 464ab10 commit 33e0cc1

File tree

5 files changed

+57
-97
lines changed

5 files changed

+57
-97
lines changed

ngclearn/components/lava/neurons/LIFCell.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from ngclearn import resolver, Component, Compartment
2-
from ngclearn.utils import tensorstats
31
from ngclearn import numpy as jnp
4-
from ngclearn.utils.weight_distribution import initialize_params
52
from ngcsimlib.logger import info, warn
3+
from ngcsimlib.compilers.process import transition
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.utils import tensorstats
69

710
class LIFCell(Component): ## Lava-compliant leaky integrate-and-fire cell
811
"""
@@ -113,8 +116,9 @@ def __init__(self, name, n_units, dt, tau_m, thr_theta_init=None, resist_m=1.,
113116
def _init(self, thr_theta0):
114117
self.thr_theta.set(thr_theta0)
115118

119+
@transition(output_compartments=["v", "s", "rfr", "thr_theta"])
116120
@staticmethod
117-
def _advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta,
121+
def advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta,
118122
theta_plus, j_exc, j_inh, v, s, rfr, thr_theta):
119123
#j = j * (tau_m/dt) ## scale electrical current
120124
j = j_exc - j_inh ## sum the excitatory and inhibitory input channels
@@ -136,16 +140,9 @@ def _advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau
136140
#tols = (1. - s) * tols + (s * t)
137141
return v, s, rfr, thr_theta #, tols
138142

139-
@resolver(_advance_state)
140-
def advance_state(self, v, s, rfr, thr_theta): #, tols):
141-
self.v.set(v)
142-
self.s.set(s)
143-
self.rfr.set(rfr)
144-
self.thr_theta.set(thr_theta)
145-
#self.tols.set(tols)
146-
143+
@transition(output_compartments=["j_exc", "j_inh", "v", "s", "rfr"])
147144
@staticmethod
148-
def _reset(batch_size, n_units, v_rest, refract_T):
145+
def reset(batch_size, n_units, v_rest, refract_T):
149146
restVals = jnp.zeros((batch_size, n_units))
150147
j_exc = restVals #+ 0
151148
j_inh = restVals #+ 0
@@ -154,14 +151,6 @@ def _reset(batch_size, n_units, v_rest, refract_T):
154151
rfr = restVals + refract_T
155152
return j_exc, j_inh, v, s, rfr #, tols
156153

157-
@resolver(_reset)
158-
def reset(self, j_exc, j_inh, v, s, rfr):#, tols):
159-
self.j_exc.set(j_exc)
160-
self.j_inh.set(j_inh)
161-
self.v.set(v)
162-
self.s.set(s)
163-
self.rfr.set(rfr)
164-
165154
def save(self, directory, **kwargs):
166155
file_name = directory + "/" + self.name + ".npz"
167156
jnp.savez(file_name,

ngclearn/components/lava/synapses/hebbianSynapse.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from ngclearn import resolver, Component, Compartment
2-
from ngclearn.utils import tensorstats
31
from ngclearn import numpy as jnp
4-
from ngclearn.utils.weight_distribution import initialize_params
52
from ngcsimlib.logger import info, warn
3+
from ngcsimlib.compilers.process import transition
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.utils import tensorstats
69

710
class HebbianSynapse(Component): ## Lava-compliant Hebbian synapse
811
"""
9-
A synaptic cable that adjusts its efficacies via a two-factor Hebbian
10-
adjustment rule. This is a Lava-compliant synaptic cable that adjusts
11-
with a hard-coded form of (stochastic) gradient ascent.
12+
A synaptic cable that adjusts its efficacies via a two-factor Hebbian adjustment rule. This is a Lava-compliant
13+
synaptic cable that adjusts with a hard-coded form of (stochastic) gradient ascent.
1214
1315
| --- Synapse Input Compartments: (Takes wired-in signals) ---
1416
| inputs - input (pre-synaptic) stimulus
@@ -102,8 +104,9 @@ def _init(self, weights):
102104
self.post.set(postVals)
103105
self.weights.set(weights)
104106

107+
@transition(output_compartments=["outputs", "weights"])
105108
@staticmethod
106-
def _advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights,
109+
def advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights,
107110
pre, post, eta):
108111
outputs = jnp.matmul(inputs, weights) * Rscale
109112
########################################################################
@@ -119,13 +122,9 @@ def _advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights,
119122
########################################################################
120123
return outputs, weights
121124

122-
@resolver(_advance_state)
123-
def advance_state(self, outputs, weights):
124-
self.outputs.set(outputs)
125-
self.weights.set(weights)
126-
125+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "eta"])
127126
@staticmethod
128-
def _reset(batch_size, rows, cols, eta0):
127+
def reset(batch_size, rows, cols, eta0):
129128
preVals = jnp.zeros((batch_size, rows))
130129
postVals = jnp.zeros((batch_size, cols))
131130
return (
@@ -136,14 +135,6 @@ def _reset(batch_size, rows, cols, eta0):
136135
jnp.ones((1,1)) * eta0
137136
)
138137

139-
@resolver(_reset)
140-
def reset(self, inputs, outputs, pre, post, eta):
141-
self.inputs.set(inputs)
142-
self.outputs.set(outputs)
143-
self.pre.set(pre)
144-
self.post.set(post)
145-
self.eta.set(eta)
146-
147138
def save(self, directory, **kwargs):
148139
file_name = directory + "/" + self.name + ".npz"
149140
jnp.savez(file_name, weights=self.weights.value)

ngclearn/components/lava/synapses/staticSynapse.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from ngclearn import numpy as jnp
2-
from ngclearn import resolver, Component, Compartment
3-
from ngclearn.utils import tensorstats
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
45
from ngclearn.utils.weight_distribution import initialize_params
56
from ngcsimlib.logger import info, warn
7+
from ngclearn.components.synapses.hebbian import TraceSTDPSynapse
8+
from ngclearn.utils import tensorstats
69

710
class StaticSynapse(Component): ## Lava-compliant fixed/non-evolvable synapse
811
"""
9-
A static (dense) synaptic cable; no form of synaptic evolution/adaptation
10-
is in-built to this component. This a Lava-compliant version of the
11-
static synapse component from the synapses sub-package of components.
12+
A static (dense) synaptic cable; no form of synaptic evolution/adaptation is in-built to this component. This a
13+
Lava-compliant version of the static synapse component from the synapses sub-package of components.
1214
1315
| --- Synapse Input Compartments: (Takes wired-in signals) ---
1416
| inputs - input (pre-synaptic) stimulus
@@ -79,29 +81,22 @@ def _init(self, weights):
7981
self.outputs.set(postVals)
8082
self.weights.set(weights)
8183

84+
@transition(output_compartments=["outputs"])
8285
@staticmethod
83-
def _advance_state(dt, Rscale, inputs, weights):
86+
def advance_state(dt, Rscale, inputs, weights):
8487
outputs = jnp.matmul(inputs, weights) * Rscale
8588
return outputs
8689

87-
@resolver(_advance_state)
88-
def advance_state(self, outputs):
89-
self.outputs.set(outputs)
90-
90+
@transition(output_compartments=["inputs", "outputs"])
9191
@staticmethod
92-
def _reset(batch_size, rows, cols):
92+
def reset(batch_size, rows, cols):
9393
preVals = jnp.zeros((batch_size, rows))
9494
postVals = jnp.zeros((batch_size, cols))
9595
return (
9696
preVals, # inputs
9797
postVals, # outputs
9898
)
9999

100-
@resolver(_reset)
101-
def reset(self, inputs, outputs):
102-
self.inputs.set(inputs)
103-
self.outputs.set(outputs)
104-
105100
def save(self, directory, **kwargs):
106101
file_name = directory + "/" + self.name + ".npz"
107102
jnp.savez(file_name, weights=self.weights.value)

ngclearn/components/lava/synapses/traceSTDPSynapse.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from ngclearn import resolver, Component, Compartment
2-
from ngclearn.utils import tensorstats
31
from ngclearn import numpy as jnp
4-
from ngclearn.utils.weight_distribution import initialize_params
52
from ngcsimlib.logger import info, warn
3+
from ngcsimlib.compilers.process import transition
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.utils import tensorstats
69

710
class TraceSTDPSynapse(Component): ## Lava-compliant trace-STDP synapse
811
"""
9-
A synaptic cable that adjusts its efficacies via trace-based form of
10-
spike-timing-dependent plasticity (STDP). This is a Lava-compliant synaptic
11-
cable that adjusts with a hard-coded form of (stochastic) gradient ascent.
12+
A synaptic cable that adjusts its efficacies via trace-based form of spike-timing-dependent plasticity (STDP).
13+
This is a Lava-compliant synaptic cable that adjusts with a hard-coded form of (stochastic) gradient ascent.
1214
1315
| --- Synapse Input Compartments: (Takes wired-in signals) ---
1416
| inputs - input (pre-synaptic) stimulus
@@ -120,8 +122,9 @@ def _init(self, weights):
120122
self.x_post.set(postVals) ## post-synaptic trace
121123
self.weights.set(weights)
122124

125+
@transition(output_compartments=["outputs", "weights"])
123126
@staticmethod
124-
def _advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar,
127+
def advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar,
125128
inputs, weights, pre, x_pre, post, x_post, eta):
126129
outputs = jnp.matmul(inputs, weights) * Rscale
127130
########################################################################
@@ -139,13 +142,9 @@ def _advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar,
139142
########################################################################
140143
return outputs, weights
141144

142-
@resolver(_advance_state)
143-
def advance_state(self, outputs, weights):
144-
self.outputs.set(outputs)
145-
self.weights.set(weights)
146-
145+
@transition(output_compartments=["inputs", "outputs", "pre", "post", "x_pre", "x_post", "eta"])
147146
@staticmethod
148-
def _reset(batch_size, rows, cols, eta0):
147+
def reset(batch_size, rows, cols, eta0):
149148
preVals = jnp.zeros((batch_size, rows))
150149
postVals = jnp.zeros((batch_size, cols))
151150
return (
@@ -158,16 +157,6 @@ def _reset(batch_size, rows, cols, eta0):
158157
jnp.ones((1, 1)) * eta0
159158
)
160159

161-
@resolver(_reset)
162-
def reset(self, inputs, outputs, pre, post, x_pre, x_post, eta):
163-
self.inputs.set(inputs)
164-
self.outputs.set(outputs)
165-
self.pre.set(pre)
166-
self.post.set(post)
167-
self.x_pre.set(x_pre)
168-
self.x_post.set(x_post)
169-
self.eta.set(eta)
170-
171160
def save(self, directory, **kwargs):
172161
file_name = directory + "/" + self.name + ".npz"
173162
jnp.savez(file_name, weights=self.weights.value)

ngclearn/components/lava/traces/gatedTrace.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from ngclearn import resolver, Component, Compartment
2-
from ngclearn.utils import tensorstats
3-
41
from ngclearn import numpy as jnp
5-
import time, sys
2+
from ngcsimlib.logger import info, warn
3+
from ngcsimlib.compilers.process import transition
4+
from ngcsimlib.component import Component
5+
from ngcsimlib.compartment import Compartment
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
8+
from ngclearn.utils import tensorstats
69

710
class GatedTrace(Component): ## gated/piecewise low-pass filter
811
"""
9-
A gated/piecewise variable trace (filter).
12+
A gated/piecewise variable trace (filter). This is a Lava-compliant trace component.
1013
1114
| --- Cell Input Compartments: (Takes wired-in signals) ---
1215
| inputs - input (takes wired-in external signals)
@@ -39,25 +42,18 @@ def __init__(self, name, n_units, dt, tau_tr, **kwargs):
3942
self.inputs = Compartment(restVals) # input compartment
4043
self.trace = Compartment(restVals)
4144

45+
@transition(output_compartments=["trace"])
4246
@staticmethod
43-
def _advance_state(dt, tau_tr, inputs, trace):
47+
def advance_state(dt, tau_tr, inputs, trace):
4448
trace = (trace * (1. - dt/tau_tr)) * (1. - inputs) + inputs
4549
return trace
4650

47-
@resolver(_advance_state)
48-
def advance_state(self, trace):
49-
self.trace.set(trace)
50-
51+
@transition(output_compartments=["inputs", "trace"])
5152
@staticmethod
52-
def _reset(batch_size, n_units):
53+
def reset(batch_size, n_units):
5354
restVals = jnp.zeros((batch_size, n_units))
5455
return restVals, restVals
5556

56-
@resolver(_reset)
57-
def reset(self, inputs, trace):
58-
self.inputs.set(inputs)
59-
self.trace.set(trace)
60-
6157
def __repr__(self):
6258
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
6359
maxlen = max(len(c) for c in comps) + 5

0 commit comments

Comments
 (0)