Skip to content

Commit 464ab10

Browse files
author
Alexander Ororbia
committed
refactored mstdp-et syn w/ unit-test
1 parent 17540ea commit 464ab10

File tree

4 files changed

+117
-53
lines changed

4 files changed

+117
-53
lines changed

ngclearn/components/__init__.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from .jaxComponent import JaxComponent
22

3-
43
## point to rate-coded cell component types
54
from .neurons.graded.rateCell import RateCell
65
from .neurons.graded.gaussianErrorCell import GaussianErrorCell
76
from .neurons.graded.laplacianErrorCell import LaplacianErrorCell
87
from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell
98
from .neurons.graded.rewardErrorCell import RewardErrorCell
109

11-
1210
## point to standard spiking cell component types
1311
from .neurons.spiking.sLIFCell import SLIFCell
1412
from .neurons.spiking.IFCell import IFCell
@@ -21,11 +19,9 @@
2119
from .neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
2220
from .neurons.spiking.RAFCell import RAFCell
2321

24-
## point to transformer/operater component types
22+
## point to transformer/operator component types
2523
from .other.varTrace import VarTrace
2624
from .other.expKernel import ExpKernel
27-
from ngclearn.components.synapses.modulated.eligibilityTrace import EligibilityTrace
28-
2925

3026
## point to input encoder component types
3127
from .input_encoders.bernoulliCell import BernoulliCell
@@ -43,7 +39,6 @@
4339
from .synapses.hebbian.BCMSynapse import BCMSynapse
4440
from .synapses.STPDenseSynapse import STPDenseSynapse
4541

46-
4742
## point to convolutional component types
4843
from .synapses.convolution.convSynapse import ConvSynapse
4944
from .synapses.convolution.staticConvSynapse import StaticConvSynapse
@@ -56,7 +51,6 @@
5651
## point to modulated component types
5752
from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
5853

59-
6054
## point to monitors
6155
from .monitor import Monitor
6256

@@ -65,8 +59,3 @@
6559
from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse
6660
from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse
6761

68-
69-
70-
71-
72-

ngclearn/components/synapses/modulated/MSTDPETSynapse.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from jax import random, numpy as jnp, jit
2-
from ngclearn import resolver, Component, Compartment
2+
from ngcsimlib.compilers.process import transition
3+
from ngcsimlib.component import Component
4+
from ngcsimlib.compartment import Compartment
5+
6+
from ngclearn.utils.weight_distribution import initialize_params
7+
from ngcsimlib.logger import info
38
from ngclearn.components.synapses.hebbian import TraceSTDPSynapse
49
from ngclearn.utils import tensorstats
510

611
class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligility traces
712
"""
8-
A synaptic cable that adjusts its efficacies via trace-based form of
9-
three-factor learning, i.e., modulated spike-timing-dependent plasticity
10-
(M-STDP) or modulated STDP with eligibility traces (M-STDP-ET).
13+
A synaptic cable that adjusts its efficacies via trace-based form of three-factor learning, i.e., modulated
14+
spike-timing-dependent plasticity (M-STDP) or modulated STDP with eligibility traces (M-STDP-ET).
1115
1216
| --- Synapse Compartments: ---
1317
| inputs - input (takes in external signals)
@@ -20,11 +24,14 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
2024
| postSpike - post-synaptic spike to drive 2nd term of STDP update (takes in external signals)
2125
| preTrace - pre-synaptic trace value to drive 1st term of STDP update (takes in external signals)
2226
| postTrace - post-synaptic trace value to drive 2nd term of STDP update (takes in external signals)
23-
| dWeights - current delta matrix containing changes to be applied to synaptic efficacies
27+
| dWeights - current delta matrix containing (MS-STDP/MS-STDP-ET) changes to be applied to synaptic efficacies
2428
| eligibility - current state of eligibility trace
25-
| eta - global learning rate (multiplier beyond A_plus and A_minus)
29+
| eta - global learning rate (applied to change in weights for final MS-STDP/MS-STDP-ET adjustment)
2630
2731
| References:
32+
| Florian, Răzvan V. "Reinforcement learning through modulation of spike-timing-dependent synaptic plasticity."
33+
| Neural computation 19.6 (2007): 1468-1502.
34+
|
2835
| Morrison, Abigail, Ad Aertsen, and Markus Diesmann. "Spike-timing-dependent
2936
| plasticity in balanced random networks." Neural computation 19.6 (2007): 1437-1467.
3037
|
@@ -66,29 +73,30 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
6673
"""
6774

6875
# Define Functions
69-
def __init__(self, name, shape, A_plus, A_minus, eta=1., mu=0.,
70-
pretrace_target=0., tau_elg=0., elg_decay=1.,
71-
weight_init=None, resist_scale=1., p_conn=1., w_bound=1.,
72-
batch_size=1, **kwargs):
73-
super().__init__(name, shape, A_plus, A_minus, eta=eta, mu=mu,
74-
pretrace_target=pretrace_target, weight_init=weight_init,
75-
resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound,
76-
batch_size=batch_size, **kwargs)
76+
def __init__(
77+
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1.,
78+
weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs
79+
):
80+
super().__init__(
81+
name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init,
82+
resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs
83+
)
7784
## MSTDP/MSTDP-ET meta-parameters
7885
self.tau_elg = tau_elg
7986
self.elg_decay = elg_decay
8087
## MSTDP/MSTDP-ET compartments
8188
self.modulator = Compartment(jnp.zeros((self.batch_size, 1)))
8289
self.eligibility = Compartment(jnp.zeros(shape))
8390

