@@ -61,6 +61,8 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
6161
6262 elg_decay: eligibility decay constant (default: 1)
6363
64+ tau_w: amount of synaptic decay to augment each MSTDP/MSTDP-ET update with
65+
6466 weight_init: a kernel to drive initialization of this synaptic cable's values;
6567 typically a tuple with 1st element as a string calling the name of
6668 initialization to use
@@ -74,26 +76,28 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
7476
7577 # Define Functions
7678 def __init__ (
77- self , name , shape , A_plus , A_minus , eta = 1. , mu = 0. , pretrace_target = 0. , tau_elg = 0. , elg_decay = 1. ,
79+ self , name , shape , A_plus , A_minus , eta = 1. , mu = 0. , pretrace_target = 0. , tau_elg = 0. , elg_decay = 1. , tau_w = 0. ,
7880 weight_init = None , resist_scale = 1. , p_conn = 1. , w_bound = 1. , batch_size = 1 , ** kwargs
7981 ):
8082 super ().__init__ (
8183 name , shape , A_plus , A_minus , eta = eta , mu = mu , pretrace_target = pretrace_target , weight_init = weight_init ,
8284 resist_scale = resist_scale , p_conn = p_conn , w_bound = w_bound , batch_size = batch_size , ** kwargs
8385 )
8486 self .w_eps = 0.
87+ self .tau_w = tau_w
8588 ## MSTDP/MSTDP-ET meta-parameters
8689 self .tau_elg = tau_elg
8790 self .elg_decay = elg_decay
8891 ## MSTDP/MSTDP-ET compartments
8992 self .modulator = Compartment (jnp .zeros ((self .batch_size , 1 )))
9093 self .eligibility = Compartment (jnp .zeros (shape ))
94+ self .outmask = Compartment (jnp .zeros ((1 , shape [1 ])))
9195
9296 @transition (output_compartments = ["weights" , "dWeights" , "eligibility" ])
9397 @staticmethod
9498 def evolve (
95- dt , w_bound , w_eps , preTrace_target , mu , Aplus , Aminus , tau_elg , elg_decay , preSpike , postSpike , preTrace ,
96- postTrace , weights , dWeights , eta , modulator , eligibility
99+ dt , w_bound , w_eps , preTrace_target , mu , Aplus , Aminus , tau_elg , elg_decay , tau_w , preSpike , postSpike ,
100+ preTrace , postTrace , weights , dWeights , eta , modulator , eligibility , outmask
97101 ):
98102 # dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
99103 # dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
@@ -105,21 +109,25 @@ def evolve(
105109 else : ## otherwise, just do M-STDP
106110 eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
107111 ## do a gradient ascent update/shift
108- weights = weights + eligibility * modulator * eta ## do modulated update
109- #'''
112+ decayTerm = 0.
113+ if tau_w > 0. :
114+ decayTerm = weights * (1. / tau_w )
115+ weights = weights + (eligibility * modulator * eta ) * outmask - decayTerm ## do modulated update
116+
110117 dW_dt = TraceSTDPSynapse ._compute_update ( ## use Hebbian/STDP rule to obtain a non-modulated update
111118 dt , w_bound , preTrace_target , mu , Aplus , Aminus , preSpike , postSpike , preTrace , postTrace , weights
112119 )
113120 dWeights = dW_dt ## can think of this as eligibility at time t
114- #'''
115-
121+
116122 #w_eps = 0.01
117123 weights = jnp .clip (weights , w_eps , w_bound - w_eps ) # jnp.abs(w_bound))
118124
119125 return weights , dWeights , eligibility
120126
121127 @transition (
122- output_compartments = ["inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" , "dWeights" , "eligibility" ]
128+ output_compartments = [
129+ "inputs" , "outputs" , "preSpike" , "postSpike" , "preTrace" , "postTrace" , "dWeights" , "eligibility" , "outmask"
130+ ]
123131 )
124132 @staticmethod
125133 def reset (batch_size , shape ):
@@ -134,7 +142,8 @@ def reset(batch_size, shape):
134142 postTrace = postVals
135143 dWeights = synVals
136144 eligibility = synVals
137- return inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights , eligibility
145+ outmask = postVals + 1.
146+ return inputs , outputs , preSpike , postSpike , preTrace , postTrace , dWeights , eligibility , outmask
138147
139148 @classmethod
140149 def help (cls ): ## component help function
0 commit comments