Skip to content

Commit 49bccb5

Browse files
author
Alexander Ororbia
committed
update to mstdp-et and var-trace
1 parent 061e713 commit 49bccb5

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

ngclearn/components/other/varTrace.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class VarTrace(JaxComponent): ## low-pass filter
5959
a_delta: value to increment a trace by in presence of a spike; note if set
6060
to a value <= 0, then a piecewise gated trace will be used instead
6161
62+
P_scale: if `a_delta=0`, then this scales the value that the trace snaps to upon receiving a pulse value
63+
6264
gamma_tr: an extra multiplier in front of the leak of the trace (Default: 1)
6365
6466
decay_type: string indicating the decay type to be applied to ODE
@@ -73,13 +75,14 @@ class VarTrace(JaxComponent): ## low-pass filter
7375
"""
7476

7577
# Define Functions
76-
def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
78+
def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp",
7779
batch_size=1, **kwargs):
7880
super().__init__(name, **kwargs)
7981

8082
## Trace control coefficients
8183
self.tau_tr = tau_tr ## trace time constant
8284
self.a_delta = a_delta ## trace increment (if spike occurred)
85+
self.P_scale = P_scale ## trace scale if non-additive trace to be used
8386
self.gamma_tr = gamma_tr
8487
self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
8588

@@ -94,7 +97,7 @@ def __init__(self, name, n_units, tau_tr, a_delta, gamma_tr=1, decay_type="exp",
9497

9598
@transition(output_compartments=["outputs", "trace"])
9699
@staticmethod
97-
def advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
100+
def advance_state(dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace):
98101
decayFactor = 0.
99102
if "exp" in decay_type:
100103
decayFactor = jnp.exp(-dt/tau_tr)
@@ -104,7 +107,7 @@ def advance_state(dt, decay_type, tau_tr, a_delta, gamma_tr, inputs, trace):
104107
if a_delta > 0.:
105108
_x_tr = _x_tr + inputs * a_delta
106109
else:
107-
_x_tr = _x_tr * (1. - inputs) + inputs
110+
_x_tr = _x_tr * (1. - inputs) + inputs * P_scale
108111
trace = _x_tr
109112
return trace, trace
110113

ngclearn/components/synapses/modulated/MSTDPETSynapse.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
6161
6262
elg_decay: eligibility decay constant (default: 1)
6363
64+
tau_w: amount of synaptic decay to augment each MSTDP/MSTDP-ET update with
65+
6466
weight_init: a kernel to drive initialization of this synaptic cable's values;
6567
typically a tuple with 1st element as a string calling the name of
6668
initialization to use
@@ -74,26 +76,28 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
7476

7577
# Define Functions
7678
def __init__(
77-
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1.,
79+
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0.,
7880
weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs
7981
):
8082
super().__init__(
8183
name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init,
8284
resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs
8385
)
8486
self.w_eps = 0.
87+
self.tau_w = tau_w
8588
## MSTDP/MSTDP-ET meta-parameters
8689
self.tau_elg = tau_elg
8790
self.elg_decay = elg_decay
8891
## MSTDP/MSTDP-ET compartments
8992
self.modulator = Compartment(jnp.zeros((self.batch_size, 1)))
9093
self.eligibility = Compartment(jnp.zeros(shape))
94+
self.outmask = Compartment(jnp.zeros((1, shape[1])))
9195

9296
@transition(output_compartments=["weights", "dWeights", "eligibility"])
9397
@staticmethod
9498
def evolve(
95-
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, preSpike, postSpike, preTrace,
96-
postTrace, weights, dWeights, eta, modulator, eligibility
99+
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, tau_w, preSpike, postSpike,
100+
preTrace, postTrace, weights, dWeights, eta, modulator, eligibility, outmask
97101
):
98102
# dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
99103
# dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
@@ -105,21 +109,25 @@ def evolve(
105109
else: ## otherwise, just do M-STDP
106110
eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
107111
## do a gradient ascent update/shift
108-
weights = weights + eligibility * modulator * eta ## do modulated update
109-
#'''
112+
decayTerm = 0.
113+
if tau_w > 0.:
114+
decayTerm = weights * (1. / tau_w)
115+
weights = weights + (eligibility * modulator * eta) * outmask - decayTerm ## do modulated update
116+
110117
dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
111118
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
112119
)
113120
dWeights = dW_dt ## can think of this as eligibility at time t
114-
#'''
115-
121+
116122
#w_eps = 0.01
117123
weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
118124

119125
return weights, dWeights, eligibility
120126

121127
@transition(
122-
output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility"]
128+
output_compartments=[
129+
"inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility", "outmask"
130+
]
123131
)
124132
@staticmethod
125133
def reset(batch_size, shape):
@@ -134,7 +142,8 @@ def reset(batch_size, shape):
134142
postTrace = postVals
135143
dWeights = synVals
136144
eligibility = synVals
137-
return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility
145+
outmask = postVals + 1.
146+
return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility, outmask
138147

139148
@classmethod
140149
def help(cls): ## component help function

0 commit comments

Comments
 (0)