91+
@transition(output_compartments=["weights", "dWeights", "eligibility"])
8492
@staticmethod
85-
def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus, tau_elg,
86-
elg_decay, preSpike, postSpike, preTrace, postTrace, weights,
87-
eta, modulator, eligibility):
93+
def evolve(
94+
dt, w_bound, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, preSpike, postSpike, preTrace,
95+
postTrace, weights, eta, modulator, eligibility
96+
):
8897
## compute local synaptic update (via STDP)
8998
dW_dt = TraceSTDPSynapse._compute_update(
90-
dt, w_bound, preTrace_target, mu, Aplus, Aminus,
91-
preSpike, postSpike, preTrace, postTrace, weights
99+
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
92100
) ## produce dW/dt (ODE for synaptic change dynamics)
93101
if tau_elg > 0.: ## perform dynamics of M-STDP-ET
94102
## update eligibility trace given current local update
@@ -107,14 +115,11 @@ def _evolve(dt, w_bound, preTrace_target, mu, Aplus, Aminus, tau_elg,
107115
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
108116
return weights, dWeights, eligibility
109117

110-
@resolver(_evolve)
111-
def evolve(self, weights, dWeights, eligibility):
112-
self.weights.set(weights)
113-
self.dWeights.set(dWeights)
114-
self.eligibility.set(eligibility)
115-
118+
@transition(
119+
output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility"]
120+
)
116121
@staticmethod
117-
def _reset(batch_size, shape):
122+
def reset(batch_size, shape):
118123
preVals = jnp.zeros((batch_size, shape[0]))
119124
postVals = jnp.zeros((batch_size, shape[1]))
120125
synVals = jnp.zeros(shape)
@@ -126,20 +131,7 @@ def _reset(batch_size, shape):
126131
postTrace = postVals
127132
dWeights = synVals
128133
eligibility = synVals
129-
return (inputs, outputs, preSpike, postSpike, preTrace, postTrace,
130-
dWeights, eligibility)
131-
132-
@resolver(_reset)
133-
def reset(self, inputs, outputs, preSpike, postSpike, preTrace, postTrace,
134-
dWeights, eligibility):
135-
self.inputs.set(inputs)
136-
self.outputs.set(outputs)
137-
self.preSpike.set(preSpike)
138-
self.postSpike.set(postSpike)
139-
self.preTrace.set(preTrace)
140-
self.postTrace.set(postTrace)
141-
self.dWeights.set(dWeights)
142-
self.eligibility.set(eligibility)
134+
return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility
143135

144136
@classmethod
145137
def help(cls): ## component help function
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .MSTDPETSynapse import MSTDPETSynapse
1+
from .MSTDPETSynapse import MSTDPETSynapse
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from jax import numpy as jnp, random, jit
2+
from ngcsimlib.context import Context
3+
import numpy as np
4+
np.random.seed(42)
5+
from ngclearn.components import MSTDPETSynapse
6+
from ngcsimlib.compilers import compile_command, wrap_command
7+
from numpy.testing import assert_array_equal
8+
9+
from ngcsimlib.compilers.process import Process, transition
10+
from ngcsimlib.component import Component
11+
from ngcsimlib.compartment import Compartment
12+
from ngcsimlib.context import Context
13+
import ngclearn.utils.weight_distribution as dist
14+
15+
def test_MSTDPETSynapse1():
16+
name = "mstdpet_ctx"
17+
## create seeding keys
18+
dkey = random.PRNGKey(1234)
19+
dkey, *subkeys = random.split(dkey, 6)
20+
dt = 1. # ms
21+
# ---- build a simple Poisson cell system ----
22+
with Context(name) as ctx:
23+
a = MSTDPETSynapse(
24+
name="a", shape=(1,1), A_plus=1., A_minus=1., eta=0., key=subkeys[0]
25+
)
26+
27+
#"""
28+
advance_process = (Process()
29+
>> a.advance_state)
30+
# ctx.wrap_and_add_command(advance_process.pure, name="run")
31+
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
32+
33+
evolve_process = (Process()
34+
>> a.evolve)
35+
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
36+
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
37+
38+
reset_process = (Process()
39+
>> a.reset)
40+
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
41+
#"""
42+
43+
"""
44+
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
45+
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
46+
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
47+
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
48+
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
49+
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
50+
"""
51+
52+
a.weights.set(jnp.ones((1, 1)))
53+
54+
in_spike = jnp.ones((1, 1))
55+
in_trace = jnp.ones((1, 1,)) * 1.25
56+
out_spike = jnp.ones((1, 1))
57+
out_trace = jnp.ones((1, 1,)) * 0.65
58+
r_neg = -jnp.ones((1, 1))
59+
r_pos = jnp.ones((1, 1))
60+
61+
ctx.reset()
62+
a.preSpike.set(in_spike * 0)
63+
a.preTrace.set(in_trace)
64+
a.postSpike.set(out_spike)
65+
a.postTrace.set(out_trace)
66+
a.modulator.set(r_pos)
67+
ctx.run(t=1. * dt, dt=dt)
68+
ctx.adapt(t=1. * dt, dt=dt)
69+
#print(a.dWeights.value)
70+
assert_array_equal(a.dWeights.value, jnp.array([[1.25]]))
71+
72+
ctx.reset()
73+
a.preSpike.set(in_spike * 0)
74+
a.preTrace.set(in_trace)
75+
a.postSpike.set(out_spike)
76+
a.postTrace.set(out_trace)
77+
a.modulator.set(r_neg)
78+
ctx.run(t=1. * dt, dt=dt)
79+
ctx.adapt(t=1. * dt, dt=dt)
80+
#print(a.dWeights.value)
81+
assert_array_equal(a.dWeights.value, jnp.array([[-1.25]]))
82+
83+
#test_MSTDPETSynapse1()

0 commit comments

Comments
 (0)