@@ -94,19 +94,28 @@ def __init__(
9494 def evolve (
9595 dt , w_bound , w_eps , preTrace_target , mu , Aplus , Aminus , tau_elg , elg_decay , preSpike , postSpike , preTrace ,
9696 postTrace , weights , dWeights , eta , modulator , eligibility
97- ):
97+ ):
98+ '''
99+ dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
100+ dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
101+ )
102+ dWeights = dW_dt ## can think of this as eligibility at time t
103+ '''
104+
98105 if tau_elg > 0. : ## perform dynamics of M-STDP-ET
99106 eligibility = eligibility * jnp .exp (- dt / tau_elg ) * elg_decay + dWeights / tau_elg
100107 else : ## otherwise, just do M-STDP
101108 eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
102109 ## do a gradient ascent update/shift
103110 weights = weights + eligibility * modulator * eta ## do modulated update
111+ #'''
104112 dW_dt = TraceSTDPSynapse ._compute_update ( ## use Hebbian/STDP rule to obtain a non-modulated update
105113 dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
106114 )
107115 dWeights = dW_dt ## can think of this as eligibility at time t
108-
109- #w_eps = 0. # 0.01
116+ #'''
117+
118+ #w_eps = 0.01
110119 weights = jnp .clip (weights , w_eps , w_bound - w_eps ) # jnp.abs(w_bound))
111120
112121 return weights , dWeights , eligibility
0 commit comments