Skip to content

Commit 94477b8

Browse files
author
Alexander Ororbia
committed
refactored mstdpet-syn and test passed
1 parent ebdea3e commit 94477b8

File tree

2 files changed

+59
-105
lines changed

2 files changed

+59
-105
lines changed

ngclearn/components/synapses/modulated/MSTDPETSynapse.py

Lines changed: 39 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from jax import random, numpy as jnp, jit
2-
from ngcsimlib.compilers.process import transition
3-
from ngcsimlib.component import Component
42
from ngcsimlib.compartment import Compartment
5-
3+
from ngcsimlib.parser import compilable
64
from ngclearn.utils.weight_distribution import initialize_params
7-
from ngcsimlib.logger import info
5+
86
from ngclearn.components.synapses.hebbian import TraceSTDPSynapse
9-
from ngclearn.utils import tensorstats
107

118
class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligility traces
129
"""
@@ -72,78 +69,69 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
7269
7370
p_conn: probability of a connection existing (default: 1.); setting
7471
this to < 1. will result in a sparser synaptic structure
72+
73+
w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1)
7574
"""
7675

77-
# Define Functions
7876
def __init__(
7977
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0.,
8078
weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs
8179
):
82-
super().__init__(
80+
super().__init__( # call to parent trace-stdp component
8381
name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init,
8482
resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs
8583
)
8684
self.w_eps = 0.
8785
self.tau_w = tau_w
8886
## MSTDP/MSTDP-ET meta-parameters
89-
self.tau_elg = tau_elg
90-
self.elg_decay = elg_decay
87+
self.tau_elg = tau_elg ## time constant for eligibility trace
88+
self.elg_decay = elg_decay ## decay factor eligibility trace
9189
## MSTDP/MSTDP-ET compartments
9290
self.modulator = Compartment(jnp.zeros((self.batch_size, 1)))
9391
self.eligibility = Compartment(jnp.zeros(shape))
9492
self.outmask = Compartment(jnp.zeros((1, shape[1])))
9593

96-
@transition(output_compartments=["weights", "dWeights", "eligibility"])
97-
@staticmethod
98-
def evolve(
99-
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, tau_w, preSpike, postSpike,
100-
preTrace, postTrace, weights, dWeights, eta, modulator, eligibility, outmask
101-
):
102-
# dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
103-
# dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
104-
# )
94+
@compilable
95+
def evolve(self, dt, t):
96+
# dW_dt = self._compute_update()
10597
# dWeights = dW_dt ## can think of this as eligibility at time t
10698

107-
if tau_elg > 0.: ## perform dynamics of M-STDP-ET
108-
eligibility = eligibility * jnp.exp(-dt / tau_elg) * elg_decay + dWeights/tau_elg
99+
if self.tau_elg > 0.: ## perform dynamics of M-STDP-ET
100+
eligibility = self.eligibility.get() * jnp.exp(-dt / self.tau_elg) * self.elg_decay + self.dWeights.get()/self.tau_elg
109101
else: ## otherwise, just do M-STDP
110-
eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
102+
eligibility = self.dWeights.get() ## dynamics of M-STDP had no eligibility tracing
111103
## do a gradient ascent update/shift
112104
decayTerm = 0.
113-
if tau_w > 0.:
114-
decayTerm = weights * (1. / tau_w)
115-
weights = weights + (eligibility * modulator * eta) * outmask - decayTerm ## do modulated update
105+
if self.tau_w > 0.:
106+
decayTerm = self.weights.get() * (1. / self.tau_w)
107+
## do modulated update
108+
weights = self.weights.get() + (eligibility * self.modulator.get() * self.eta) * self.outmask.get() - decayTerm
116109

117-
dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
118-
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
119-
)
110+
dW_dt = self._compute_update() ## apply a Hebbian/STDP rule to obtain a non-modulated update
120111
dWeights = dW_dt ## can think of this as eligibility at time t
121112

122113
#w_eps = 0.01
123-
weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
124-
125-
return weights, dWeights, eligibility
126-
127-
@transition(
128-
output_compartments=[
129-
"inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility", "outmask"
130-
]
131-
)
132-
@staticmethod
133-
def reset(batch_size, shape):
134-
preVals = jnp.zeros((batch_size, shape[0]))
135-
postVals = jnp.zeros((batch_size, shape[1]))
136-
synVals = jnp.zeros(shape)
137-
inputs = preVals
138-
outputs = postVals
139-
preSpike = preVals
140-
postSpike = postVals
141-
preTrace = preVals
142-
postTrace = postVals
143-
dWeights = synVals
144-
eligibility = synVals
145-
outmask = postVals + 1.
146-
return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility, outmask
114+
weights = jnp.clip(weights, self.w_eps, self.w_bound - self.w_eps) # jnp.abs(w_bound))
115+
self.weights.set(weights)
116+
self.dWeights.set(dWeights)
117+
self.eligibility.set(eligibility)
118+
119+
@compilable
120+
def reset(self):
121+
preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
122+
postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
123+
synVals = jnp.zeros(self.shape.get())
124+
125+
if not self.inputs.targeted:
126+
self.inputs.set(preVals)
127+
self.outputs.set(postVals)
128+
self.preSpike.set(preVals)
129+
self.postSpike.set(postVals)
130+
self.preTrace.set(preVals)
131+
self.postTrace.set(postVals)
132+
self.dWeights.set(synVals)
133+
self.eligibility.set(synVals)
134+
self.outmask.set(postVals + 1.)
147135

