Skip to content

Commit 7fbae79

Browse files
author
Alexander Ororbia
committed
refactored event-stdp w/ unit-test
1 parent 6a4889a commit 7fbae79

File tree

2 files changed

+97
-23
lines changed

2 files changed

+97
-23
lines changed

ngclearn/components/synapses/hebbian/eventSTDPSynapse.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from jax import numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
3-
from ngclearn.utils import tensorstats
4-
## import parent synapse class/component
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
56
from ngclearn.components.synapses import DenseSynapse
7+
from ngclearn.utils import tensorstats
68

79
class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
810
"""
@@ -80,8 +82,9 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
8082
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity
8183

8284
@staticmethod
83-
def _compute_update(t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols,
84-
postSpike, weights):
85+
def _compute_update(
86+
t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights
87+
): ## synaptic adjustment calculation co-routine
8588
## check if a spike occurred in window of (t - presyn_win_len, t]
8689
m = (pre_tols > 0.) * 1. ## ignore default value of tols = 0 ms
8790
if presyn_win_len > 0.:
@@ -99,40 +102,30 @@ def _compute_update(t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols,
99102
dW = (dW * postSpike) ## gate to make sure only post-spikes trigger updates
100103
return dW
101104

105+
@transition(output_compartments=["weights", "dWeights"])
102106
@staticmethod
103-
def _evolve(t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols,
104-
postSpike, weights, eta):
107+
def evolve(
108+
t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights, eta
109+
):
105110
dWeights = EventSTDPSynapse._compute_update(
106111
t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights
107112
)
108113
weights = weights + dWeights * eta # * (1. - w) * eta
109-
weights = jnp.clip(weights, 0.01, w_bound) # not in source paper
114+
weights = jnp.clip(weights, 0.01, w_bound) ## Note: this step not in source paper
110115
return weights, dWeights
111116

112-
@resolver(_evolve)
113-
def evolve(self, weights, dWeights):
114-
self.weights.set(weights)
115-
self.dWeights.set(dWeights)
116-
117+
@transition(output_compartments=["inputs", "outputs", "pre_tols", "postSpike", "dWeights"])
117118
@staticmethod
118-
def _reset(batch_size, shape):
119+
def reset(batch_size, shape):
119120
preVals = jnp.zeros((batch_size, shape[0]))
120121
postVals = jnp.zeros((batch_size, shape[1]))
121122
inputs = preVals
122123
outputs = postVals
123-
pre_tols = preVals ## pre-synaptic time-of-last-spike record
124+
pre_tols = preVals ## pre-synaptic time-of-last-spike(s) record
124125
postSpike = postVals
125126
dWeights = jnp.zeros(shape)
126127
return inputs, outputs, pre_tols, postSpike, dWeights
127128

128-
@resolver(_reset)
129-
def reset(self, inputs, outputs, pre_tols, postSpike, dWeights):
130-
self.inputs.set(inputs)
131-
self.outputs.set(outputs)
132-
self.pre_tols.set(pre_tols)
133-
self.postSpike.set(postSpike)
134-
self.dWeights.set(dWeights)
135-
136129
@classmethod
137130
def help(cls): ## component help function
138131
properties = {
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import EventSTDPSynapse
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
14+
def test_eventSTDPSynapse1():
15+
name = "event_stdp_ctx"
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
trace_increment = 0.1
21+
# ---- build a simple Poisson cell system ----
22+
with Context(name) as ctx:
23+
a = EventSTDPSynapse(
24+
name="a", shape=(1,1), eta=0., presyn_win_len=2., key=subkeys[0]
25+
)
26+
27+
#"""
28+
evolve_process = (Process()
29+
>> a.evolve)
30+
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
31+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
32+
33+
advance_process = (Process()
34+
>> a.advance_state)
35+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
36+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
37+
38+
reset_process = (Process()
39+
>> a.reset)
40+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
41+
#"""
42+
43+
"""
44+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
45+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
46+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
47+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
48+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
49+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
50+
"""
51+
52+
t = 12. ## fake out current time
53+
## Case 1: outside of pre-syn time window
54+
input_tols = jnp.ones((1, 1,)) * 9.
55+
out_spike = jnp.ones((1, 1))
56+
57+
## check pre-synaptic STDP only
58+
truth = jnp.array([[-0.6296545]])
59+
ctx.reset()
60+
a.pre_tols.set(input_tols)
61+
a.postSpike.set(out_spike)
62+
ctx.run(t=t, dt=dt)
63+
ctx.adapt(t=t, dt=dt)
64+
#print(a.dWeights.value)
65+
assert_array_equal(a.dWeights.value, truth)
66+
67+
## Case 2: within pre-syn time window
68+
input_tols = jnp.ones((1, 1,)) * 11.
69+
out_spike = jnp.ones((1, 1))
70+
71+
## check pre-synaptic STDP only
72+
truth = jnp.array([[0.37034547]])
73+
ctx.reset()
74+
a.pre_tols.set(input_tols)
75+
a.postSpike.set(out_spike)
76+
ctx.run(t=t, dt=dt)
77+
ctx.adapt(t=t, dt=dt)
78+
#print(a.dWeights.value)
79+
assert_array_equal(a.dWeights.value, truth)
80+
81+
#test_eventSTDPSynapse1()

0 commit comments

Comments
 (0)