@@ -81,6 +81,7 @@ def __init__(
8181 name , shape , A_plus , A_minus , eta = eta , mu = mu , pretrace_target = pretrace_target , weight_init = weight_init ,
8282 resist_scale = resist_scale , p_conn = p_conn , w_bound = w_bound , batch_size = batch_size , ** kwargs
8383 )
84+ self .w_eps = 0.
8485 ## MSTDP/MSTDP-ET meta-parameters
8586 self .tau_elg = tau_elg
8687 self .elg_decay = elg_decay
@@ -91,28 +92,23 @@ def __init__(
9192 @transition (output_compartments = ["weights" , "dWeights" , "eligibility" ])
9293 @staticmethod
9394 def evolve (
94- dt , w_bound , preTrace_target , mu , Aplus , Aminus , tau_elg , elg_decay , preSpike , postSpike , preTrace ,
95- postTrace , weights , eta , modulator , eligibility
96- ):
97- ## compute local synaptic update (via STDP)
98- dW_dt = TraceSTDPSynapse ._compute_update (
99- dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
100- ) ## produce dW/dt (ODE for synaptic change dynamics)
95+ dt , w_bound , w_eps , preTrace_target , mu , Aplus , Aminus , tau_elg , elg_decay , preSpike , postSpike , preTrace ,
96+ postTrace , weights , dWeights , eta , modulator , eligibility
97+ ):
10198 if tau_elg > 0. : ## perform dynamics of M-STDP-ET
102- ## update eligibility trace given current local update
103- # dElg_dt = -eligibility * elg_decay + dW_dt * update_scale
104- # eligibility = eligibility + dElg_dt * dt/elg_tau
105- eligibility = eligibility * jnp .exp (- dt / tau_elg ) * elg_decay + dW_dt
106- else : ## perform dynamics of M-STDP (no eligibility trace)
107- eligibility = dW_dt
108- ## Perform a trace/update times a modulatory signal (e.g., reward)
109- dWeights = eligibility * modulator
110-
99+ eligibility = eligibility * jnp .exp (- dt / tau_elg ) * elg_decay + dWeights / tau_elg
100+ else : ## otherwise, just do M-STDP
101+ eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
111102 ## do a gradient ascent update/shift
112- weights = weights + dWeights * eta ## modulate update
113- ## enforce non-negativity
114- eps = 0.01
115- weights = jnp .clip (weights , eps , w_bound - eps ) # jnp.abs(w_bound))
103+ weights = weights + eligibility * modulator * eta ## do modulated update
104+ dW_dt = TraceSTDPSynapse ._compute_update ( ## use Hebbian/STDP rule to obtain a non-modulated update
105+ dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
106+ )
107+ dWeights = dW_dt ## can think of this as eligibility at time t
108+
109+ #w_eps = 0. # 0.01
110+ weights = jnp .clip (weights , w_eps , w_bound - w_eps ) # jnp.abs(w_bound))
111+
116112 return weights , dWeights , eligibility
117113
118114 @transition (
0 commit comments