Skip to content

Commit ebdea3e

Browse files
author
Alexander Ororbia
committed
refactored event-stdp-syn and test passed
1 parent 99b3c43 commit ebdea3e

File tree

3 files changed

+61
-99
lines changed

3 files changed

+61
-99
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#from .hebbianSynapse import HebbianSynapse
22
from .traceSTDPSynapse import TraceSTDPSynapse
33
from .expSTDPSynapse import ExpSTDPSynapse
4-
#from .eventSTDPSynapse import EventSTDPSynapse
4+
from .eventSTDPSynapse import EventSTDPSynapse
55
from .BCMSynapse import BCMSynapse
66

ngclearn/components/synapses/hebbian/eventSTDPSynapse.py

Lines changed: 39 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from jax import numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
1+
from jax import random, numpy as jnp, jit
42
from ngcsimlib.compartment import Compartment
3+
from ngcsimlib.parser import compilable
54

6-
from ngclearn.components.synapses import DenseSynapse
7-
from ngclearn.utils import tensorstats
5+
from ngclearn.components.synapses.denseSynapse import DenseSynapse
86

97
class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
108
"""
@@ -57,11 +55,11 @@ class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
5755
"""
5856

5957
# Define Functions
60-
def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
61-
presyn_win_len=2., w_bound=1., weight_init=None, resist_scale=1.,
62-
p_conn=1., batch_size=1, **kwargs):
63-
super().__init__(name, shape, weight_init, None, resist_scale, p_conn,
64-
batch_size=batch_size, **kwargs)
58+
def __init__(
59+
self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., presyn_win_len=2., w_bound=1., weight_init=None,
60+
resist_scale=1., p_conn=1., batch_size=1, **kwargs
61+
):
62+
super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
6563

6664
## Synaptic hyper-parameters
6765
self.eta = eta ## global learning rate governing plasticity
@@ -78,53 +76,47 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
7876
postVals = jnp.zeros((self.batch_size, shape[1]))
7977
self.pre_tols = Compartment(preVals)
8078
self.postSpike = Compartment(postVals)
81-
self.dWeights = Compartment(self.weights.value * 0)
79+
self.dWeights = Compartment(self.weights.get() * 0)
8280
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity
8381

84-
@staticmethod
85-
def _compute_update(
86-
t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights
87-
): ## synaptic adjustment calculation co-routine
82+
def _compute_update(self, t, dt): ## synaptic adjustment calculation co-routine
8883
## check if a spike occurred in window of (t - presyn_win_len, t]
89-
m = (pre_tols > 0.) * 1. ## ignore default value of tols = 0 ms
90-
if presyn_win_len > 0.:
91-
lbound = ((t - presyn_win_len) < pre_tols) * 1.
84+
m = (self.pre_tols.get() > 0.) * 1. ## ignore default value of tols = 0 ms
85+
if self.presyn_win_len > 0.:
86+
lbound = ((t - self.presyn_win_len) < self.pre_tols.get()) * 1.
9287
preSpike = lbound * m
9388
else:
94-
check_spike = (pre_tols == t) * 1.
89+
check_spike = (self.pre_tols.get() == t) * 1.
9590
preSpike = check_spike * m
9691
## this implements a generalization of the rule in eqn 18 of the paper
97-
pos_shift = w_bound - (weights * (1. + lmbda))
98-
pos_shift = pos_shift * Aplus
99-
neg_shift = -weights * (1. + lmbda)
100-
neg_shift = neg_shift * Aminus
92+
pos_shift = self.w_bound - (self.weights.get() * (1. + self.lmbda))
93+
pos_shift = pos_shift * self.Aplus
94+
neg_shift = -self.weights.get() * (1. + self.lmbda)
95+
neg_shift = neg_shift * self.Aminus
10196
dW = jnp.where(preSpike.T, pos_shift, neg_shift) # at pre-spikes => LTP, else decay
102-
dW = (dW * postSpike) ## gate to make sure only post-spikes trigger updates
97+
dW = (dW * self.postSpike.get()) ## gate to make sure only post-spikes trigger updates
10398
return dW
10499

105-
@transition(output_compartments=["weights", "dWeights"])
106-
@staticmethod
107-
def evolve(
108-
t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights, eta
109-
):
110-
dWeights = EventSTDPSynapse._compute_update(
111-
t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights
112-
)
113-
weights = weights + dWeights * eta # * (1. - w) * eta
114-
weights = jnp.clip(weights, 0.01, w_bound) ## Note: this step not in source paper
115-
return weights, dWeights
116-
117-
@transition(output_compartments=["inputs", "outputs", "pre_tols", "postSpike", "dWeights"])
118-
@staticmethod
119-
def reset(batch_size, shape):
120-
preVals = jnp.zeros((batch_size, shape[0]))
121-
postVals = jnp.zeros((batch_size, shape[1]))
122-
inputs = preVals
123-
outputs = postVals
124-
pre_tols = preVals ## pre-synaptic time-of-last-spike(s) record
125-
postSpike = postVals
126-
dWeights = jnp.zeros(shape)
127-
return inputs, outputs, pre_tols, postSpike, dWeights
100+
@compilable
101+
def evolve(self, t, dt):
102+
dWeights = self._compute_update(t, dt)
103+
weights = self.weights.get() + dWeights * self.eta # * (1. - w) * eta
104+
weights = jnp.clip(weights, 0.01, self.w_bound) ## Note: this step not in source paper
105+
106+
self.weights.set(weights)
107+
self.dWeights.set(dWeights)
108+
109+
@compilable
110+
def reset(self):
111+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
112+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
113+
114+
if not self.inputs.targeted:
115+
self.inputs.set(preVals)
116+
self.outputs.set(postVals)
117+
self.pre_tols.set(preVals) ## pre-synaptic time-of-last-spike(s) record
118+
self.postSpike.set(postVals)
119+
self.dWeights.set(jnp.zeros(self.shape.get()))
128120

