Skip to content

Commit 6a4889a

Browse files
author
Alexander Ororbia
committed
refactored exp-stdp syn w/ unit-test
1 parent 9478695 commit 6a4889a

File tree

3 files changed

+116
-50
lines changed

3 files changed

+116
-50
lines changed

ngclearn/components/synapses/hebbian/expSTDPSynapse.py

Lines changed: 37 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,11 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from ngclearn.components.synapses import DenseSynapse
47
from ngclearn.utils import tensorstats
58

6-
def _calc_update(dt, pre, x_pre, post, x_post, W, w_bound=1., x_tar=0.7,
7-
exp_beta=1., Aplus=1., Aminus=0.): ## internal dynamics method
8-
## equations 4 from Diehl and Cook - full exponential weight-dependent STDP
9-
## calculate post-synaptic term
10-
post_term1 = jnp.exp(-exp_beta * W) * jnp.matmul(x_pre.T, post)
11-
x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form
12-
post_term2 = jnp.exp(-exp_beta * (w_bound - W)) * jnp.matmul(x_tar_vec.T,
13-
post)
14-
dWpost = (post_term1 - post_term2) * Aplus
15-
## calculate pre-synaptic term
16-
dWpre = 0.
17-
if Aminus > 0.:
18-
dWpre = -jnp.exp(-exp_beta * W) * jnp.matmul(pre.T, x_post) * Aminus
19-
## calc final weighted adjustment
20-
dW = (dWpost + dWpre)
21-
return dW
22-
239
class ExpSTDPSynapse(DenseSynapse):
2410
"""
2511
A synaptic cable that adjusts its efficacies via trace-based form of
@@ -78,9 +64,10 @@ class ExpSTDPSynapse(DenseSynapse):
7864
"""
7965

8066
# Define Functions
81-
def __init__(self, name, shape, A_plus, A_minus, exp_beta, eta=1.,
82-
pretrace_target=0., weight_init=None, resist_scale=1.,
83-
p_conn=1., w_bound=1., batch_size=1, **kwargs):
67+
def __init__(
68+
self, name, shape, A_plus, A_minus, exp_beta, eta=1., pretrace_target=0., weight_init=None, resist_scale=1.,
69+
p_conn=1., w_bound=1., batch_size=1, **kwargs
70+
):
8471
super().__init__(name, shape, weight_init, None, resist_scale,
8572
p_conn, batch_size=batch_size, **kwargs)
8673

