Skip to content

Commit 9478695

Browse files
author
Alexander Ororbia
committed
refactored dense and trace-stdp syn w/ unit-test
1 parent 6119846 commit 9478695

File tree

3 files changed

+135
-62
lines changed

3 files changed

+135
-62
lines changed

ngclearn/components/synapses/denseSynapse.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
32
from ngclearn.components.jaxComponent import JaxComponent
43
from ngclearn.utils import tensorstats
54
from ngclearn.utils.weight_distribution import initialize_params
65
from ngcsimlib.logger import info
76

7+
from ngcsimlib.compilers.process import transition
8+
from ngcsimlib.component import Component
9+
from ngcsimlib.compartment import Compartment
10+
811
class DenseSynapse(JaxComponent): ## base dense synaptic cable
912
"""
1013
A dense synaptic cable; no form of synaptic evolution/adaptation
@@ -13,7 +16,7 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
1316
| --- Synapse Compartments: ---
1417
| inputs - input (takes in external signals)
1518
| outputs - output signals
16-
| weights - current value matrix of synaptic efficacies
19+
| weights - current value matrix of synaptic efficacies (strength values)
1720
| biases - current value vector of synaptic bias values
1821
1922
Args:
@@ -75,28 +78,21 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
7578
(1, shape[1]))
7679
if bias_init else 0.0)
7780

81+
@transition(output_compartments=["outputs"])
7882
@staticmethod
79-
def _advance_state(Rscale, inputs, weights, biases):
83+
def advance_state(Rscale, inputs, weights, biases):
8084
outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
8185
return outputs
8286

83-
@resolver(_advance_state)
84-
def advance_state(self, outputs):
85-
self.outputs.set(outputs)
86-
87+
@transition(output_compartments=["inputs", "outputs"])
8788
@staticmethod
88-
def _reset(batch_size, shape):
89+
def reset(batch_size, shape):
8990
preVals = jnp.zeros((batch_size, shape[0]))
9091
postVals = jnp.zeros((batch_size, shape[1]))
9192
inputs = preVals
9293
outputs = postVals
9394
return inputs, outputs
9495

95-
@resolver(_reset)
96-
def reset(self, inputs, outputs):
97-
self.inputs.set(inputs)
98-
self.outputs.set(outputs)
99-
10096
def save(self, directory, **kwargs):
10197
file_name = directory + "/" + self.name + ".npz"
10298
if self.bias_init != None:

ngclearn/components/synapses/hebbian/traceSTDPSynapse.py

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,11 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
36
from ngclearn.components.synapses import DenseSynapse
47
from ngclearn.utils import tensorstats
58

6-
def _calc_update(dt, pre, x_pre, post, x_post, W, w_bound=1., x_tar=0.0, mu=0.,
7-
Aplus=1., Aminus=0.):
8-
if mu > 0.:
9-
## equations 3, 5, & 6 from Diehl and Cook - full power-law STDP
10-
post_shift = jnp.power(w_bound - W, mu)
11-
pre_shift = jnp.power(W, mu)
12-
dWpost = (post_shift * jnp.matmul((x_pre - x_tar).T, post)) * Aplus
13-
dWpre = 0.
14-
if Aminus > 0.:
15-
dWpre = -(pre_shift * jnp.matmul(pre.T, x_post)) * Aminus
16-
else:
17-
## calculate post-synaptic term
18-
dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus)
19-
dWpre = 0.
20-
if Aminus > 0.:
21-
## calculate pre-synaptic term
22-
dWpre = -jnp.matmul(pre.T, x_post * Aminus)
23-
## calc final weighted adjustment
24-
dW = (dWpost + dWpre)
25-
return dW
269

2710
class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP
2811
"""
@@ -83,9 +66,10 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP
8366
"""
8467

8568
# Define Functions
86-
def __init__(self, name, shape, A_plus, A_minus, eta=1., mu=0.,
87-
pretrace_target=0., weight_init=None, resist_scale=1.,
88-
p_conn=1., w_bound=1., batch_size=1, **kwargs):
69+
def __init__(
70+
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., weight_init=None, resist_scale=1.,
71+
p_conn=1., w_bound=1., batch_size=1, **kwargs
72+
):
8973
super().__init__(name, shape, weight_init, None, resist_scale,
9074
p_conn, batch_size=batch_size, **kwargs)
9175

@@ -109,19 +93,41 @@ def __init__(self, name, shape, A_plus, A_minus, eta=1., mu=0.,
10993
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate
11094

11195
@staticmethod
112-
def _compute_update(dt, w_bound, preTrace_target, mu, Aplus, Aminus,
113-
preSpike, postSpike, preTrace, postTrace, weights):
114-
dW = _calc_update(dt, preSpike, preTrace, postSpike, postTrace, weights,
115-
w_bound=w_bound, x_tar=preTrace_target, mu=mu,
116-
Aplus=Aplus, Aminus=Aminus)
96+
def _compute_update(
97+
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
98+
):
99+
pre = preSpike
100+
x_pre = preTrace
101+
post = postSpike
102+
x_post = postTrace
103+
W = weights
104+
x_tar = preTrace_target
105+
if mu > 0.:
106+
## equations 3, 5, & 6 from Diehl and Cook - full power-law STDP
107+
post_shift = jnp.power(w_bound - W, mu)
108+
pre_shift = jnp.power(W, mu)
109+
dWpost = (post_shift * jnp.matmul((x_pre - x_tar).T, post)) * Aplus
110+
dWpre = 0.
111+
if Aminus > 0.:
112+
dWpre = -(pre_shift * jnp.matmul(pre.T, x_post)) * Aminus
113+
else:
114+
## calculate post-synaptic term
115+
dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus)
116+
dWpre = 0.
117+
if Aminus > 0.:
118+
## calculate pre-synaptic term
119+
dWpre = -jnp.matmul(pre.T, x_post * Aminus)
120+
## calc final weighted adjustment
121+
dW = (dWpost + dWpre)
117122
return dW
118123

124+
@transition(output_compartments=["weights", "dWeights"])
119125
@staticmethod
120-
def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus,
121-
preSpike, postSpike, preTrace, postTrace, weights, eta):
126+
def evolve(
127+
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights, eta
128+
):
122129
dWeights = TraceSTDPSynapse._compute_update(
123-
dt, w_bound, preTrace_target, mu, Aplus, Aminus,
124-
preSpike, postSpike, preTrace, postTrace, weights
130+
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
125131
)
126132
## do a gradient ascent update/shift
127133
weights = weights + dWeights * eta
@@ -130,13 +136,9 @@ def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus,
130136
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
131137
return weights, dWeights
132138

133-
@resolver(_evolve)
134-
def evolve(self, weights, dWeights):
135-
self.weights.set(weights)
136-
self.dWeights.set(dWeights)
137-
139+
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"])
138140
@staticmethod
139-
def _reset(batch_size, shape):
141+
def reset(batch_size, shape):
140142
preVals = jnp.zeros((batch_size, shape[0]))
141143
postVals = jnp.zeros((batch_size, shape[1]))
142144
inputs = preVals
@@ -148,16 +150,6 @@ def _reset(batch_size, shape):
148150
dWeights = jnp.zeros(shape)
149151
return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights
150152

151-
@resolver(_reset)
152-
def reset(self, inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights):
153-
self.inputs.set(inputs)
154-
self.outputs.set(outputs)
155-
self.preSpike.set(preSpike)
156-
self.postSpike.set(postSpike)
157-
self.preTrace.set(preTrace)
158-
self.postTrace.set(postTrace)
159-
self.dWeights.set(dWeights)
160-
161153
@classmethod
162154
def help(cls): ## component help function
163155
properties = {
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import TraceSTDPSynapse
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
14+
def test_traceSTDPSynapse1():
15+
name = "trace_stdp_ctx"
16+
## create seeding keys
17+
dkey = random.PRNGKey(1234)
18+
dkey, *subkeys = random.split(dkey, 6)
19+
dt = 1. # ms
20+
trace_increment = 0.1
21+
# ---- build a simple Poisson cell system ----
22+
with Context(name) as ctx:
23+
a = TraceSTDPSynapse(
24+
name="a", shape=(1,1), A_plus=1., A_minus=1., key=subkeys[0]
25+
)
26+
27+
#"""
28+
evolve_process = (Process()
29+
>> a.evolve)
30+
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
31+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
32+
33+
advance_process = (Process()
34+
>> a.advance_state)
35+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
36+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
37+
38+
reset_process = (Process()
39+
>> a.reset)
40+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
41+
#"""
42+
43+
"""
44+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
45+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
46+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
47+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
48+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
49+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
50+
"""
51+
52+
## set up non-compiled utility commands
53+
@Context.dynamicCommand
54+
def clamp(x):
55+
a.j.set(x)
56+
57+
in_spike = jnp.ones((1, 1))
58+
in_trace = jnp.ones((1, 1,)) * 1.25
59+
out_spike = jnp.ones((1, 1))
60+
out_trace = jnp.ones((1, 1,)) * 0.65
61+
62+
## check pre-synaptic STDP only
63+
truth = jnp.array([[1.25]])
64+
ctx.reset()
65+
a.preSpike.set(in_spike * 0)
66+
a.preTrace.set(in_trace)
67+
a.postSpike.set(out_spike)
68+
a.postTrace.set(out_trace)
69+
ctx.run(t=1., dt=dt)
70+
ctx.adapt(t=1., dt=dt)
71+
#print(a.dWeights.value)
72+
assert_array_equal(a.dWeights.value, truth)
73+
74+
truth = jnp.array([[-0.65]])
75+
ctx.reset()
76+
a.preSpike.set(in_spike)
77+
a.preTrace.set(in_trace)
78+
a.postSpike.set(out_spike * 0)
79+
a.postTrace.set(out_trace)
80+
ctx.run(t=1., dt=dt)
81+
ctx.adapt(t=1., dt=dt)
82+
#print(a.dWeights.value)
83+
assert_array_equal(a.dWeights.value, truth)
84+
85+
#test_traceSTDPSynapse1()

0 commit comments

Comments
 (0)