129121
@classmethod
130122
def help(cls): ## component help function
@@ -166,20 +158,6 @@ def help(cls): ## component help function
166158
"hyperparameters": hyperparams}
167159
return info
168160

169-
def __repr__(self):
170-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
171-
maxlen = max(len(c) for c in comps) + 5
172-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
173-
for c in comps:
174-
stats = tensorstats(getattr(self, c).value)
175-
if stats is not None:
176-
line = [f"{k}: {v}" for k, v in stats.items()]
177-
line = ", ".join(line)
178-
else:
179-
line = "None"
180-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
181-
return lines
182-
183161
if __name__ == '__main__':
184162
from ngcsimlib.context import Context
185163
with Context("Bar") as bar:

tests/components/synapses/hebbian/test_eventSTDPSynapse.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22
from ngcsimlib.context import Context
33
import numpy as np
44
np.random.seed(42)
5-
from ngclearn.components import EventSTDPSynapse
6-
from ngcsimlib.compilers import compile_command, wrap_command
7-
from numpy.testing import assert_array_equal
85

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
6+
from ngclearn import Context, MethodProcess
7+
import ngclearn.utils.weight_distribution as dist
8+
from ngclearn.components.synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
9+
from numpy.testing import assert_array_equal
1310

1411
def test_eventSTDPSynapse1():
1512
name = "event_stdp_ctx"
@@ -24,60 +21,47 @@ def test_eventSTDPSynapse1():
2421
name="a", shape=(1,1), eta=0., presyn_win_len=2., key=subkeys[0]
2522
)
2623

27-
#"""
28-
evolve_process = (Process("evolve_proc")
29-
>> a.evolve)
30-
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
31-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
24+
evolve_process = (MethodProcess("evolve_process")
25+
>> a.evolve)
3226

33-
advance_process = (Process("advance_proc")
27+
advance_process = (MethodProcess("advance_proc")
3428
>> a.advance_state)
35-
# ctx.wrap_and_add_command(advance_process.pure, name="run")
36-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3729

38-
reset_process = (Process("reset_proc")
30+
reset_process = (MethodProcess("reset_proc")
3931
>> a.reset)
40-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
41-
#"""
4232

43-
"""
44-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
45-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
46-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
47-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
48-
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
49-
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
50-
"""
5133
a.weights.set(jnp.ones((1, 1)) * 0.1)
5234

5335
t = 12. ## fake out current time
54-
## Case 1: outside of pre-syn time window
36+
## Case 1: outside pre-syn time window
5537
input_tols = jnp.ones((1, 1,)) * 9.
5638
out_spike = jnp.ones((1, 1))
5739

5840
## check pre-synaptic STDP only
5941
truth = jnp.array([[-0.101]])
60-
ctx.reset()
42+
reset_process.run() # ctx.reset()
6143
a.pre_tols.set(input_tols)
6244
a.postSpike.set(out_spike)
63-
ctx.run(t=t, dt=dt)
64-
ctx.adapt(t=t, dt=dt)
65-
#print(a.dWeights.value)
66-
assert_array_equal(a.dWeights.value, truth)
45+
advance_process.run(t=t, dt=dt) # ctx.run(t=t, dt=dt)
46+
evolve_process.run(t=t, dt=dt) # ctx.adapt(t=t, dt=dt)
47+
# print(a.dWeights.get())
48+
# print(truth)
49+
assert_array_equal(a.dWeights.get(), truth)
6750

6851
## Case 2: within pre-syn time window
6952
input_tols = jnp.ones((1, 1,)) * 11.
7053
out_spike = jnp.ones((1, 1))
7154

7255
## check pre-synaptic STDP only
7356
truth = jnp.array([[0.899]])
74-
ctx.reset()
57+
reset_process.run() # ctx.reset()
7558
a.pre_tols.set(input_tols)
7659
a.postSpike.set(out_spike)
77-
ctx.run(t=t, dt=dt)
78-
ctx.adapt(t=t, dt=dt)
79-
#print(a.dWeights.value)
80-
assert_array_equal(a.dWeights.value, truth)
60+
advance_process.run(t=t, dt=dt) # ctx.run(t=t, dt=dt)
61+
evolve_process.run(t=t, dt=dt) # ctx.adapt(t=t, dt=dt)
62+
# print(a.dWeights.get())
63+
# print(truth)
64+
assert_array_equal(a.dWeights.get(), truth)
8165

8266
#test_eventSTDPSynapse1()
8367

0 commit comments

Comments
 (0)