148136
@classmethod
149137
def help(cls): ## component help function
@@ -195,17 +183,3 @@ def help(cls): ## component help function
195183
"dW^{stdp}_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
196184
"hyperparameters": hyperparams}
197185
return info
198-
199-
def __repr__(self):
200-
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
201-
maxlen = max(len(c) for c in comps) + 5
202-
lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
203-
for c in comps:
204-
stats = tensorstats(getattr(self, c).value)
205-
if stats is not None:
206-
line = [f"{k}: {v}" for k, v in stats.items()]
207-
line = ", ".join(line)
208-
else:
209-
line = "None"
210-
lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
211-
return lines

tests/components/synapses/modulated/test_MSTDPETSynapse.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,11 @@
22
from ngcsimlib.context import Context
33
import numpy as np
44
np.random.seed(42)
5-
from ngclearn.components import MSTDPETSynapse
6-
from ngcsimlib.compilers import compile_command, wrap_command
7-
from numpy.testing import assert_array_equal
85

9-
from ngcsimlib.compilers.process import Process, transition
10-
from ngcsimlib.component import Component
11-
from ngcsimlib.compartment import Compartment
12-
from ngcsimlib.context import Context
6+
from ngclearn import Context, MethodProcess
137
import ngclearn.utils.weight_distribution as dist
8+
from ngclearn.components.synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
9+
from numpy.testing import assert_array_equal
1410

1511
def test_MSTDPETSynapse1():
1612
name = "mstdpet_ctx"
@@ -24,30 +20,14 @@ def test_MSTDPETSynapse1():
2420
name="a", shape=(1,1), A_plus=1., A_minus=1., eta=0.1, key=subkeys[0]
2521
)
2622

27-
#"""
28-
advance_process = (Process("advance_proc")
23+
evolve_process = (MethodProcess("evolve_process")
24+
>> a.evolve)
25+
26+
advance_process = (MethodProcess("advance_proc")
2927
>> a.advance_state)
30-
# ctx.wrap_and_add_command(advance_process.pure, name="run")
31-
ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
3228

33-
evolve_process = (Process("evolve_proc")
34-
>> a.evolve)
35-
#ctx.wrap_and_add_command(evolve_process.pure, name="run")
36-
ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
37-
38-
reset_process = (Process("reset_proc")
29+
reset_process = (MethodProcess("reset_proc")
3930
>> a.reset)
40-
ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
41-
#"""
42-
43-
"""
44-
reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
45-
ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
46-
advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
47-
ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
48-
evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
49-
ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
50-
"""
5131

5232
a.weights.set(jnp.ones((1, 1)) * 0.75)
5333

@@ -59,28 +39,28 @@ def test_MSTDPETSynapse1():
5939
r_pos = jnp.ones((1, 1))
6040

6141
#print(a.weights.value)
62-
ctx.reset()
42+
reset_process.run() # ctx.reset()
6343
a.preSpike.set(in_spike * 0)
6444
a.preTrace.set(in_trace)
6545
a.postSpike.set(out_spike)
6646
a.postTrace.set(out_trace)
6747
a.modulator.set(r_pos)
68-
ctx.run(t=1. * dt, dt=dt)
69-
ctx.adapt(t=1. * dt, dt=dt)
70-
ctx.adapt(t=1. * dt, dt=dt)
71-
#print(a.weights.value)
72-
assert_array_equal(a.weights.value, jnp.array([[0.875]]))
48+
advance_process.run(t=1., dt=dt) # ctx.run(t=1. * dt, dt=dt)
49+
evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
50+
evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
51+
#print(a.weights.get())
52+
assert_array_equal(a.weights.get(), jnp.array([[0.875]]))
7353

74-
ctx.reset()
54+
reset_process.run() # ctx.reset()
7555
a.preSpike.set(in_spike * 0)
7656
a.preTrace.set(in_trace)
7757
a.postSpike.set(out_spike)
7858
a.postTrace.set(out_trace)
7959
a.modulator.set(r_neg)
80-
ctx.run(t=1. * dt, dt=dt)
81-
ctx.adapt(t=1. * dt, dt=dt)
82-
ctx.adapt(t=1. * dt, dt=dt)
83-
#print(a.weights.value)
84-
assert_array_equal(a.weights.value, jnp.array([[0.75]]))
60+
advance_process.run(t=1., dt=dt) # ctx.run(t=1. * dt, dt=dt)
61+
evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
62+
evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
63+
#print(a.weights.get())
64+
assert_array_equal(a.weights.get(), jnp.array([[0.75]]))
8565

8666
#test_MSTDPETSynapse1()

0 commit comments

Comments
 (0)