Skip to content

Commit c5a13d4

Browse files
author
Alexander Ororbia
committed
tweaked trace-stdp and mstdpet
1 parent b154e4e commit c5a13d4

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

ngclearn/components/synapses/hebbian/traceSTDPSynapse.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
self.Aminus = A_minus ## LTD strength
8282
self.Rscale = resist_scale ## post-transformation scale factor
8383
self.w_bound = w_bound #1. ## soft weight constraint
84+
self.w_eps = 0. ## w_eps = 0.01
8485

8586
## Compartment setup
8687
preVals = jnp.zeros((self.batch_size, shape[0]))
@@ -113,6 +114,7 @@ def _compute_update(
113114
else:
114115
## calculate post-synaptic term
115116
dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus)
117+
116118
dWpre = 0.
117119
if Aminus > 0.:
118120
## calculate pre-synaptic term
@@ -124,16 +126,16 @@ def _compute_update(
124126
@transition(output_compartments=["weights", "dWeights"])
125127
@staticmethod
126128
def evolve(
127-
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights, eta
129+
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights, eta
128130
):
129131
dWeights = TraceSTDPSynapse._compute_update(
130132
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
131133
)
132134
## do a gradient ascent update/shift
133135
weights = weights + dWeights * eta
134136
## enforce non-negativity
135-
eps = 0.01 # 0.001
136-
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
137+
#w_eps = 0. # 0.01 # 0.001
138+
weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
137139
return weights, dWeights
138140

139141
@transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"])

ngclearn/components/synapses/modulated/MSTDPETSynapse.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init,
8282
resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs
8383
)
84+
self.w_eps = 0.
8485
## MSTDP/MSTDP-ET meta-parameters
8586
self.tau_elg = tau_elg
8687
self.elg_decay = elg_decay
@@ -91,28 +92,23 @@ def __init__(
9192
@transition(output_compartments=["weights", "dWeights", "eligibility"])
9293
@staticmethod
9394
def evolve(
94-
dt, w_bound, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, preSpike, postSpike, preTrace,
95-
postTrace, weights, eta, modulator, eligibility
96-
):
97-
## compute local synaptic update (via STDP)
98-
dW_dt = TraceSTDPSynapse._compute_update(
99-
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
100-
) ## produce dW/dt (ODE for synaptic change dynamics)
95+
dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, preSpike, postSpike, preTrace,
96+
postTrace, weights, dWeights, eta, modulator, eligibility
97+
):
10198
if tau_elg > 0.: ## perform dynamics of M-STDP-ET
102-
## update eligibility trace given current local update
103-
# dElg_dt = -eligibility * elg_decay + dW_dt * update_scale
104-
# eligibility = eligibility + dElg_dt * dt/elg_tau
105-
eligibility = eligibility * jnp.exp(-dt / tau_elg) * elg_decay + dW_dt
106-
else: ## perform dynamics of M-STDP (no eligibility trace)
107-
eligibility = dW_dt
108-
## Perform a trace/update times a modulatory signal (e.g., reward)
109-
dWeights = eligibility * modulator
110-
99+
eligibility = eligibility * jnp.exp(-dt / tau_elg) * elg_decay + dWeights/tau_elg
100+
else: ## otherwise, just do M-STDP
101+
eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
111102
## do a gradient ascent update/shift
112-
weights = weights + dWeights * eta ## modulate update
113-
## enforce non-negativity
114-
eps = 0.01
115-
weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
103+
weights = weights + eligibility * modulator * eta ## do modulated update
104+
dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
105+
dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
106+
)
107+
dWeights = dW_dt ## can think of this as eligibility at time t
108+
109+
#w_eps = 0. # 0.01
110+
weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
111+
116112
return weights, dWeights, eligibility
117113

118114
@transition(

0 commit comments

Comments
 (0)