@@ -105,16 +92,36 @@ def __init__(self, name, shape, A_plus, A_minus, exp_beta, eta=1.,
10592
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity
10693

10794
@staticmethod
108-
def _compute_update(dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus,
109-
preSpike, postSpike, preTrace, postTrace, weights):
110-
dW = _calc_update(dt, preSpike, preTrace, postSpike, postTrace, weights,
111-
w_bound=w_bound, x_tar=preTrace_target, exp_beta=exp_beta,
112-
Aplus=Aplus, Aminus=Aminus)
95+
def _compute_update(
96+
dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
97+
):
98+
pre = preSpike
99+
x_pre = preTrace
100+
post = postSpike
101+
x_post = postTrace
102+
W = weights
103+
x_tar = preTrace_target
104+
## equations 4 from Diehl and Cook - full exponential weight-dependent STDP
105+
## calculate post-synaptic term
106+
post_term1 = jnp.exp(-exp_beta * W) * jnp.matmul(x_pre.T, post)
107+
x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form
108+
post_term2 = jnp.exp(-exp_beta * (w_bound - W)) * jnp.matmul(x_tar_vec.T,
109+
post)
110+
dWpost = (post_term1 - post_term2) * Aplus
111+
## calculate pre-synaptic term
112+
dWpre = 0.
113+
if Aminus > 0.:
114+
dWpre = -jnp.exp(-exp_beta * W) * jnp.matmul(pre.T, x_post) * Aminus
115+
## calc final weighted adjustment
116+
dW = (dWpost + dWpre)
113117
return dW
114118

119+
@transition(output_compartments=["weights", "dWeights"])
115120
@staticmethod
116-
def _evolve(dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus,
117-
preSpike, postSpike, preTrace, postTrace, weights, eta):
121+
def evolve(
122+
dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace,
123+
weights, eta
124+
):
118125
dW = ExpSTDPSynapse._compute_update(
119126
dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus,
120127
preSpike, postSpike, preTrace, postTrace, weights
@@ -126,13 +133,9 @@ def _evolve(dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus,
126133
_W = jnp.clip(_W, eps, w_bound - eps)
127134
return weights, dW
128135

129-
@resolver(_evolve)
130-
def evolve(self, weights, dWeights):
131-
self.weights.set(weights)
132-
self.dWeights.set(dWeights)
133-
136+
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"])
134137
@staticmethod
135-
def _reset(batch_size, shape):
138+
def reset(batch_size, shape):
136139
preVals = jnp.zeros((batch_size, shape[0]))
137140
postVals = jnp.zeros((batch_size, shape[1]))
138141
inputs = preVals
@@ -144,16 +147,6 @@ def _reset(batch_size, shape):
144147
dWeights = jnp.zeros(shape)
145148
return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights
146149

147-
@resolver(_reset)
148-
def reset(self, inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights):
149-
self.inputs.set(inputs)
150-
self.outputs.set(outputs)
151-
self.preSpike.set(preSpike)
152-
self.postSpike.set(postSpike)
153-
self.preTrace.set(preTrace)
154-
self.postTrace.set(postTrace)
155-
self.dWeights.set(dWeights)
156-
157150
@classmethod
158151
def help(cls): ## component help function
159152
properties = {
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import ExpSTDPSynapse
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
14+
def test_expSTDPSynapse1():
15+
name = "exp_stdp_ctx"
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
# ---- build a simple Poisson cell system ----
21+
with Context(name) as ctx:
22+
a = ExpSTDPSynapse(
23+
name="a", shape=(1,1), A_plus=1., A_minus=1., exp_beta=1.25, key=subkeys[0]
24+
)
25+
26+
#"""
27+
evolve_process = (Process()
28+
>> a.evolve)
29+
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
30+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
31+
32+
advance_process = (Process()
33+
>> a.advance_state)
34+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
35+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
36+
37+
reset_process = (Process()
38+
>> a.reset)
39+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
40+
#"""
41+
42+
"""
43+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
44+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
45+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
46+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
47+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
48+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
49+
"""
50+
51+
in_spike = jnp.ones((1, 1))
52+
in_trace = jnp.ones((1, 1,)) * 1.25
53+
out_spike = jnp.ones((1, 1))
54+
out_trace = jnp.ones((1, 1,)) * 0.65
55+
56+
## check pre-synaptic STDP only
57+
truth = jnp.array([[0.57342285]])
58+
ctx.reset()
59+
a.preSpike.set(in_spike * 0)
60+
a.preTrace.set(in_trace)
61+
a.postSpike.set(out_spike)
62+
a.postTrace.set(out_trace)
63+
ctx.run(t=1., dt=dt)
64+
ctx.adapt(t=1., dt=dt)
65+
#print(a.dWeights.value)
66+
assert_array_equal(a.dWeights.value, truth)
67+
68+
truth = jnp.array([[-0.29817986]])
69+
ctx.reset()
70+
a.preSpike.set(in_spike)
71+
a.preTrace.set(in_trace)
72+
a.postSpike.set(out_spike * 0)
73+
a.postTrace.set(out_trace)
74+
ctx.run(t=1., dt=dt)
75+
ctx.adapt(t=1., dt=dt)
76+
#print(a.dWeights.value)
77+
assert_array_equal(a.dWeights.value, truth)
78+
79+
#test_expSTDPSynapse1()

tests/components/synapses/hebbian/test_traceSTDPSynapse.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def test_traceSTDPSynapse1():
1717
dkey = random.PRNGKey(1234)
1818
dkey, *subkeys = random.split(dkey, 6)
1919
dt = 1. # ms
20-
trace_increment = 0.1
2120
# ---- build a simple Poisson cell system ----
2221
with Context(name) as ctx:
2322
a = TraceSTDPSynapse(
@@ -49,11 +48,6 @@ def test_traceSTDPSynapse1():
4948
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
5049
"""
5150

52-
## set up non-compiled utility commands
53-
@Context.dynamicCommand
54-
def clamp(x):
55-
a.j.set(x)
56-
5751
in_spike = jnp.ones((1, 1))
5852
in_trace = jnp.ones((1, 1,)) * 1.25
5953
out_spike = jnp.ones((1, 1))

0 commit comments

Comments
 (0)